forked from AI621/ai621
1
0
Fork 0

(real) initial commit

This commit is contained in:
Foxo 2023-07-04 20:19:03 +02:00
parent 56c6bde06d
commit 9e7a120561
9 changed files with 78809 additions and 0 deletions

1
.gitignore vendored
View File

@ -164,3 +164,4 @@ config.py
bot.session* bot.session*
e621/* e621/*
userdata/* userdata/*
ai621.db

BIN
ai621.example.db Normal file

Binary file not shown.

620
bot.py Normal file
View File

@ -0,0 +1,620 @@
import logging, asyncio
from telethon import events
from telethon.tl.custom import Button
from os.path import isfile
from time import time
import httpx
import re
import sqlite3
from os import unlink
from random import randint
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
from telethon.utils import get_display_name
from config import *
from judge_prompt import judge
def get_credits(user_id):
res = conn.execute("""SELECT count(*), prompt.id, daily_cycles, daily_cycles-coalesce(cast(sum((86400-(strftime('%s')-completed_at))/86400*(number*inference_steps/quality)) as int), 0) AS balance, coalesce(86400-(STRFTIME('%s')-max(completed_at)), 0) AS remaining_time, CAST(sum(number*inference_steps/quality)/24 AS integer) AS hourly_gain FROM user
LEFT JOIN prompt ON user_id = user.id AND (strftime('%s')-completed_at) < 86400
WHERE user.id = ?
GROUP BY user.id
""", (user_id,)).fetchone()
return res
async def edit_or_respond(ev, *args, **kwargs):
if hasattr(ev, 'message') and ev.message.input_sender.user_id != (await client.get_me(input_peer=True)).user_id:
await ev.respond(*args, **kwargs)
return
try:
await ev.edit(*args, **kwargs)
except Exception as e:
if isinstance(e, MessageNotModifiedError): return
await ev.respond(*args, **kwargs)
@client.on(events.NewMessage(incoming=True, func=lambda e: e.is_private))
async def maintenance(ev):
conn.execute('INSERT INTO user(id, name) VALUES (?,?) ON CONFLICT DO NOTHING', (ev.input_sender.user_id,get_display_name(ev.sender)))
# if ev.input_sender.user_id != client.admin_id:
# await ev.respond('The bot is currently closed while i implement some new functions. For news and updates, go to @ai621chat')
# raise events.StopPropagation
# New prompt
@client.on(events.NewMessage(pattern='^/new', incoming=True, func=lambda e: e.is_private))
async def new_prompt(ev):
if conn.execute('SELECT 1 FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone():
await edit_or_respond(ev, 'You already have another prompt in the queue. Do you want to delete it?',
buttons=[[Button.inline(f"Check queue", f"queue")],[Button.inline(f"Delete", "delete_and_new")]])
return
if conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone():
await edit_or_respond(ev, 'You already have another prompt pending. Do you want to edit or overwrite it?',
buttons=[[Button.inline(f"No, edit old", f"edit")],[Button.inline(f"Overwrite", "delete_and_new")]])
return
buttons = [Button.inline(f"🚫 No image", f"msg_but:no_image"),]
if conn.execute('SELECT * FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone():
buttons.append(Button.inline(f"📋 Copy your last prompt", f"copy_last"))
await edit_or_respond(ev, 'Creating a new prompt!\nSend me an image, an e621 link, or just press on "No image" if you want to only use a text prompt.\n\n⚠️ From now on, we will use BB95 instead of yiffy-e18! Pino daeni and other tags are not valid anymore. ⚠️', buttons=buttons)
conn.execute('INSERT INTO pending_prompt(user_id, seed) VALUES (?, ?)', (ev.input_sender.user_id,randint(0,1000000)))
conn.execute('UPDATE user SET pending_action = \'initial_image\' WHERE id = ?', (ev.input_sender.user_id,))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^copy_last'))
async def copy_last_prompt(ev):
await ev.delete()
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
conn.execute('INSERT INTO pending_prompt SELECT user_id, image, prompt, detail, negative_prompt, inference_steps, number, ?, blend, prompt_e6, image_e6, resolution, crop, hires FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (randint(0,1000000), ev.input_sender.user_id))
await edit_prompt(ev)
async def step_two(ev):
msg = await ev.respond("Please give me a prompt to use.\n\nThis is a phrase, or a list of e621 tags, or a link to an e621 post to use as a source.\nYou can increase weight on some tags by enclosing them with ((brackets)).\n\nIf you're starting off from an image, please make sure the tags accurately describe the picture for the best results.\n\nSome examples:\n" +
'- "chunie wolf ((athletic male)) solo (abs) pecs standing topless swimwear"\n' +
'- "a digital drawing by chunie of a fox smiling at the side of a swimming pool, wearing swimwear and with white fur and grey ears"\n' +
'- "https://e621.net/posts/3549862"', link_preview=False, parse_mode='HTML')
conn.execute('UPDATE user SET pending_action = \'initial_prompt\' WHERE id = ?', (ev.input_sender.user_id,))
@client.on(events.callbackquery.CallbackQuery(pattern=r'delete_and_new'))
async def delete_and_new(ev):
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
conn.execute('DELETE FROM prompt WHERE user_id = ? AND is_done IS NULL AND started_at IS NULL', (ev.input_sender.user_id,))
await new_prompt(ev)
@client.on(events.callbackquery.CallbackQuery(pattern=r'^msg_but:'))
@client.on(events.NewMessage(incoming=True, pattern='^([^\/](\n|.)*)?$', func=lambda e: e.is_private))
async def edit_pending(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: got_value')
res = conn.execute('SELECT pending_action FROM user WHERE id = ? AND pending_action IS NOT NULL', (ev.input_sender.user_id,)).fetchone()
pending_prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not res:
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: no pending edit')
await ev.respond('You gave me a value, but you\'re not editing any parameter. Maybe try /start again?')
return
field = res[0].replace('initial_', '')
is_initial = res[0].startswith('initial_')
if hasattr(ev, 'message'):
data = ev.message.raw_text.strip()
else:
data = ev.data.decode().split(':', 1)[1]
if field not in ['image', 'prompt', 'negative_prompt', 'seed', 'number', 'detail', 'inference_steps', 'blend', 'resolution', 'crop', 'hires']:
print(field)
await ev.respond('Something went wrong. Try /start again.')
return
if field == 'number':
try:
data = int(data)
assert data <= 4
assert data > 0
except:
await ev.respond('You can only generate between 1 and 4 images. Try again:')
raise events.StopPropagation
if data*pending_prompt['inference_steps'] > 200:
await ev.respond(f"You have choosen to generate {data} images with {pending_prompt['inference_steps']} cycles, resulting in a total {data*pending_prompt['inference_steps']} cycles. If you want to increase the number of generated images, reduce the cycles so that the total is below 200.")
return
if field == 'resolution':
if data not in ['768x512', '512x512', '512x768']: #, '1024x512', '512x1024']:
await ev.respond('This resolution is invalid.')
raise events.StopPropagation
if field == 'detail':
try:
data = float(data)
assert data >= 2
assert data <= 30
except:
await ev.respond('Level of detail must be between 2 and 20. We suggest using 7.5 for best results.')
raise events.StopPropagation
if field == 'crop':
data = data.lower()
if data not in ['yes','no']:
await ev.respond('Answer for this question must be yes or no.')
raise events.StopPropagation
if field == 'hires':
data = data.lower()
if data not in ['yes','no']:
await ev.respond('Answer for this question must be yes or no.')
raise events.StopPropagation
if field == 'blend':
try:
data = float(data)
assert data >= 0.3
assert data <= 0.9
except:
await ev.respond('Level of blend must be between 0.3 and 0.9.')
raise events.StopPropagation
if field == 'inference_steps':
try:
data = int(data)
assert data >= 20
assert data <= 200
assert data % 10 == 0
except:
await ev.respond('Generation time must be an integer between 20 and 200, a multiple of 10.')
raise events.StopPropagation
if data*pending_prompt['number'] > 200:
await ev.respond(f"You have choosen to generate {pending_prompt['number']} images with {data} cycles, a total {pending_prompt['number']*data} cycles. If you want to increase the cycles, reduce the number of generated images so that the total is below 200.")
return
if field == 'image':
conn.execute('UPDATE pending_prompt SET image_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
if data == 'no_image':
pass
elif ev.message.photo:
await ev.client.download_media(ev.message.photo, f"userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg")
data = f'userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg'
await ev.respond('Alright. I will use this image as a reference when generating.')
elif re.search(r'e621\.net/posts?/([0-9]+)', data):
e6_id = int(re.search(r'e621\.net/posts?/([0-9]+)', data).group(1))
e6_post = e621.execute('SELECT * FROM post WHERE id = ?', (e6_id,)).fetchone()
if not e6_post:
await ev.respond('This post doesn\'t seem to exist in our database. This means it\'s too new (database is refreshed once a day) or it was deleted.')
return
if not (e6_post['file_id'].endswith('png') or e6_post['file_id'].endswith('jpg')):
await ev.respond('The e621 post you gave me doesn\'t look like an image. Are you sure it\'s not a gif or a video?')
return
if not isfile(f"e621/{e6_post['file_id']}"):
msg = await ev.respond('Trying to download the image from e621...')
try:
async with httpx.AsyncClient() as webcli:
with open(f"e621/{e6_post['file_id']}", 'wb') as f:
async with webcli.stream('GET', f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}") as response:
async for chunk in response.aiter_bytes():
f.write(chunk)
if response.status_code != 200:
print(f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}")
unlink(f"e621/{e6_post['file_id']}")
await ev.respond(f'I wasn\'t able to download this image. (Got status code {response.status_code}) Try again later')
return
except Exception as e:
await msg.edit(f"I encountered an issue downloading your image: {str(e)}. Please try with another image")
return
else:
await msg.edit('Downloaded!')
conn.execute('UPDATE pending_prompt SET image_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
data = f"e621/{e6_post['file_id']}"
else:
await ev.respond('You need to give me an e621 link or upload an image.', buttons=[Button.inline(f"🚫 No image", f"msg_but:no_image")])
return
if data != 'no_image':
img = Image.open(data)
ar = img.width/img.height
res = None
diff = 10
for r in [[768,512], [512,768], [512,512]]:
if abs(ar-(r[0]/r[1])) < diff:
res = f"{r[0]}x{r[1]}"
diff = abs(ar-(r[0]/r[1]))
if res:
conn.execute('UPDATE pending_prompt SET resolution = ? WHERE user_id = ?', (res, ev.input_sender.user_id))
await ev.respond(f'⚠️ Additionally, i\'ve automatically adjusted the resolution to {res} to match the aspect ratio.')
if field == 'prompt':
conn.execute('UPDATE pending_prompt SET prompt_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
if re.search(r'e621\.net\/posts?/([0-9]+)', data):
e6_id = re.search(r'e621\.net\/posts?/([0-9]+)', data).group(1)
print('Getting data from e621 (prompt)...')
e6_prompt = e621.execute('SELECT tags FROM post WHERE id = ?', (int(e6_id),)).fetchone()
if not e6_prompt:
await ev.respond('This post does not seem to exist on e621, or perhaps, it\'t too new. You can only use posts from yesterday or older.')
return
data = e6_prompt['tags']
if len(data) > 2000:
tags = data.split(' ')
e6_prompt = ''
while len(e6_prompt) < 1000:
e6_prompt += ' '+tags.pop(randint(0, len(tags)-1))
data = e6_prompt
await ev.respond('⚠️ Since this post had too many tags, I had to delete some of them.')
conn.execute('UPDATE pending_prompt SET prompt_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
else:
tags = re.split(r'\W+', data)
if len(tags) < 5:
await ev.respond('At least five tags are needed for the prompt. Please try again:')
return
if field == 'seed':
if data == 'last_prompt':
data = conn.execute('SELECT seed FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
if data: data = data[0]
try:
data = int(data)%1000000
except:
await ev.respond('The seed needs to be an integer between 0 and 1000000. Try again:')
return
conn.execute(f"UPDATE user SET pending_action = NULL WHERE id = ?", (ev.input_sender.user_id,))
conn.execute(f"UPDATE pending_prompt SET {field.replace('initial_','')} = ? WHERE user_id = ?", (data, ev.input_sender.user_id))
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {field} -> {str(data)[:32]}')
if is_initial and field == 'image':
await step_two(ev)
else:
await edit_prompt(ev)
return
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete$'))
async def delete_prompt(ev):
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
await edit_or_respond(ev, f'Prompt has been deleted. Use /new to create a new one.')
@client.on(events.NewMessage(pattern='^/queue', incoming=True))
async def queue_info(ev):
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
avg_wait = conn.execute("SELECT avg(strftime('%s', 'now')-queued_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
min_max_wait = conn.execute("SELECT max(strftime('%s', 'now')-queued_at), min(strftime('%s', 'now')-queued_at) FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
waiting_usrs = conn.execute("SELECT count(*) FROM prompt WHERE is_done IS NULL").fetchone()
avg_speed = conn.execute("SELECT avg(aa) FROM (SELECT (inference_steps*number)/abs(completed_at-started_at) AS aa FROM prompt WHERE started_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)").fetchone()
await ev.respond("\n".join([
f"👯‍♂️ {waiting_usrs[0]} people are waiting in the queue",
f"{'🔴' if int(avg_comp_time[0]/60) > 20 else ('🟡' if int(avg_comp_time[0]/60) > 3 else '🟢')} Average queue-to-result time: <strong>{int(avg_comp_time[0]/60)} minutes</strong>.",
f"{'🔴' if int(min_max_wait[0]/60) > 20 else ('🟡' if int(min_max_wait[0]/60) > 5 else '🟢')} Current queue duration: {int(min_max_wait[1]/60)}~{int(min_max_wait[0]/60)} minutes",
f"🏃 Average speed: {avg_speed[0]:.2f} steps/sec"
]), parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/queuelist', incoming=True))
async def queue_list(ev):
ret = 'Current queue:\n\n'
queue = conn.execute("""SELECT prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
WHERE is_done IS NULL
ORDER BY prompt.id ASC""")
for item in queue:
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
await ev.respond(ret, parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/broke', incoming=True))
async def queue_list(ev):
ret = 'Current queue:\n\n'
queue = conn.execute("""SELECT user.daily_cycles, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
WHERE is_done IS NULL
ORDER BY prompt.id ASC""")
for item in queue:
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
await ev.respond(ret, parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/addworker', incoming=True))
async def add_worker(ev):
if ev.input_sender.user_id != client.admin_id: return
command, api_url, name = ev.message.raw_text.split(' ')
workers.append(Worker(api_url, client, name))
@client.on(events.NewMessage(pattern='^/stop', incoming=True))
async def stop_queue(ev):
if ev.input_sender.user_id != client.admin_id: return
client.process_queue = False
await ev.respond("Stopping queue processing.")
@client.on(events.NewMessage(pattern='^/process', incoming=True))
async def start_queue(ev):
if ev.input_sender.user_id == client.admin_id:
await ev.respond(f'Starting queue processing')
client.process_queue = True
if not client.process_queue:
await ev.respond('Please hold on a bit more time as the bot is going under maintenance.')
return
for w in workers:
w.start()
@client.on(events.NewMessage(pattern='^/queueraw', incoming=True))
async def queue_list(ev):
if ev.input_sender.user_id != client.admin_id: return
await ev.respond(str(client.queue._queue))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^queue$'))
async def queue(ev, new_message=False):
prompt = conn.execute('SELECT * FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not prompt:
await edit_or_respond(ev, 'You don\'t have pending prompts :)\n/new for a new one')
raise events.StopPropagation
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
behind_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id > ?', (prompt['id'],)).fetchone()[0]
front_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id < ?', (prompt['id'],)).fetchone()[0]
if new_message:
await client.send_message(ev.sender, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
else:
await edit_or_respond(ev, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
@client.on(events.callbackquery.CallbackQuery(pattern=r'^confirm$'))
async def confirm_prompt(ev):
# Do not allow banned users to confirm prompts
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,)).fetchone()
if user['is_banned'] == 1:
await ev.respond('You have been banned from sending further prompts due to abuse.')
return
# Check if the prompt exists
prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not prompt or not prompt['prompt']:
await ev.respond('Looks like you have nothing to confirm. Try /new for a new prompt.')
return
# Analyze the prompt
comments, quality = judge(prompt['prompt'])
# Check if the user has used all of his cycles
usage = get_credits(ev.input_sender.user_id)
if usage['balance'] < (prompt['inference_steps']*prompt['number'])/quality:
await ev.respond(f"Sorry. You only have {usage['balance']} cycles left out of the {user['daily_cycles']} 🌀 you can use every day. You can use /cycles to have more info on how many cycles you can use.")
return
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: confirm prompt')
conn.execute('INSERT INTO prompt SELECT NULL, NULL, NULL, *, ?, NULL, NULL, ? FROM pending_prompt WHERE user_id = ?', (time(),quality,ev.input_sender.user_id)).fetchone()
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
prompt = conn.execute('SELECT * FROM prompt WHERE user_id = ? AND is_done IS NULL ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
await edit_or_respond(ev,
"\n".join([f"✅ Your prompt {prompt['id']} has been scheduled and will be generated soon!\n",
f"<strong>You spent {prompt['number']*prompt['inference_steps']} cycle/bs. You have {int(usage['balance']-((prompt['inference_steps']*prompt['number'])/quality))} left for today.</strong>"]), parse_mode='HTML')
await queue(ev, new_message=True)
await client.send_message(client.log_channel_id, f"New prompt #{prompt['id']} by {ev.input_sender.user_id} {get_display_name(ev.sender)}\n\n{prompt['prompt']}",
buttons=[[Button.inline(f"Delete prompt", f"delete_mod:{prompt['id']}"),
Button.inline(f"Ban user", f"ban_mod:{prompt['user_id']}")]])
task = (prompt['id'], prompt['id'],)
if task in client.queue._queue:
self.log.error(f"Tried to insert a duplicate task while confirming!!!")
print(task, client.queue._queue)
return
client.queue.put_nowait(task)
await start_queue(ev)
raise events.StopPropagation
desc = {
'seed': 'The seed must be a whole number between 0 and 1000000. It defines the random noise which will be used to begin generation - two images with the same seed will be identical. Write it or get the one from the previous generation.',
'image': 'Upload an image, give me an e621 link or just delete the current one.',
'prompt': 'A list of e621 tags, a phrase or an e621 link to use as a prompt for your image',
'negative_prompt': 'A list of e621 tags you don\'t want to see. Do not put - in front of the tags',
'number': 'Number of images to generate. Write a value or press the buttons.',
'detail': 'The detail (aka "guidance scale") can be set between 2 and 50. Lower values will create softer images, while higher values will create images with a lot of contrast and perhaps distortion. Write a value or press the buttons.',
'inference_steps': 'Define the amount of time, in "cycles", to spend generating the images.\nHigher time usually means higher quality but only if the tags are good enough. Try around ~40/image.',
'blend': '0.0 to 1.0, the amount of transformation to apply on the base image. The higher the value, the more the result image will be different than the source.',
'resolution': 'The width and height of the final image.',
'crop': 'Do you want to crop the base image so that it matches the generated image resolution?',
'hires': 'Do you want to receive a high res image at the end of the generation? (it will take more time)',
}
desc_buttons = {
'seed': [Button.inline(f"Get from last prompt", f"msg_but:last_prompt"),],
'number': [Button.inline(f"1", f"msg_but:1"), Button.inline(f"2", f"msg_but:2"), Button.inline(f"3", f"msg_but:3"), Button.inline(f"4", f"msg_but:4")],
'detail': [Button.inline(f"S (6.0)", f"msg_but:6"), Button.inline(f"M (10)", f"msg_but:10"), Button.inline(f"L (18)", f"msg_but:18")],
'inference_steps': [Button.inline(f"S (20)", f"msg_but:20"), Button.inline(f"M (40)", f"msg_but:40"), Button.inline(f"L (60)", f"msg_but:60")],
'blend': [Button.inline(f"S (0.3)", f"msg_but:0.3"), Button.inline(f"M (0.6)", f"msg_but:0.6"), Button.inline(f"L (0.8)", f"msg_but:0.8")],
'resolution': [[Button.inline(f"Potrait (512x768)", f"msg_but:512x768"), Button.inline(f"Landscape (768x512)", f"msg_but:768x512")],[Button.inline(f"Square (512x512)", f"msg_but:512x512")], #[Button.inline(f"Ultrawide (1024x512)", f"msg_but:1024x512"), Button.inline(f"Ultratall (512x1024)", f"msg_but:512x1024")]
],
'image': [Button.inline(f"🚫 No image", f"msg_but:no_image")],
'crop': [Button.inline(f"🚫 Don't touch it", f"msg_but:no"),Button.inline(f"✂️ Crop it", f"msg_but:yes")],
'hires': [Button.inline(f"Normal image", f"msg_but:no"),Button.inline(f"High resolution", f"msg_but:yes")],
}
@client.on(events.NewMessage(pattern='^/cycles', incoming=True))
async def cycles_notice(ev):
user = conn.execute("SELECT * FROM user WHERE id = ?", (ev.input_sender.user_id,)).fetchone()
usage = get_credits(ev.input_sender.user_id)
await ev.respond(f"Hello {user['name']}. You have {usage['balance']}/{user['daily_cycles']} cycles left." + (f"\nYou are currently earning {usage['hourly_gain']} cycles/hour. Full amount in {int(usage['remaining_time']/3600)}h{int((usage['remaining_time']/60)%60):0>2}m{int(usage['remaining_time']%60):0>2}s." if usage['remaining_time'] else ''))
@client.on(events.NewMessage(pattern='^/ban', incoming=True))
async def ban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
await ev.respond('User banned.')
@client.on(events.NewMessage(pattern='^/unban', incoming=True))
async def unban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = NULL WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
await ev.respond('User unbanned.')
@client.on(events.NewMessage(pattern='^/setcredits', incoming=True))
async def unban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
await ev.respond(conn.execute('UPDATE user SET daily_cycles = ? WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),int(ev.message.raw_text.split(' ')[2]))))
await client.send_message(int(ev.message.raw_text.split(' ')[1]), f"Congrats! You can now use {ev.message.raw_text.split(' ')[2]} cycles/day. Have fun!")
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete_mod:'))
async def del_mod(ev):
if ev.input_sender.user_id != client.admin_id: return
prompt = conn.execute('SELECT * FROM prompt WHERE id = ?', (int(ev.data.decode().split(':')[1]),)).fetchone()
if prompt:
await client.send_message(prompt['user_id'], 'Hi. Your prompt has been deleted by a moderator.')
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?', (prompt['id'],))
await ev.answer('Prompt deleted.')
@client.on(events.callbackquery.CallbackQuery(pattern=r'^ban_mod:'))
async def del_mod(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.data.decode().split(':')[1]),))
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE user_id = ?', (int(ev.data.decode().split(':')[1]),))
await ev.answer('User banned.')
@client.on(events.callbackquery.CallbackQuery(pattern=r'^change:'))
async def setting(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {ev.data.decode()}')
field = ev.data.decode().split(':')[1]
pending_prompt = conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not pending_prompt:
await ev.edit('You cannot edit a prompt that doesn\'t exist anymore.')
return
conn.execute('UPDATE user SET pending_action = ? WHERE id = ?', (field, ev.input_sender.user_id))
await ev.respond(f"Please tell me the value for <strong>{field}</strong>.{chr(10)+chr(10)+desc[field] if field in desc else ''}", parse_mode='HTML', buttons=desc_buttons.get(field, None))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^edit$'))
async def edit_prompt(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: edit_prompt')
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,))
user = user.fetchone()
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
p = res.fetchone()
if not p:
await ev.respond('Sorry, it looks like you have no pending prompt.\n/new for a new one')
return
await ev.respond(f"""
<b>Your prompt</b>
🖼 Starting image: {'None' if p['image'] == 'no_image' else ('https://e621.net/posts/'+str(p['image_e6']) if p['image_e6'] else 'User uploaded image')}
🌱 Seed: {p['seed']}
👀 Prompt: <pre>{p['prompt']}</pre>
Negative prompt: <pre>{p['negative_prompt'] or 'no negative prompt set.'}</pre>
<b>Do not touch these if you don't know what you're doing</b>
🔢 Number of images: {p['number']}
🖥 Resolution: {p['resolution']}
High resolution: {p['hires']} (<i>slower generation!</i>
💎 Detail: {p['detail']} (<i>Guidance scale</i>)
🌀 Generation cycles: {p['inference_steps']} cycles ({p['inference_steps']*p['number']} total) (<i>Inference steps</i>)""" +
(f"\n🎚 Blend amount: {p['blend']} (<i>Prompt strength</i>)\n✂️ Crop: {'Yes' if p['crop'] == '1' else 'No'}" if p['image'] != 'no_image' else ''),
parse_mode='HTML',
link_preview=False,
buttons = [
[
Button.inline(f"👀 Prompt", f"change:prompt"),
Button.inline(f"⛔ Negative prompt", f"change:negative_prompt"),
],
[
Button.inline(f"🌱 Seed", f"change:seed"),
Button.inline(f"🔢 Number", f"change:number"),
Button.inline(f"✨ High resolution", f"change:hires")
],
[
Button.inline(f"💎 Detail", f"change:detail"),
Button.inline(f"🌀 Gen cycles", f"change:inference_steps"),
],
[
Button.inline(f"🖼 Base image", f"change:image"),
Button.inline(f"🖥 Resolution", f"change:resolution"),
*([Button.inline(f"🎚 Blend", f"change:blend"),Button.inline(f"✂️ Crop", f"change:crop")] if p['image'] != 'no_image' else [])
],
[
Button.inline(f"✅ Confirm", f"confirm"),
Button.inline(f"🕵️ Analyze", f"analyze"),
Button.inline(f"❌ Delete", f"delete")
]
]
)
@client.on(events.callbackquery.CallbackQuery(pattern=r'^analyze$'))
async def analyze_prompt(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: analyze')
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ? LIMIT 1', (ev.input_sender.user_id,)).fetchone()
if res:
comments, quality = judge(res['prompt'])
await ev.respond("\n".join(comments), parse_mode='HTML')
else:
await ev.respond('What am i supposed to analyze?')
@client.on(events.NewMessage(pattern='^/start', incoming=True, func=lambda e: e.is_private))
async def welcome(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: hello')
await ev.respond(f'Hello, and welcome to ai621. This bot can be used to generate yiff. Before beginning, keep in mind that:\n\n1. Images are public, no abusive stuff\n2. Using the bot maliciously or with multiple alts will lead to a ban\n3. The images generated by the bot are of public domain.\nGenerate a new prompt with /new\n\nDiscussion: @ai621chat\nWebsite: https://ai621.foxo.me/')
if __name__ == '__main__':
qqq = conn.execute("""SELECT used, user.daily_cycles, prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
JOIN (SELECT bb.user_id, sum(bb.number*bb.inference_steps) AS used FROM prompt AS bb WHERE bb.queued_at > strftime('%s', 'now')-86400 GROUP BY bb.user_id) aa ON aa.user_id = user.id
WHERE is_done IS NULL
ORDER BY started_at ASC NULLS LAST, prompt.id ASC""").fetchall()
for task in qqq:
if task in client.queue._queue:
self.log.error(f"Tried to insert a duplicate task while Resuming!!!")
print(task, client.queue._queue)
continue
client.queue.put_nowait((task['id'], task['id'],))
client.start()
client.flood_sleep_threshold = 24*60*60
client.run_until_disconnected()

42
config.example.py Normal file
View File

@ -0,0 +1,42 @@
from asyncio import PriorityQueue
from telethon import TelegramClient
import logging
import sqlite3
import coloredlogs
from process_queue import *
coloredlogs.install(level='INFO')
api_id = YOUR TG API ID HERE
api_hash = YOUR TG API HASH HERE
temp_folder = TEMP FOLDER OF THE GENERATIONS
client = TelegramClient('bot', api_id, api_hash)
e621 = sqlite3.connect('e621.db')
e621.row_factory = sqlite3.Row
conn = sqlite3.connect('ai621.db', isolation_level=None)
conn.row_factory = sqlite3.Row
client.conn = conn
client.process = None
client.log_channel_id = ID OF LOG CHANNEL
client.main_channel_id = ID OF RAW CHANNEL
client.queue = PriorityQueue()
client.media_lock = asyncio.Lock()
client.process_queue = False
client.admin_id = USER ID OF THE ADMIN
workers = [
Worker('http://127.0.0.1:9000', client, 'armorlink'),
#Worker('http://127.0.0.1:9001', client, 'armorlink-low'),
#Worker('http://local.proxy:9000', client, 'g14')
]
log = logging.getLogger('bot')

26
e621_import.py Normal file
View File

@ -0,0 +1,26 @@
import csv
import sys
from glob import glob
import sqlite3
conn = sqlite3.connect('e621.db')
conn.execute('DELETE FROM post')
csv.field_size_limit(sys.maxsize)
with open(glob('posts-*')[0]) as csvfile:
posts = csv.reader(csvfile, delimiter=',', quotechar='"')
i = 0
for p in posts:
i += 1
if i == 1:
continue
if i%10000 == 0:
print(p[0], end='\r')
conn.commit()
conn.execute('INSERT INTO post(id, file_id, tags) VALUES (?,?,?)', (int(p[0]), p[3]+'.'+p[11], p[8]))
conn.commit()

65
judge_prompt.py Normal file
View File

@ -0,0 +1,65 @@
import re
good_chars = 'abcdefghijklmnopqrstuvwxyz0123456789_- ,.()'
word_chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
stopwords = [ 'stop', 'the', 'to', 'and', 'a', 'in', 'it', 'is', 'i', 'that', 'had', 'on', 'for', 'were', 'was']
regexes = [
(r"\w_\w", 'It looks like you are using the character "_" to separate words in tags. Use spaces instead.', 0.9),
(r"\W-\w", 'It looks like you are using the "-" character to exclude tags. Put them into the "negative prompt" instead.', 0.9),
(r"^.{,60}$", 'Your prompt is very short. You will probably get a bad image.', 0.8),
(r"\b(stalin|hitler|nazi|nigger|waluigi|luigi|toilet)\b", 'Seriously?', 0.01),
(r"\b(loli|shota|cub|young|age difference|preteen|underage|child|teen)\b", 'Trying to generate cub art will probably make you banned. I warned you.', 0.001),
(r"\b(scat|poop|shit|piss|urine|pooping|rape)\b", 'Some of the tags you sent will be ignored because they were also blacklisted on e621.', 0.75),
(r"\({3,}", 'No need to use that many (((parenthesis))). That will give you a worse image.', 0.9),
(r"\[{3,}", 'Square [braces] will reduce the enphasis on a tag.', 1.0),
(r"\W#", 'There is no need to put # in front of tags. That\'ll worsen the quality of the image', 0.9),
(r"[👎👍🌀🌱💎]", "If you copy prompts from the channel, at least copy <strong>only</strong> the prompt.", 0.001)
]
tags = {}
with open('yiffy_tags.csv') as f:
while 1:
line = f.readline()
if not line: break
tag, value = line.strip().rsplit(',', 1)
if value == 'count': continue
tags[tag] = int(value)
sorted_by_size = sorted(tags.keys(), key=lambda x: len(x), reverse=True)
def judge(prompt):
prompt = ' '+prompt.lower().replace("\n", ' ')+' '
quality = 1.0
comments = []
found_tags = {}
for tag in sorted_by_size:
pos = prompt.find(tag)
if pos == -1: continue
if prompt[pos-1] in word_chars: continue
if prompt[pos+len(tag)] in word_chars: continue
found_tags[tag] = tags[tag]
if len(found_tags) == 0:
quality *= 0.65
comments.append(f"Your prompt doesn't even contain one tag from e621.")
elif len(found_tags) < 6:
quality *= 0.8
comments.append(f"Found only {len(found_tags)} tags in your prompt. The AI knows ~{int(sum(found_tags.values())/len(found_tags))} images from e621 with these tags.")
else:
comments.append(f"Found {len(found_tags)} tags in your prompt. The AI knows ~{int(sum(found_tags.values())/len(found_tags))} images from e621 with these tags.")
for pattern, comment, value in regexes:
match = re.search(pattern, prompt)
if match:
quality *= value
comments.append(comment)
if quality < 1:
comments.append(f"Because of these issues, you will consume {(1/quality):.2f}x the amount of usual cycles.")
return comments, quality

209
process_queue.py Normal file
View File

@ -0,0 +1,209 @@
import asyncio
from PIL import Image
from PIL import ImageOps
from base64 import b64encode, b64decode
from io import BytesIO
import logging
import json
import httpx
from os import listdir
from os.path import join
from time import time
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
log = logging.getLogger('process')
temp_folder = '/home/ed/temp/'
default_vars = {
"use_cpu":False,
"use_full_precision": False,
"stream_progress_updates": True,
"stream_image_progress": False,
"show_only_filtered_image": True,
"sampler_name": "dpm_solver_stability",
"save_to_disk_path": temp_folder,
"output_format": "png",
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
"metadata_output_format": "embed",
"use_hypernetwork_model": "boring_e621",
"hypernetwork_strength": 0.25,
}
class Worker:
def __init__(self, api, client, name):
self.api = api
self.ready = False
self.client = client
self.queue = client.queue
self.conn = client.conn
self.loop = asyncio.get_event_loop()
self.task = None
self.name = name
self.log = logging.getLogger(name)
self.prompt = None
def start(self, future=None):
if not self.client.process_queue:
asyncio.create_task(self.client.send_message(self.client.admin_id, f"Loop of {self.name} has been stopped."))
return
if self.task and not self.task.done(): return
if future and future.exception():
self.log.error(future.exception())
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), self.prompt))
return
try:
priority, prompt_id = self.queue.get_nowait()
except asyncio.QueueEmpty:
self.log.info('No more tasks to process!')
else:
prompt = self.client.conn.execute('SELECT prompt.* FROM prompt WHERE id = ?', (prompt_id,)).fetchone()
self.log.info(f"Processing {prompt_id}")
self.task = self.loop.create_task(self.process_prompt(prompt, (priority, prompt_id)))
self.task.add_done_callback(self.start)
async def process_prompt(self, prompt, queue_item):
# First of all, check if the user can still be messaged.
async with httpx.AsyncClient() as httpclient:
try:
await httpclient.get(join(self.api, 'ping'), timeout=5)
except Exception as e:
print(str(e))
log.error('Server is dead. Waiting 10 seconds for server availability...')
await self.queue.put(queue_item)
await asyncio.sleep(10)
return
try:
msg = await self.client.send_message(prompt['user_id'], f"Hello 👀 Generating your prompt now.")
except UserIsBlockedError:
self.conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?',(prompt['id'],))
return
# Prepare the parameters for the request
params = default_vars.copy()
params['session_id'] = str(prompt['id'])
params['prompt'] = prompt['prompt'] or ''
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
params['num_outputs'] = int(prompt['number'])
params['num_inference_steps'] = prompt['inference_steps']
params['guidance_scale'] = prompt['detail']
params['width'] = prompt['resolution'].split('x')[0]
params['height'] = prompt['resolution'].split('x')[1]
params['seed'] = str(prompt['seed'])
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
self.prompt = prompt['id']
if prompt['hires'] == 'yes':
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
if prompt['image'] != 'no_image':
img = Image.open(prompt['image'])
img = img.convert('RGB')
if prompt['crop'] == 'no':
img = img.resize(list((int(x) for x in prompt['resolution'].split('x'))))
else:
img = ImageOps.fit(img, list((int(x) for x in prompt['resolution'].split('x'))))
imgdata = BytesIO()
img.save(imgdata, format='JPEG')
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
params['sampler_name'] = 'ddim'
params['prompt_strength'] = prompt['blend']
async with httpx.AsyncClient() as httpclient:
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
start = time()
failed = False
self.log.info('POST to server')
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
res = res.json()
last_edit = 0
while 1:
step = await httpclient.get(join(self.api, res['stream'][1:]))
try:
data = step.json()
except:
continue
if 'step' in data:
if int(data['step'])%10 == 0:
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
if time() - last_edit > 10:
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
last_edit = time()
elif 'status' in data and data['status'] == 'failed':
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
break
elif 'status' in data and data['status'] == 'succeeded':
self.log.info('Success!')
images = []
for img in data['output']:
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
break
else:
print(data)
await asyncio.sleep(2)
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
await msg.delete()
if not failed:
asyncio.create_task(self.send_submission(prompt, images))
async def send_submission(self, prompt, images):
tg_files = []
for fn in images:
img = Image.open(fn)
img.thumbnail((1280,1280))
imgdata = BytesIO()
img.save(imgdata, format='JPEG')
siz = imgdata.seek(0,2)
imgdata.seek(0)
tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name))
results = await self.client.send_file(self.client.main_channel_id, tg_files, caption=
"\n".join([f"#{prompt['id']} · 🌀 {prompt['inference_steps']} · 🌱 {prompt['seed']} · 💎 {prompt['detail']}" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''),
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']),
(f"\n👎 {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML')
await self.client.forward_messages(prompt['user_id'], results)
if prompt['hires'] == 'yes':
await self.client.send_message(prompt['user_id'], 'Uploading raw images... This will take a while')
tg_files = []
for img in images:
siz = img.seek(0,2)
img.seek(0)
tg_files.append(await self.client.upload_file(img, file_size=siz, file_name=img.name))
results = await self.client.send_file(self.client.main_channel_id, tg_files, force_document=True, caption=f"Raw images of #{prompt['id']}")
await self.client.forward_messages(prompt['user_id'], results)
self.log.info(f"Files for prompt #{prompt['id']} have been sent succesfully.")

1
yiffy-e18.json Normal file

File diff suppressed because one or more lines are too long

77845
yiffy_tags.csv Normal file

File diff suppressed because it is too large Load Diff