ai621/bot.py

621 lines
31 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging, asyncio
from telethon import events
from telethon.tl.custom import Button
from os.path import isfile
from time import time
import httpx
import re
import sqlite3
from os import unlink
from random import randint
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
from telethon.utils import get_display_name
from config import *
from judge_prompt import judge
def get_credits(user_id):
res = conn.execute("""SELECT count(*), prompt.id, daily_cycles, daily_cycles-coalesce(cast(sum((86400-(strftime('%s')-completed_at))/86400*(number*inference_steps/quality)) as int), 0) AS balance, coalesce(86400-(STRFTIME('%s')-max(completed_at)), 0) AS remaining_time, CAST(sum(number*inference_steps/quality)/24 AS integer) AS hourly_gain FROM user
LEFT JOIN prompt ON user_id = user.id AND (strftime('%s')-completed_at) < 86400
WHERE user.id = ?
GROUP BY user.id
""", (user_id,)).fetchone()
return res
async def edit_or_respond(ev, *args, **kwargs):
if hasattr(ev, 'message') and ev.message.input_sender.user_id != (await client.get_me(input_peer=True)).user_id:
await ev.respond(*args, **kwargs)
return
try:
await ev.edit(*args, **kwargs)
except Exception as e:
if isinstance(e, MessageNotModifiedError): return
await ev.respond(*args, **kwargs)
@client.on(events.NewMessage(incoming=True, func=lambda e: e.is_private))
async def maintenance(ev):
conn.execute('INSERT INTO user(id, name) VALUES (?,?) ON CONFLICT DO NOTHING', (ev.input_sender.user_id,get_display_name(ev.sender)))
# if ev.input_sender.user_id != client.admin_id:
# await ev.respond('The bot is currently closed while i implement some new functions. For news and updates, go to @ai621chat')
# raise events.StopPropagation
# New prompt
@client.on(events.NewMessage(pattern='^/new', incoming=True, func=lambda e: e.is_private))
async def new_prompt(ev):
if conn.execute('SELECT 1 FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone():
await edit_or_respond(ev, 'You already have another prompt in the queue. Do you want to delete it?',
buttons=[[Button.inline(f"Check queue", f"queue")],[Button.inline(f"Delete", "delete_and_new")]])
return
if conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone():
await edit_or_respond(ev, 'You already have another prompt pending. Do you want to edit or overwrite it?',
buttons=[[Button.inline(f"No, edit old", f"edit")],[Button.inline(f"Overwrite", "delete_and_new")]])
return
buttons = [Button.inline(f"🚫 No image", f"msg_but:no_image"),]
if conn.execute('SELECT * FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone():
buttons.append(Button.inline(f"📋 Copy your last prompt", f"copy_last"))
await edit_or_respond(ev, 'Creating a new prompt!\nSend me an image, an e621 link, or just press on "No image" if you want to only use a text prompt.\n\n⚠️ From now on, we will use BB95 instead of yiffy-e18! Pino daeni and other tags are not valid anymore. ⚠️', buttons=buttons)
conn.execute('INSERT INTO pending_prompt(user_id, seed) VALUES (?, ?)', (ev.input_sender.user_id,randint(0,1000000)))
conn.execute('UPDATE user SET pending_action = \'initial_image\' WHERE id = ?', (ev.input_sender.user_id,))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^copy_last'))
async def copy_last_prompt(ev):
await ev.delete()
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
conn.execute('INSERT INTO pending_prompt SELECT user_id, image, prompt, detail, negative_prompt, inference_steps, number, ?, blend, prompt_e6, image_e6, resolution, crop, hires FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (randint(0,1000000), ev.input_sender.user_id))
await edit_prompt(ev)
async def step_two(ev):
msg = await ev.respond("Please give me a prompt to use.\n\nThis is a phrase, or a list of e621 tags, or a link to an e621 post to use as a source.\nYou can increase weight on some tags by enclosing them with ((brackets)).\n\nIf you're starting off from an image, please make sure the tags accurately describe the picture for the best results.\n\nSome examples:\n" +
'- "chunie wolf ((athletic male)) solo (abs) pecs standing topless swimwear"\n' +
'- "a digital drawing by chunie of a fox smiling at the side of a swimming pool, wearing swimwear and with white fur and grey ears"\n' +
'- "https://e621.net/posts/3549862"', link_preview=False, parse_mode='HTML')
conn.execute('UPDATE user SET pending_action = \'initial_prompt\' WHERE id = ?', (ev.input_sender.user_id,))
@client.on(events.callbackquery.CallbackQuery(pattern=r'delete_and_new'))
async def delete_and_new(ev):
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
conn.execute('DELETE FROM prompt WHERE user_id = ? AND is_done IS NULL AND started_at IS NULL', (ev.input_sender.user_id,))
await new_prompt(ev)
@client.on(events.callbackquery.CallbackQuery(pattern=r'^msg_but:'))
@client.on(events.NewMessage(incoming=True, pattern='^([^\/](\n|.)*)?$', func=lambda e: e.is_private))
async def edit_pending(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: got_value')
res = conn.execute('SELECT pending_action FROM user WHERE id = ? AND pending_action IS NOT NULL', (ev.input_sender.user_id,)).fetchone()
pending_prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not res:
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: no pending edit')
await ev.respond('You gave me a value, but you\'re not editing any parameter. Maybe try /start again?')
return
field = res[0].replace('initial_', '')
is_initial = res[0].startswith('initial_')
if hasattr(ev, 'message'):
data = ev.message.raw_text.strip()
else:
data = ev.data.decode().split(':', 1)[1]
if field not in ['image', 'prompt', 'negative_prompt', 'seed', 'number', 'detail', 'inference_steps', 'blend', 'resolution', 'crop', 'hires']:
print(field)
await ev.respond('Something went wrong. Try /start again.')
return
if field == 'number':
try:
data = int(data)
assert data <= 4
assert data > 0
except:
await ev.respond('You can only generate between 1 and 4 images. Try again:')
raise events.StopPropagation
if data*pending_prompt['inference_steps'] > 200:
await ev.respond(f"You have choosen to generate {data} images with {pending_prompt['inference_steps']} cycles, resulting in a total {data*pending_prompt['inference_steps']} cycles. If you want to increase the number of generated images, reduce the cycles so that the total is below 200.")
return
if field == 'resolution':
if data not in ['768x512', '512x512', '512x768']: #, '1024x512', '512x1024']:
await ev.respond('This resolution is invalid.')
raise events.StopPropagation
if field == 'detail':
try:
data = float(data)
assert data >= 2
assert data <= 30
except:
await ev.respond('Level of detail must be between 2 and 20. We suggest using 7.5 for best results.')
raise events.StopPropagation
if field == 'crop':
data = data.lower()
if data not in ['yes','no']:
await ev.respond('Answer for this question must be yes or no.')
raise events.StopPropagation
if field == 'hires':
data = data.lower()
if data not in ['yes','no']:
await ev.respond('Answer for this question must be yes or no.')
raise events.StopPropagation
if field == 'blend':
try:
data = float(data)
assert data >= 0.3
assert data <= 0.9
except:
await ev.respond('Level of blend must be between 0.3 and 0.9.')
raise events.StopPropagation
if field == 'inference_steps':
try:
data = int(data)
assert data >= 20
assert data <= 200
assert data % 10 == 0
except:
await ev.respond('Generation time must be an integer between 20 and 200, a multiple of 10.')
raise events.StopPropagation
if data*pending_prompt['number'] > 200:
await ev.respond(f"You have choosen to generate {pending_prompt['number']} images with {data} cycles, a total {pending_prompt['number']*data} cycles. If you want to increase the cycles, reduce the number of generated images so that the total is below 200.")
return
if field == 'image':
conn.execute('UPDATE pending_prompt SET image_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
if data == 'no_image':
pass
elif ev.message.photo:
await ev.client.download_media(ev.message.photo, f"userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg")
data = f'userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg'
await ev.respond('Alright. I will use this image as a reference when generating.')
elif re.search(r'e621\.net/posts?/([0-9]+)', data):
e6_id = int(re.search(r'e621\.net/posts?/([0-9]+)', data).group(1))
e6_post = e621.execute('SELECT * FROM post WHERE id = ?', (e6_id,)).fetchone()
if not e6_post:
await ev.respond('This post doesn\'t seem to exist in our database. This means it\'s too new (database is refreshed once a day) or it was deleted.')
return
if not (e6_post['file_id'].endswith('png') or e6_post['file_id'].endswith('jpg')):
await ev.respond('The e621 post you gave me doesn\'t look like an image. Are you sure it\'s not a gif or a video?')
return
if not isfile(f"e621/{e6_post['file_id']}"):
msg = await ev.respond('Trying to download the image from e621...')
try:
async with httpx.AsyncClient() as webcli:
with open(f"e621/{e6_post['file_id']}", 'wb') as f:
async with webcli.stream('GET', f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}") as response:
async for chunk in response.aiter_bytes():
f.write(chunk)
if response.status_code != 200:
print(f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}")
unlink(f"e621/{e6_post['file_id']}")
await ev.respond(f'I wasn\'t able to download this image. (Got status code {response.status_code}) Try again later')
return
except Exception as e:
await msg.edit(f"I encountered an issue downloading your image: {str(e)}. Please try with another image")
return
else:
await msg.edit('Downloaded!')
conn.execute('UPDATE pending_prompt SET image_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
data = f"e621/{e6_post['file_id']}"
else:
await ev.respond('You need to give me an e621 link or upload an image.', buttons=[Button.inline(f"🚫 No image", f"msg_but:no_image")])
return
if data != 'no_image':
img = Image.open(data)
ar = img.width/img.height
res = None
diff = 10
for r in [[768,512], [512,768], [512,512]]:
if abs(ar-(r[0]/r[1])) < diff:
res = f"{r[0]}x{r[1]}"
diff = abs(ar-(r[0]/r[1]))
if res:
conn.execute('UPDATE pending_prompt SET resolution = ? WHERE user_id = ?', (res, ev.input_sender.user_id))
await ev.respond(f'⚠️ Additionally, i\'ve automatically adjusted the resolution to {res} to match the aspect ratio.')
if field == 'prompt':
conn.execute('UPDATE pending_prompt SET prompt_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
if re.search(r'e621\.net\/posts?/([0-9]+)', data):
e6_id = re.search(r'e621\.net\/posts?/([0-9]+)', data).group(1)
print('Getting data from e621 (prompt)...')
e6_prompt = e621.execute('SELECT tags FROM post WHERE id = ?', (int(e6_id),)).fetchone()
if not e6_prompt:
await ev.respond('This post does not seem to exist on e621, or perhaps, it\'t too new. You can only use posts from yesterday or older.')
return
data = e6_prompt['tags']
if len(data) > 2000:
tags = data.split(' ')
e6_prompt = ''
while len(e6_prompt) < 1000:
e6_prompt += ' '+tags.pop(randint(0, len(tags)-1))
data = e6_prompt
await ev.respond('⚠️ Since this post had too many tags, I had to delete some of them.')
conn.execute('UPDATE pending_prompt SET prompt_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
else:
tags = re.split(r'\W+', data)
if len(tags) < 5:
await ev.respond('At least five tags are needed for the prompt. Please try again:')
return
if field == 'seed':
if data == 'last_prompt':
data = conn.execute('SELECT seed FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
if data: data = data[0]
try:
data = int(data)%1000000
except:
await ev.respond('The seed needs to be an integer between 0 and 1000000. Try again:')
return
conn.execute(f"UPDATE user SET pending_action = NULL WHERE id = ?", (ev.input_sender.user_id,))
conn.execute(f"UPDATE pending_prompt SET {field.replace('initial_','')} = ? WHERE user_id = ?", (data, ev.input_sender.user_id))
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {field} -> {str(data)[:32]}')
if is_initial and field == 'image':
await step_two(ev)
else:
await edit_prompt(ev)
return
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete$'))
async def delete_prompt(ev):
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
await edit_or_respond(ev, f'Prompt has been deleted. Use /new to create a new one.')
@client.on(events.NewMessage(pattern='^/queue', incoming=True))
async def queue_info(ev):
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
avg_wait = conn.execute("SELECT avg(strftime('%s', 'now')-queued_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
min_max_wait = conn.execute("SELECT max(strftime('%s', 'now')-queued_at), min(strftime('%s', 'now')-queued_at) FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
waiting_usrs = conn.execute("SELECT count(*) FROM prompt WHERE is_done IS NULL").fetchone()
avg_speed = conn.execute("SELECT avg(aa) FROM (SELECT (inference_steps*number)/abs(completed_at-started_at) AS aa FROM prompt WHERE started_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)").fetchone()
await ev.respond("\n".join([
f"👯‍♂️ {waiting_usrs[0]} people are waiting in the queue",
f"{'🔴' if int(avg_comp_time[0]/60) > 20 else ('🟡' if int(avg_comp_time[0]/60) > 3 else '🟢')} Average queue-to-result time: <strong>{int(avg_comp_time[0]/60)} minutes</strong>.",
f"{'🔴' if int(min_max_wait[0]/60) > 20 else ('🟡' if int(min_max_wait[0]/60) > 5 else '🟢')} Current queue duration: {int(min_max_wait[1]/60)}~{int(min_max_wait[0]/60)} minutes",
f"🏃 Average speed: {avg_speed[0]:.2f} steps/sec"
]), parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/queuelist', incoming=True))
async def queue_list(ev):
ret = 'Current queue:\n\n'
queue = conn.execute("""SELECT prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
WHERE is_done IS NULL
ORDER BY prompt.id ASC""")
for item in queue:
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
await ev.respond(ret, parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/broke', incoming=True))
async def queue_list(ev):
ret = 'Current queue:\n\n'
queue = conn.execute("""SELECT user.daily_cycles, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
WHERE is_done IS NULL
ORDER BY prompt.id ASC""")
for item in queue:
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
await ev.respond(ret, parse_mode='HTML')
@client.on(events.NewMessage(pattern='^/addworker', incoming=True))
async def add_worker(ev):
if ev.input_sender.user_id != client.admin_id: return
command, api_url, name = ev.message.raw_text.split(' ')
workers.append(Worker(api_url, client, name))
@client.on(events.NewMessage(pattern='^/stop', incoming=True))
async def stop_queue(ev):
if ev.input_sender.user_id != client.admin_id: return
client.process_queue = False
await ev.respond("Stopping queue processing.")
@client.on(events.NewMessage(pattern='^/process', incoming=True))
async def start_queue(ev):
if ev.input_sender.user_id == client.admin_id:
await ev.respond(f'Starting queue processing')
client.process_queue = True
if not client.process_queue:
await ev.respond('Please hold on a bit more time as the bot is going under maintenance.')
return
for w in workers:
w.start()
@client.on(events.NewMessage(pattern='^/queueraw', incoming=True))
async def queue_list(ev):
if ev.input_sender.user_id != client.admin_id: return
await ev.respond(str(client.queue._queue))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^queue$'))
async def queue(ev, new_message=False):
prompt = conn.execute('SELECT * FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not prompt:
await edit_or_respond(ev, 'You don\'t have pending prompts :)\n/new for a new one')
raise events.StopPropagation
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
behind_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id > ?', (prompt['id'],)).fetchone()[0]
front_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id < ?', (prompt['id'],)).fetchone()[0]
if new_message:
await client.send_message(ev.sender, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
else:
await edit_or_respond(ev, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
@client.on(events.callbackquery.CallbackQuery(pattern=r'^confirm$'))
async def confirm_prompt(ev):
# Do not allow banned users to confirm prompts
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,)).fetchone()
if user['is_banned'] == 1:
await ev.respond('You have been banned from sending further prompts due to abuse.')
return
# Check if the prompt exists
prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not prompt or not prompt['prompt']:
await ev.respond('Looks like you have nothing to confirm. Try /new for a new prompt.')
return
# Analyze the prompt
comments, quality = judge(prompt['prompt'])
# Check if the user has used all of his cycles
usage = get_credits(ev.input_sender.user_id)
if usage['balance'] < (prompt['inference_steps']*prompt['number'])/quality:
await ev.respond(f"Sorry. You only have {usage['balance']} cycles left out of the {user['daily_cycles']} 🌀 you can use every day. You can use /cycles to have more info on how many cycles you can use.")
return
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: confirm prompt')
conn.execute('INSERT INTO prompt SELECT NULL, NULL, NULL, *, ?, NULL, NULL, ? FROM pending_prompt WHERE user_id = ?', (time(),quality,ev.input_sender.user_id)).fetchone()
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
prompt = conn.execute('SELECT * FROM prompt WHERE user_id = ? AND is_done IS NULL ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
await edit_or_respond(ev,
"\n".join([f"✅ Your prompt {prompt['id']} has been scheduled and will be generated soon!\n",
f"<strong>You spent {prompt['number']*prompt['inference_steps']} cycle/bs. You have {int(usage['balance']-((prompt['inference_steps']*prompt['number'])/quality))} left for today.</strong>"]), parse_mode='HTML')
await queue(ev, new_message=True)
await client.send_message(client.log_channel_id, f"New prompt #{prompt['id']} by {ev.input_sender.user_id} {get_display_name(ev.sender)}\n\n{prompt['prompt']}",
buttons=[[Button.inline(f"Delete prompt", f"delete_mod:{prompt['id']}"),
Button.inline(f"Ban user", f"ban_mod:{prompt['user_id']}")]])
task = (prompt['id'], prompt['id'],)
if task in client.queue._queue:
self.log.error(f"Tried to insert a duplicate task while confirming!!!")
print(task, client.queue._queue)
return
client.queue.put_nowait(task)
await start_queue(ev)
raise events.StopPropagation
desc = {
'seed': 'The seed must be a whole number between 0 and 1000000. It defines the random noise which will be used to begin generation - two images with the same seed will be identical. Write it or get the one from the previous generation.',
'image': 'Upload an image, give me an e621 link or just delete the current one.',
'prompt': 'A list of e621 tags, a phrase or an e621 link to use as a prompt for your image',
'negative_prompt': 'A list of e621 tags you don\'t want to see. Do not put - in front of the tags',
'number': 'Number of images to generate. Write a value or press the buttons.',
'detail': 'The detail (aka "guidance scale") can be set between 2 and 50. Lower values will create softer images, while higher values will create images with a lot of contrast and perhaps distortion. Write a value or press the buttons.',
'inference_steps': 'Define the amount of time, in "cycles", to spend generating the images.\nHigher time usually means higher quality but only if the tags are good enough. Try around ~40/image.',
'blend': '0.0 to 1.0, the amount of transformation to apply on the base image. The higher the value, the more the result image will be different than the source.',
'resolution': 'The width and height of the final image.',
'crop': 'Do you want to crop the base image so that it matches the generated image resolution?',
'hires': 'Do you want to receive a high res image at the end of the generation? (it will take more time)',
}
desc_buttons = {
'seed': [Button.inline(f"Get from last prompt", f"msg_but:last_prompt"),],
'number': [Button.inline(f"1", f"msg_but:1"), Button.inline(f"2", f"msg_but:2"), Button.inline(f"3", f"msg_but:3"), Button.inline(f"4", f"msg_but:4")],
'detail': [Button.inline(f"S (6.0)", f"msg_but:6"), Button.inline(f"M (10)", f"msg_but:10"), Button.inline(f"L (18)", f"msg_but:18")],
'inference_steps': [Button.inline(f"S (20)", f"msg_but:20"), Button.inline(f"M (40)", f"msg_but:40"), Button.inline(f"L (60)", f"msg_but:60")],
'blend': [Button.inline(f"S (0.3)", f"msg_but:0.3"), Button.inline(f"M (0.6)", f"msg_but:0.6"), Button.inline(f"L (0.8)", f"msg_but:0.8")],
'resolution': [[Button.inline(f"Potrait (512x768)", f"msg_but:512x768"), Button.inline(f"Landscape (768x512)", f"msg_but:768x512")],[Button.inline(f"Square (512x512)", f"msg_but:512x512")], #[Button.inline(f"Ultrawide (1024x512)", f"msg_but:1024x512"), Button.inline(f"Ultratall (512x1024)", f"msg_but:512x1024")]
],
'image': [Button.inline(f"🚫 No image", f"msg_but:no_image")],
'crop': [Button.inline(f"🚫 Don't touch it", f"msg_but:no"),Button.inline(f"✂️ Crop it", f"msg_but:yes")],
'hires': [Button.inline(f"Normal image", f"msg_but:no"),Button.inline(f"High resolution", f"msg_but:yes")],
}
@client.on(events.NewMessage(pattern='^/cycles', incoming=True))
async def cycles_notice(ev):
user = conn.execute("SELECT * FROM user WHERE id = ?", (ev.input_sender.user_id,)).fetchone()
usage = get_credits(ev.input_sender.user_id)
await ev.respond(f"Hello {user['name']}. You have {usage['balance']}/{user['daily_cycles']} cycles left." + (f"\nYou are currently earning {usage['hourly_gain']} cycles/hour. Full amount in {int(usage['remaining_time']/3600)}h{int((usage['remaining_time']/60)%60):0>2}m{int(usage['remaining_time']%60):0>2}s." if usage['remaining_time'] else ''))
@client.on(events.NewMessage(pattern='^/ban', incoming=True))
async def ban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
await ev.respond('User banned.')
@client.on(events.NewMessage(pattern='^/unban', incoming=True))
async def unban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = NULL WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
await ev.respond('User unbanned.')
@client.on(events.NewMessage(pattern='^/setcredits', incoming=True))
async def unban_user(ev):
if ev.input_sender.user_id != client.admin_id: return
await ev.respond(conn.execute('UPDATE user SET daily_cycles = ? WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),int(ev.message.raw_text.split(' ')[2]))))
await client.send_message(int(ev.message.raw_text.split(' ')[1]), f"Congrats! You can now use {ev.message.raw_text.split(' ')[2]} cycles/day. Have fun!")
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete_mod:'))
async def del_mod(ev):
if ev.input_sender.user_id != client.admin_id: return
prompt = conn.execute('SELECT * FROM prompt WHERE id = ?', (int(ev.data.decode().split(':')[1]),)).fetchone()
if prompt:
await client.send_message(prompt['user_id'], 'Hi. Your prompt has been deleted by a moderator.')
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?', (prompt['id'],))
await ev.answer('Prompt deleted.')
@client.on(events.callbackquery.CallbackQuery(pattern=r'^ban_mod:'))
async def del_mod(ev):
if ev.input_sender.user_id != client.admin_id: return
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.data.decode().split(':')[1]),))
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE user_id = ?', (int(ev.data.decode().split(':')[1]),))
await ev.answer('User banned.')
@client.on(events.callbackquery.CallbackQuery(pattern=r'^change:'))
async def setting(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {ev.data.decode()}')
field = ev.data.decode().split(':')[1]
pending_prompt = conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
if not pending_prompt:
await ev.edit('You cannot edit a prompt that doesn\'t exist anymore.')
return
conn.execute('UPDATE user SET pending_action = ? WHERE id = ?', (field, ev.input_sender.user_id))
await ev.respond(f"Please tell me the value for <strong>{field}</strong>.{chr(10)+chr(10)+desc[field] if field in desc else ''}", parse_mode='HTML', buttons=desc_buttons.get(field, None))
@client.on(events.callbackquery.CallbackQuery(pattern=r'^edit$'))
async def edit_prompt(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: edit_prompt')
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,))
user = user.fetchone()
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
p = res.fetchone()
if not p:
await ev.respond('Sorry, it looks like you have no pending prompt.\n/new for a new one')
return
await ev.respond(f"""
<b>Your prompt</b>
🖼 Starting image: {'None' if p['image'] == 'no_image' else ('https://e621.net/posts/'+str(p['image_e6']) if p['image_e6'] else 'User uploaded image')}
🌱 Seed: {p['seed']}
👀 Prompt: <pre>{p['prompt']}</pre>
⛔ Negative prompt: <pre>{p['negative_prompt'] or 'no negative prompt set.'}</pre>
<b>Do not touch these if you don't know what you're doing</b>
🔢 Number of images: {p['number']}
🖥 Resolution: {p['resolution']}
✨ High resolution: {p['hires']} (<i>slower generation!</i>
💎 Detail: {p['detail']} (<i>Guidance scale</i>)
🌀 Generation cycles: {p['inference_steps']} cycles ({p['inference_steps']*p['number']} total) (<i>Inference steps</i>)""" +
(f"\n🎚 Blend amount: {p['blend']} (<i>Prompt strength</i>)\n✂️ Crop: {'Yes' if p['crop'] == '1' else 'No'}" if p['image'] != 'no_image' else ''),
parse_mode='HTML',
link_preview=False,
buttons = [
[
Button.inline(f"👀 Prompt", f"change:prompt"),
Button.inline(f"⛔ Negative prompt", f"change:negative_prompt"),
],
[
Button.inline(f"🌱 Seed", f"change:seed"),
Button.inline(f"🔢 Number", f"change:number"),
Button.inline(f"✨ High resolution", f"change:hires")
],
[
Button.inline(f"💎 Detail", f"change:detail"),
Button.inline(f"🌀 Gen cycles", f"change:inference_steps"),
],
[
Button.inline(f"🖼 Base image", f"change:image"),
Button.inline(f"🖥 Resolution", f"change:resolution"),
*([Button.inline(f"🎚 Blend", f"change:blend"),Button.inline(f"✂️ Crop", f"change:crop")] if p['image'] != 'no_image' else [])
],
[
Button.inline(f"✅ Confirm", f"confirm"),
Button.inline(f"🕵️ Analyze", f"analyze"),
Button.inline(f"❌ Delete", f"delete")
]
]
)
@client.on(events.callbackquery.CallbackQuery(pattern=r'^analyze$'))
async def analyze_prompt(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: analyze')
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ? LIMIT 1', (ev.input_sender.user_id,)).fetchone()
if res:
comments, quality = judge(res['prompt'])
await ev.respond("\n".join(comments), parse_mode='HTML')
else:
await ev.respond('What am i supposed to analyze?')
@client.on(events.NewMessage(pattern='^/start', incoming=True, func=lambda e: e.is_private))
async def welcome(ev):
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: hello')
await ev.respond(f'Hello, and welcome to ai621. This bot can be used to generate yiff. Before beginning, keep in mind that:\n\n1. Images are public, no abusive stuff\n2. Using the bot maliciously or with multiple alts will lead to a ban\n3. The images generated by the bot are of public domain.\nGenerate a new prompt with /new\n\nDiscussion: @ai621chat\nWebsite: https://ai621.foxo.me/')
if __name__ == '__main__':
qqq = conn.execute("""SELECT used, user.daily_cycles, prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
JOIN user ON user.id = prompt.user_id
JOIN (SELECT bb.user_id, sum(bb.number*bb.inference_steps) AS used FROM prompt AS bb WHERE bb.queued_at > strftime('%s', 'now')-86400 GROUP BY bb.user_id) aa ON aa.user_id = user.id
WHERE is_done IS NULL
ORDER BY started_at ASC NULLS LAST, prompt.id ASC""").fetchall()
for task in qqq:
if task in client.queue._queue:
self.log.error(f"Tried to insert a duplicate task while Resuming!!!")
print(task, client.queue._queue)
continue
client.queue.put_nowait((task['id'], task['id'],))
client.start()
client.flood_sleep_threshold = 24*60*60
client.run_until_disconnected()