(real) initial commit
This commit is contained in:
parent
56c6bde06d
commit
9e7a120561
|
@ -164,3 +164,4 @@ config.py
|
||||||
bot.session*
|
bot.session*
|
||||||
e621/*
|
e621/*
|
||||||
userdata/*
|
userdata/*
|
||||||
|
ai621.db
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,620 @@
|
||||||
|
import logging, asyncio
|
||||||
|
from telethon import events
|
||||||
|
from telethon.tl.custom import Button
|
||||||
|
from os.path import isfile
|
||||||
|
from time import time
|
||||||
|
import httpx
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
from os import unlink
|
||||||
|
from random import randint
|
||||||
|
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
||||||
|
from telethon.utils import get_display_name
|
||||||
|
from config import *
|
||||||
|
from judge_prompt import judge
|
||||||
|
|
||||||
|
def get_credits(user_id):
|
||||||
|
res = conn.execute("""SELECT count(*), prompt.id, daily_cycles, daily_cycles-coalesce(cast(sum((86400-(strftime('%s')-completed_at))/86400*(number*inference_steps/quality)) as int), 0) AS balance, coalesce(86400-(STRFTIME('%s')-max(completed_at)), 0) AS remaining_time, CAST(sum(number*inference_steps/quality)/24 AS integer) AS hourly_gain FROM user
|
||||||
|
LEFT JOIN prompt ON user_id = user.id AND (strftime('%s')-completed_at) < 86400
|
||||||
|
WHERE user.id = ?
|
||||||
|
GROUP BY user.id
|
||||||
|
""", (user_id,)).fetchone()
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def edit_or_respond(ev, *args, **kwargs):
|
||||||
|
|
||||||
|
if hasattr(ev, 'message') and ev.message.input_sender.user_id != (await client.get_me(input_peer=True)).user_id:
|
||||||
|
await ev.respond(*args, **kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ev.edit(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, MessageNotModifiedError): return
|
||||||
|
await ev.respond(*args, **kwargs)
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(incoming=True, func=lambda e: e.is_private))
|
||||||
|
async def maintenance(ev):
|
||||||
|
conn.execute('INSERT INTO user(id, name) VALUES (?,?) ON CONFLICT DO NOTHING', (ev.input_sender.user_id,get_display_name(ev.sender)))
|
||||||
|
# if ev.input_sender.user_id != client.admin_id:
|
||||||
|
# await ev.respond('The bot is currently closed while i implement some new functions. For news and updates, go to @ai621chat')
|
||||||
|
# raise events.StopPropagation
|
||||||
|
|
||||||
|
# New prompt
|
||||||
|
@client.on(events.NewMessage(pattern='^/new', incoming=True, func=lambda e: e.is_private))
|
||||||
|
async def new_prompt(ev):
|
||||||
|
|
||||||
|
if conn.execute('SELECT 1 FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone():
|
||||||
|
await edit_or_respond(ev, 'You already have another prompt in the queue. Do you want to delete it?',
|
||||||
|
buttons=[[Button.inline(f"Check queue", f"queue")],[Button.inline(f"Delete", "delete_and_new")]])
|
||||||
|
return
|
||||||
|
|
||||||
|
if conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone():
|
||||||
|
await edit_or_respond(ev, 'You already have another prompt pending. Do you want to edit or overwrite it?',
|
||||||
|
buttons=[[Button.inline(f"No, edit old", f"edit")],[Button.inline(f"Overwrite", "delete_and_new")]])
|
||||||
|
return
|
||||||
|
|
||||||
|
buttons = [Button.inline(f"🚫 No image", f"msg_but:no_image"),]
|
||||||
|
if conn.execute('SELECT * FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone():
|
||||||
|
buttons.append(Button.inline(f"📋 Copy your last prompt", f"copy_last"))
|
||||||
|
|
||||||
|
await edit_or_respond(ev, 'Creating a new prompt!\nSend me an image, an e621 link, or just press on "No image" if you want to only use a text prompt.\n\n⚠️ From now on, we will use BB95 instead of yiffy-e18! Pino daeni and other tags are not valid anymore. ⚠️', buttons=buttons)
|
||||||
|
conn.execute('INSERT INTO pending_prompt(user_id, seed) VALUES (?, ?)', (ev.input_sender.user_id,randint(0,1000000)))
|
||||||
|
conn.execute('UPDATE user SET pending_action = \'initial_image\' WHERE id = ?', (ev.input_sender.user_id,))
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^copy_last'))
|
||||||
|
async def copy_last_prompt(ev):
|
||||||
|
|
||||||
|
await ev.delete()
|
||||||
|
|
||||||
|
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
conn.execute('INSERT INTO pending_prompt SELECT user_id, image, prompt, detail, negative_prompt, inference_steps, number, ?, blend, prompt_e6, image_e6, resolution, crop, hires FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (randint(0,1000000), ev.input_sender.user_id))
|
||||||
|
|
||||||
|
await edit_prompt(ev)
|
||||||
|
|
||||||
|
async def step_two(ev):
|
||||||
|
msg = await ev.respond("Please give me a prompt to use.\n\nThis is a phrase, or a list of e621 tags, or a link to an e621 post to use as a source.\nYou can increase weight on some tags by enclosing them with ((brackets)).\n\nIf you're starting off from an image, please make sure the tags accurately describe the picture for the best results.\n\nSome examples:\n" +
|
||||||
|
'- "chunie wolf ((athletic male)) solo (abs) pecs standing topless swimwear"\n' +
|
||||||
|
'- "a digital drawing by chunie of a fox smiling at the side of a swimming pool, wearing swimwear and with white fur and grey ears"\n' +
|
||||||
|
'- "https://e621.net/posts/3549862"', link_preview=False, parse_mode='HTML')
|
||||||
|
|
||||||
|
conn.execute('UPDATE user SET pending_action = \'initial_prompt\' WHERE id = ?', (ev.input_sender.user_id,))
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'delete_and_new'))
|
||||||
|
async def delete_and_new(ev):
|
||||||
|
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
conn.execute('DELETE FROM prompt WHERE user_id = ? AND is_done IS NULL AND started_at IS NULL', (ev.input_sender.user_id,))
|
||||||
|
await new_prompt(ev)
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^msg_but:'))
|
||||||
|
@client.on(events.NewMessage(incoming=True, pattern='^([^\/](\n|.)*)?$', func=lambda e: e.is_private))
|
||||||
|
async def edit_pending(ev):
|
||||||
|
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: got_value')
|
||||||
|
|
||||||
|
res = conn.execute('SELECT pending_action FROM user WHERE id = ? AND pending_action IS NOT NULL', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
pending_prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
|
||||||
|
if not res:
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: no pending edit')
|
||||||
|
await ev.respond('You gave me a value, but you\'re not editing any parameter. Maybe try /start again?')
|
||||||
|
return
|
||||||
|
|
||||||
|
field = res[0].replace('initial_', '')
|
||||||
|
is_initial = res[0].startswith('initial_')
|
||||||
|
|
||||||
|
if hasattr(ev, 'message'):
|
||||||
|
data = ev.message.raw_text.strip()
|
||||||
|
else:
|
||||||
|
data = ev.data.decode().split(':', 1)[1]
|
||||||
|
|
||||||
|
if field not in ['image', 'prompt', 'negative_prompt', 'seed', 'number', 'detail', 'inference_steps', 'blend', 'resolution', 'crop', 'hires']:
|
||||||
|
print(field)
|
||||||
|
await ev.respond('Something went wrong. Try /start again.')
|
||||||
|
return
|
||||||
|
|
||||||
|
if field == 'number':
|
||||||
|
try:
|
||||||
|
data = int(data)
|
||||||
|
assert data <= 4
|
||||||
|
assert data > 0
|
||||||
|
except:
|
||||||
|
await ev.respond('You can only generate between 1 and 4 images. Try again:')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if data*pending_prompt['inference_steps'] > 200:
|
||||||
|
await ev.respond(f"You have choosen to generate {data} images with {pending_prompt['inference_steps']} cycles, resulting in a total {data*pending_prompt['inference_steps']} cycles. If you want to increase the number of generated images, reduce the cycles so that the total is below 200.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if field == 'resolution':
|
||||||
|
if data not in ['768x512', '512x512', '512x768']: #, '1024x512', '512x1024']:
|
||||||
|
await ev.respond('This resolution is invalid.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if field == 'detail':
|
||||||
|
try:
|
||||||
|
data = float(data)
|
||||||
|
assert data >= 2
|
||||||
|
assert data <= 30
|
||||||
|
except:
|
||||||
|
await ev.respond('Level of detail must be between 2 and 20. We suggest using 7.5 for best results.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if field == 'crop':
|
||||||
|
data = data.lower()
|
||||||
|
if data not in ['yes','no']:
|
||||||
|
await ev.respond('Answer for this question must be yes or no.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if field == 'hires':
|
||||||
|
data = data.lower()
|
||||||
|
if data not in ['yes','no']:
|
||||||
|
await ev.respond('Answer for this question must be yes or no.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if field == 'blend':
|
||||||
|
try:
|
||||||
|
data = float(data)
|
||||||
|
assert data >= 0.3
|
||||||
|
assert data <= 0.9
|
||||||
|
except:
|
||||||
|
await ev.respond('Level of blend must be between 0.3 and 0.9.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if field == 'inference_steps':
|
||||||
|
try:
|
||||||
|
data = int(data)
|
||||||
|
assert data >= 20
|
||||||
|
assert data <= 200
|
||||||
|
assert data % 10 == 0
|
||||||
|
except:
|
||||||
|
await ev.respond('Generation time must be an integer between 20 and 200, a multiple of 10.')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
if data*pending_prompt['number'] > 200:
|
||||||
|
await ev.respond(f"You have choosen to generate {pending_prompt['number']} images with {data} cycles, a total {pending_prompt['number']*data} cycles. If you want to increase the cycles, reduce the number of generated images so that the total is below 200.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if field == 'image':
|
||||||
|
conn.execute('UPDATE pending_prompt SET image_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
if data == 'no_image':
|
||||||
|
pass
|
||||||
|
elif ev.message.photo:
|
||||||
|
await ev.client.download_media(ev.message.photo, f"userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg")
|
||||||
|
data = f'userdata/{ev.input_sender.user_id}_{ev.message.id}.jpg'
|
||||||
|
await ev.respond('Alright. I will use this image as a reference when generating.')
|
||||||
|
elif re.search(r'e621\.net/posts?/([0-9]+)', data):
|
||||||
|
e6_id = int(re.search(r'e621\.net/posts?/([0-9]+)', data).group(1))
|
||||||
|
e6_post = e621.execute('SELECT * FROM post WHERE id = ?', (e6_id,)).fetchone()
|
||||||
|
if not e6_post:
|
||||||
|
await ev.respond('This post doesn\'t seem to exist in our database. This means it\'s too new (database is refreshed once a day) or it was deleted.')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (e6_post['file_id'].endswith('png') or e6_post['file_id'].endswith('jpg')):
|
||||||
|
await ev.respond('The e621 post you gave me doesn\'t look like an image. Are you sure it\'s not a gif or a video?')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isfile(f"e621/{e6_post['file_id']}"):
|
||||||
|
msg = await ev.respond('Trying to download the image from e621...')
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as webcli:
|
||||||
|
with open(f"e621/{e6_post['file_id']}", 'wb') as f:
|
||||||
|
async with webcli.stream('GET', f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}") as response:
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"https://static1.e621.net/data/{e6_post['file_id'][0:2]}/{e6_post['file_id'][2:4]}/{e6_post['file_id']}")
|
||||||
|
unlink(f"e621/{e6_post['file_id']}")
|
||||||
|
await ev.respond(f'I wasn\'t able to download this image. (Got status code {response.status_code}) Try again later')
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await msg.edit(f"I encountered an issue downloading your image: {str(e)}. Please try with another image")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await msg.edit('Downloaded!')
|
||||||
|
|
||||||
|
conn.execute('UPDATE pending_prompt SET image_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
|
||||||
|
data = f"e621/{e6_post['file_id']}"
|
||||||
|
else:
|
||||||
|
await ev.respond('You need to give me an e621 link or upload an image.', buttons=[Button.inline(f"🚫 No image", f"msg_but:no_image")])
|
||||||
|
return
|
||||||
|
|
||||||
|
if data != 'no_image':
|
||||||
|
img = Image.open(data)
|
||||||
|
ar = img.width/img.height
|
||||||
|
res = None
|
||||||
|
diff = 10
|
||||||
|
for r in [[768,512], [512,768], [512,512]]:
|
||||||
|
if abs(ar-(r[0]/r[1])) < diff:
|
||||||
|
res = f"{r[0]}x{r[1]}"
|
||||||
|
diff = abs(ar-(r[0]/r[1]))
|
||||||
|
|
||||||
|
if res:
|
||||||
|
conn.execute('UPDATE pending_prompt SET resolution = ? WHERE user_id = ?', (res, ev.input_sender.user_id))
|
||||||
|
await ev.respond(f'⚠️ Additionally, i\'ve automatically adjusted the resolution to {res} to match the aspect ratio.')
|
||||||
|
|
||||||
|
if field == 'prompt':
|
||||||
|
conn.execute('UPDATE pending_prompt SET prompt_e6 = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
if re.search(r'e621\.net\/posts?/([0-9]+)', data):
|
||||||
|
e6_id = re.search(r'e621\.net\/posts?/([0-9]+)', data).group(1)
|
||||||
|
print('Getting data from e621 (prompt)...')
|
||||||
|
e6_prompt = e621.execute('SELECT tags FROM post WHERE id = ?', (int(e6_id),)).fetchone()
|
||||||
|
if not e6_prompt:
|
||||||
|
await ev.respond('This post does not seem to exist on e621, or perhaps, it\'t too new. You can only use posts from yesterday or older.')
|
||||||
|
return
|
||||||
|
data = e6_prompt['tags']
|
||||||
|
|
||||||
|
if len(data) > 2000:
|
||||||
|
tags = data.split(' ')
|
||||||
|
e6_prompt = ''
|
||||||
|
while len(e6_prompt) < 1000:
|
||||||
|
e6_prompt += ' '+tags.pop(randint(0, len(tags)-1))
|
||||||
|
|
||||||
|
data = e6_prompt
|
||||||
|
await ev.respond('⚠️ Since this post had too many tags, I had to delete some of them.')
|
||||||
|
|
||||||
|
conn.execute('UPDATE pending_prompt SET prompt_e6 = ? WHERE user_id = ?', (e6_id, ev.input_sender.user_id))
|
||||||
|
else:
|
||||||
|
tags = re.split(r'\W+', data)
|
||||||
|
|
||||||
|
if len(tags) < 5:
|
||||||
|
await ev.respond('At least five tags are needed for the prompt. Please try again:')
|
||||||
|
return
|
||||||
|
|
||||||
|
if field == 'seed':
|
||||||
|
if data == 'last_prompt':
|
||||||
|
data = conn.execute('SELECT seed FROM prompt WHERE user_id = ? ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if data: data = data[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = int(data)%1000000
|
||||||
|
except:
|
||||||
|
await ev.respond('The seed needs to be an integer between 0 and 1000000. Try again:')
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.execute(f"UPDATE user SET pending_action = NULL WHERE id = ?", (ev.input_sender.user_id,))
|
||||||
|
conn.execute(f"UPDATE pending_prompt SET {field.replace('initial_','')} = ? WHERE user_id = ?", (data, ev.input_sender.user_id))
|
||||||
|
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {field} -> {str(data)[:32]}')
|
||||||
|
|
||||||
|
if is_initial and field == 'image':
|
||||||
|
await step_two(ev)
|
||||||
|
else:
|
||||||
|
await edit_prompt(ev)
|
||||||
|
return
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete$'))
|
||||||
|
async def delete_prompt(ev):
|
||||||
|
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
await edit_or_respond(ev, f'Prompt has been deleted. Use /new to create a new one.')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/queue', incoming=True))
|
||||||
|
async def queue_info(ev):
|
||||||
|
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
|
||||||
|
avg_wait = conn.execute("SELECT avg(strftime('%s', 'now')-queued_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
|
||||||
|
min_max_wait = conn.execute("SELECT max(strftime('%s', 'now')-queued_at), min(strftime('%s', 'now')-queued_at) FROM prompt WHERE queued_at IS NOT NULL AND is_done IS NULL").fetchone()
|
||||||
|
waiting_usrs = conn.execute("SELECT count(*) FROM prompt WHERE is_done IS NULL").fetchone()
|
||||||
|
avg_speed = conn.execute("SELECT avg(aa) FROM (SELECT (inference_steps*number)/abs(completed_at-started_at) AS aa FROM prompt WHERE started_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)").fetchone()
|
||||||
|
|
||||||
|
await ev.respond("\n".join([
|
||||||
|
f"👯♂️ {waiting_usrs[0]} people are waiting in the queue",
|
||||||
|
|
||||||
|
f"{'🔴' if int(avg_comp_time[0]/60) > 20 else ('🟡' if int(avg_comp_time[0]/60) > 3 else '🟢')} Average queue-to-result time: <strong>{int(avg_comp_time[0]/60)} minutes</strong>.",
|
||||||
|
f"{'🔴' if int(min_max_wait[0]/60) > 20 else ('🟡' if int(min_max_wait[0]/60) > 5 else '🟢')} Current queue duration: {int(min_max_wait[1]/60)}~{int(min_max_wait[0]/60)} minutes",
|
||||||
|
f"🏃 Average speed: {avg_speed[0]:.2f} steps/sec"
|
||||||
|
]), parse_mode='HTML')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/queuelist', incoming=True))
|
||||||
|
async def queue_list(ev):
|
||||||
|
|
||||||
|
ret = 'Current queue:\n\n'
|
||||||
|
|
||||||
|
queue = conn.execute("""SELECT prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
|
||||||
|
JOIN user ON user.id = prompt.user_id
|
||||||
|
WHERE is_done IS NULL
|
||||||
|
ORDER BY prompt.id ASC""")
|
||||||
|
|
||||||
|
for item in queue:
|
||||||
|
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
|
||||||
|
|
||||||
|
await ev.respond(ret, parse_mode='HTML')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/broke', incoming=True))
|
||||||
|
async def queue_list(ev):
|
||||||
|
|
||||||
|
ret = 'Current queue:\n\n'
|
||||||
|
|
||||||
|
queue = conn.execute("""SELECT user.daily_cycles, prompt.started_at, user.name FROM prompt
|
||||||
|
JOIN user ON user.id = prompt.user_id
|
||||||
|
WHERE is_done IS NULL
|
||||||
|
ORDER BY prompt.id ASC""")
|
||||||
|
|
||||||
|
for item in queue:
|
||||||
|
ret += f"<strong>#{item['id']}</strong> · {int(item['wait']/60)}:{int(item['wait']%60):0<2}{' ⚙️' if item['started_at'] else ''} {item['name'][:16]}{'...' if len(item['name']) > 16 else ''}\n"
|
||||||
|
|
||||||
|
await ev.respond(ret, parse_mode='HTML')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/addworker', incoming=True))
|
||||||
|
async def add_worker(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
command, api_url, name = ev.message.raw_text.split(' ')
|
||||||
|
workers.append(Worker(api_url, client, name))
|
||||||
|
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/stop', incoming=True))
|
||||||
|
async def stop_queue(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
client.process_queue = False
|
||||||
|
await ev.respond("Stopping queue processing.")
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/process', incoming=True))
|
||||||
|
async def start_queue(ev):
|
||||||
|
|
||||||
|
if ev.input_sender.user_id == client.admin_id:
|
||||||
|
await ev.respond(f'Starting queue processing')
|
||||||
|
client.process_queue = True
|
||||||
|
|
||||||
|
if not client.process_queue:
|
||||||
|
await ev.respond('Please hold on a bit more time as the bot is going under maintenance.')
|
||||||
|
return
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/queueraw', incoming=True))
|
||||||
|
async def queue_list(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
await ev.respond(str(client.queue._queue))
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^queue$'))
|
||||||
|
async def queue(ev, new_message=False):
|
||||||
|
|
||||||
|
prompt = conn.execute('SELECT * FROM prompt WHERE is_done IS NULL AND user_id = ?', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if not prompt:
|
||||||
|
await edit_or_respond(ev, 'You don\'t have pending prompts :)\n/new for a new one')
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
|
||||||
|
|
||||||
|
behind_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id > ?', (prompt['id'],)).fetchone()[0]
|
||||||
|
front_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id < ?', (prompt['id'],)).fetchone()[0]
|
||||||
|
|
||||||
|
if new_message:
|
||||||
|
await client.send_message(ev.sender, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
|
||||||
|
else:
|
||||||
|
await edit_or_respond(ev, f"Your position in the queue:\n{behind_you} behind you, {front_you} in front of you\n{'🫥'*behind_you}😃{'🫥'*front_you}\n\nCurrent average wait time: {int(avg_comp_time[0]/60)} minutes", buttons=[[Button.inline(f"👯 Refresh queue", f"queue")]])
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^confirm$'))
|
||||||
|
async def confirm_prompt(ev):
|
||||||
|
|
||||||
|
# Do not allow banned users to confirm prompts
|
||||||
|
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if user['is_banned'] == 1:
|
||||||
|
await ev.respond('You have been banned from sending further prompts due to abuse.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the prompt exists
|
||||||
|
prompt = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if not prompt or not prompt['prompt']:
|
||||||
|
await ev.respond('Looks like you have nothing to confirm. Try /new for a new prompt.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Analyze the prompt
|
||||||
|
comments, quality = judge(prompt['prompt'])
|
||||||
|
|
||||||
|
# Check if the user has used all of his cycles
|
||||||
|
usage = get_credits(ev.input_sender.user_id)
|
||||||
|
|
||||||
|
if usage['balance'] < (prompt['inference_steps']*prompt['number'])/quality:
|
||||||
|
await ev.respond(f"Sorry. You only have {usage['balance']} cycles left out of the {user['daily_cycles']} 🌀 you can use every day. You can use /cycles to have more info on how many cycles you can use.")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: confirm prompt')
|
||||||
|
|
||||||
|
conn.execute('INSERT INTO prompt SELECT NULL, NULL, NULL, *, ?, NULL, NULL, ? FROM pending_prompt WHERE user_id = ?', (time(),quality,ev.input_sender.user_id)).fetchone()
|
||||||
|
conn.execute('DELETE FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
|
||||||
|
prompt = conn.execute('SELECT * FROM prompt WHERE user_id = ? AND is_done IS NULL ORDER BY id DESC LIMIT 1', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
|
||||||
|
await edit_or_respond(ev,
|
||||||
|
"\n".join([f"✅ Your prompt {prompt['id']} has been scheduled and will be generated soon!\n",
|
||||||
|
f"<strong>You spent {prompt['number']*prompt['inference_steps']} cycle/bs. You have {int(usage['balance']-((prompt['inference_steps']*prompt['number'])/quality))} left for today.</strong>"]), parse_mode='HTML')
|
||||||
|
|
||||||
|
await queue(ev, new_message=True)
|
||||||
|
|
||||||
|
await client.send_message(client.log_channel_id, f"New prompt #{prompt['id']} by {ev.input_sender.user_id} {get_display_name(ev.sender)}\n\n{prompt['prompt']}",
|
||||||
|
buttons=[[Button.inline(f"Delete prompt", f"delete_mod:{prompt['id']}"),
|
||||||
|
Button.inline(f"Ban user", f"ban_mod:{prompt['user_id']}")]])
|
||||||
|
|
||||||
|
task = (prompt['id'], prompt['id'],)
|
||||||
|
if task in client.queue._queue:
|
||||||
|
self.log.error(f"Tried to insert a duplicate task while confirming!!!")
|
||||||
|
print(task, client.queue._queue)
|
||||||
|
return
|
||||||
|
|
||||||
|
client.queue.put_nowait(task)
|
||||||
|
await start_queue(ev)
|
||||||
|
raise events.StopPropagation
|
||||||
|
|
||||||
|
desc = {
|
||||||
|
'seed': 'The seed must be a whole number between 0 and 1000000. It defines the random noise which will be used to begin generation - two images with the same seed will be identical. Write it or get the one from the previous generation.',
|
||||||
|
'image': 'Upload an image, give me an e621 link or just delete the current one.',
|
||||||
|
'prompt': 'A list of e621 tags, a phrase or an e621 link to use as a prompt for your image',
|
||||||
|
'negative_prompt': 'A list of e621 tags you don\'t want to see. Do not put - in front of the tags',
|
||||||
|
'number': 'Number of images to generate. Write a value or press the buttons.',
|
||||||
|
'detail': 'The detail (aka "guidance scale") can be set between 2 and 50. Lower values will create softer images, while higher values will create images with a lot of contrast and perhaps distortion. Write a value or press the buttons.',
|
||||||
|
'inference_steps': 'Define the amount of time, in "cycles", to spend generating the images.\nHigher time usually means higher quality but only if the tags are good enough. Try around ~40/image.',
|
||||||
|
'blend': '0.0 to 1.0, the amount of transformation to apply on the base image. The higher the value, the more the result image will be different than the source.',
|
||||||
|
'resolution': 'The width and height of the final image.',
|
||||||
|
'crop': 'Do you want to crop the base image so that it matches the generated image resolution?',
|
||||||
|
'hires': 'Do you want to receive a high res image at the end of the generation? (it will take more time)',
|
||||||
|
}
|
||||||
|
|
||||||
|
desc_buttons = {
|
||||||
|
'seed': [Button.inline(f"Get from last prompt", f"msg_but:last_prompt"),],
|
||||||
|
'number': [Button.inline(f"1️⃣", f"msg_but:1"), Button.inline(f"2️⃣", f"msg_but:2"), Button.inline(f"3️⃣", f"msg_but:3"), Button.inline(f"4️⃣", f"msg_but:4")],
|
||||||
|
'detail': [Button.inline(f"S (6.0)", f"msg_but:6"), Button.inline(f"M (10)", f"msg_but:10"), Button.inline(f"L (18)", f"msg_but:18")],
|
||||||
|
'inference_steps': [Button.inline(f"S (20)", f"msg_but:20"), Button.inline(f"M (40)", f"msg_but:40"), Button.inline(f"L (60)", f"msg_but:60")],
|
||||||
|
'blend': [Button.inline(f"S (0.3)", f"msg_but:0.3"), Button.inline(f"M (0.6)", f"msg_but:0.6"), Button.inline(f"L (0.8)", f"msg_but:0.8")],
|
||||||
|
'resolution': [[Button.inline(f"Potrait (512x768)", f"msg_but:512x768"), Button.inline(f"Landscape (768x512)", f"msg_but:768x512")],[Button.inline(f"Square (512x512)", f"msg_but:512x512")], #[Button.inline(f"Ultrawide (1024x512)", f"msg_but:1024x512"), Button.inline(f"Ultratall (512x1024)", f"msg_but:512x1024")]
|
||||||
|
],
|
||||||
|
'image': [Button.inline(f"🚫 No image", f"msg_but:no_image")],
|
||||||
|
'crop': [Button.inline(f"🚫 Don't touch it", f"msg_but:no"),Button.inline(f"✂️ Crop it", f"msg_but:yes")],
|
||||||
|
'hires': [Button.inline(f"Normal image", f"msg_but:no"),Button.inline(f"High resolution", f"msg_but:yes")],
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/cycles', incoming=True))
|
||||||
|
async def cycles_notice(ev):
|
||||||
|
user = conn.execute("SELECT * FROM user WHERE id = ?", (ev.input_sender.user_id,)).fetchone()
|
||||||
|
usage = get_credits(ev.input_sender.user_id)
|
||||||
|
|
||||||
|
await ev.respond(f"Hello {user['name']}. You have {usage['balance']}/{user['daily_cycles']} cycles left." + (f"\nYou are currently earning {usage['hourly_gain']} cycles/hour. Full amount in {int(usage['remaining_time']/3600)}h{int((usage['remaining_time']/60)%60):0>2}m{int(usage['remaining_time']%60):0>2}s." if usage['remaining_time'] else ''))
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/ban', incoming=True))
|
||||||
|
async def ban_user(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
|
||||||
|
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
|
||||||
|
await ev.respond('User banned.')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/unban', incoming=True))
|
||||||
|
async def unban_user(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
|
||||||
|
conn.execute('UPDATE user SET is_banned = NULL WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),))
|
||||||
|
await ev.respond('User unbanned.')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/setcredits', incoming=True))
|
||||||
|
async def unban_user(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
|
||||||
|
await ev.respond(conn.execute('UPDATE user SET daily_cycles = ? WHERE id = ?', (int(ev.message.raw_text.split(' ')[1]),int(ev.message.raw_text.split(' ')[2]))))
|
||||||
|
await client.send_message(int(ev.message.raw_text.split(' ')[1]), f"Congrats! You can now use {ev.message.raw_text.split(' ')[2]} cycles/day. Have fun!")
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^delete_mod:'))
|
||||||
|
async def del_mod(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
|
||||||
|
prompt = conn.execute('SELECT * FROM prompt WHERE id = ?', (int(ev.data.decode().split(':')[1]),)).fetchone()
|
||||||
|
if prompt:
|
||||||
|
await client.send_message(prompt['user_id'], 'Hi. Your prompt has been deleted by a moderator.')
|
||||||
|
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?', (prompt['id'],))
|
||||||
|
await ev.answer('Prompt deleted.')
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^ban_mod:'))
|
||||||
|
async def del_mod(ev):
|
||||||
|
if ev.input_sender.user_id != client.admin_id: return
|
||||||
|
|
||||||
|
conn.execute('UPDATE user SET is_banned = 1 WHERE id = ?', (int(ev.data.decode().split(':')[1]),))
|
||||||
|
conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE user_id = ?', (int(ev.data.decode().split(':')[1]),))
|
||||||
|
await ev.answer('User banned.')
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^change:'))
|
||||||
|
async def setting(ev):
|
||||||
|
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: {ev.data.decode()}')
|
||||||
|
|
||||||
|
field = ev.data.decode().split(':')[1]
|
||||||
|
pending_prompt = conn.execute('SELECT 1 FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if not pending_prompt:
|
||||||
|
await ev.edit('You cannot edit a prompt that doesn\'t exist anymore.')
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.execute('UPDATE user SET pending_action = ? WHERE id = ?', (field, ev.input_sender.user_id))
|
||||||
|
await ev.respond(f"Please tell me the value for <strong>{field}</strong>.{chr(10)+chr(10)+desc[field] if field in desc else ''}", parse_mode='HTML', buttons=desc_buttons.get(field, None))
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^edit$'))
|
||||||
|
async def edit_prompt(ev):
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: edit_prompt')
|
||||||
|
|
||||||
|
user = conn.execute('SELECT * FROM user WHERE id = ?', (ev.input_sender.user_id,))
|
||||||
|
user = user.fetchone()
|
||||||
|
|
||||||
|
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ?', (ev.input_sender.user_id,))
|
||||||
|
p = res.fetchone()
|
||||||
|
|
||||||
|
if not p:
|
||||||
|
await ev.respond('Sorry, it looks like you have no pending prompt.\n/new for a new one')
|
||||||
|
return
|
||||||
|
|
||||||
|
await ev.respond(f"""
|
||||||
|
<b>Your prompt</b>
|
||||||
|
|
||||||
|
🖼 Starting image: {'None' if p['image'] == 'no_image' else ('https://e621.net/posts/'+str(p['image_e6']) if p['image_e6'] else 'User uploaded image')}
|
||||||
|
🌱 Seed: {p['seed']}
|
||||||
|
👀 Prompt: <pre>{p['prompt']}</pre>
|
||||||
|
⛔ Negative prompt: <pre>{p['negative_prompt'] or 'no negative prompt set.'}</pre>
|
||||||
|
|
||||||
|
<b>Do not touch these if you don't know what you're doing</b>
|
||||||
|
🔢 Number of images: {p['number']}
|
||||||
|
🖥 Resolution: {p['resolution']}
|
||||||
|
✨ High resolution: {p['hires']} (<i>slower generation!</i>
|
||||||
|
💎 Detail: {p['detail']} (<i>Guidance scale</i>)
|
||||||
|
🌀 Generation cycles: {p['inference_steps']} cycles ({p['inference_steps']*p['number']} total) (<i>Inference steps</i>)""" +
|
||||||
|
(f"\n🎚 Blend amount: {p['blend']} (<i>Prompt strength</i>)\n✂️ Crop: {'Yes' if p['crop'] == '1' else 'No'}" if p['image'] != 'no_image' else ''),
|
||||||
|
parse_mode='HTML',
|
||||||
|
link_preview=False,
|
||||||
|
buttons = [
|
||||||
|
[
|
||||||
|
Button.inline(f"👀 Prompt", f"change:prompt"),
|
||||||
|
Button.inline(f"⛔ Negative prompt", f"change:negative_prompt"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Button.inline(f"🌱 Seed", f"change:seed"),
|
||||||
|
Button.inline(f"🔢 Number", f"change:number"),
|
||||||
|
Button.inline(f"✨ High resolution", f"change:hires")
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Button.inline(f"💎 Detail", f"change:detail"),
|
||||||
|
Button.inline(f"🌀 Gen cycles", f"change:inference_steps"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Button.inline(f"🖼 Base image", f"change:image"),
|
||||||
|
Button.inline(f"🖥 Resolution", f"change:resolution"),
|
||||||
|
*([Button.inline(f"🎚 Blend", f"change:blend"),Button.inline(f"✂️ Crop", f"change:crop")] if p['image'] != 'no_image' else [])
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Button.inline(f"✅ Confirm", f"confirm"),
|
||||||
|
Button.inline(f"🕵️ Analyze", f"analyze"),
|
||||||
|
Button.inline(f"❌ Delete", f"delete")
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@client.on(events.callbackquery.CallbackQuery(pattern=r'^analyze$'))
|
||||||
|
async def analyze_prompt(ev):
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: analyze')
|
||||||
|
|
||||||
|
res = conn.execute('SELECT * FROM pending_prompt WHERE user_id = ? LIMIT 1', (ev.input_sender.user_id,)).fetchone()
|
||||||
|
if res:
|
||||||
|
comments, quality = judge(res['prompt'])
|
||||||
|
await ev.respond("\n".join(comments), parse_mode='HTML')
|
||||||
|
else:
|
||||||
|
await ev.respond('What am i supposed to analyze?')
|
||||||
|
|
||||||
|
@client.on(events.NewMessage(pattern='^/start', incoming=True, func=lambda e: e.is_private))
|
||||||
|
async def welcome(ev):
|
||||||
|
log.info(f'{(ev.input_sender.user_id, get_display_name(ev.sender))}: hello')
|
||||||
|
await ev.respond(f'Hello, and welcome to ai621. This bot can be used to generate yiff. Before beginning, keep in mind that:\n\n1. Images are public, no abusive stuff\n2. Using the bot maliciously or with multiple alts will lead to a ban\n3. The images generated by the bot are of public domain.\nGenerate a new prompt with /new\n\nDiscussion: @ai621chat\nWebsite: https://ai621.foxo.me/')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
qqq = conn.execute("""SELECT used, user.daily_cycles, prompt.id, inference_steps*number AS cycles, strftime('%s', 'now')-prompt.queued_at AS wait, prompt.started_at, user.name FROM prompt
|
||||||
|
JOIN user ON user.id = prompt.user_id
|
||||||
|
JOIN (SELECT bb.user_id, sum(bb.number*bb.inference_steps) AS used FROM prompt AS bb WHERE bb.queued_at > strftime('%s', 'now')-86400 GROUP BY bb.user_id) aa ON aa.user_id = user.id
|
||||||
|
WHERE is_done IS NULL
|
||||||
|
ORDER BY started_at ASC NULLS LAST, prompt.id ASC""").fetchall()
|
||||||
|
|
||||||
|
for task in qqq:
|
||||||
|
if task in client.queue._queue:
|
||||||
|
self.log.error(f"Tried to insert a duplicate task while Resuming!!!")
|
||||||
|
print(task, client.queue._queue)
|
||||||
|
continue
|
||||||
|
client.queue.put_nowait((task['id'], task['id'],))
|
||||||
|
|
||||||
|
client.start()
|
||||||
|
client.flood_sleep_threshold = 24*60*60
|
||||||
|
client.run_until_disconnected()
|
|
@ -0,0 +1,42 @@
|
||||||
|
from asyncio import PriorityQueue
|
||||||
|
from telethon import TelegramClient
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import coloredlogs
|
||||||
|
from process_queue import *
|
||||||
|
|
||||||
|
coloredlogs.install(level='INFO')
|
||||||
|
|
||||||
|
api_id = YOUR TG API ID HERE
|
||||||
|
api_hash = YOUR TG API HASH HERE
|
||||||
|
|
||||||
|
temp_folder = TEMP FOLDER OF THE GENERATIONS
|
||||||
|
|
||||||
|
client = TelegramClient('bot', api_id, api_hash)
|
||||||
|
|
||||||
|
e621 = sqlite3.connect('e621.db')
|
||||||
|
e621.row_factory = sqlite3.Row
|
||||||
|
conn = sqlite3.connect('ai621.db', isolation_level=None)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
client.conn = conn
|
||||||
|
client.process = None
|
||||||
|
|
||||||
|
client.log_channel_id = ID OF LOG CHANNEL
|
||||||
|
client.main_channel_id = ID OF RAW CHANNEL
|
||||||
|
client.queue = PriorityQueue()
|
||||||
|
client.media_lock = asyncio.Lock()
|
||||||
|
client.process_queue = False
|
||||||
|
|
||||||
|
client.admin_id = USER ID OF THE ADMIN
|
||||||
|
|
||||||
|
workers = [
|
||||||
|
Worker('http://127.0.0.1:9000', client, 'armorlink'),
|
||||||
|
#Worker('http://127.0.0.1:9001', client, 'armorlink-low'),
|
||||||
|
#Worker('http://local.proxy:9000', client, 'g14')
|
||||||
|
]
|
||||||
|
|
||||||
|
log = logging.getLogger('bot')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import csv
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
conn = sqlite3.connect('e621.db')
|
||||||
|
conn.execute('DELETE FROM post')
|
||||||
|
csv.field_size_limit(sys.maxsize)
|
||||||
|
|
||||||
|
with open(glob('posts-*')[0]) as csvfile:
|
||||||
|
posts = csv.reader(csvfile, delimiter=',', quotechar='"')
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for p in posts:
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
if i == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if i%10000 == 0:
|
||||||
|
print(p[0], end='\r')
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
conn.execute('INSERT INTO post(id, file_id, tags) VALUES (?,?,?)', (int(p[0]), p[3]+'.'+p[11], p[8]))
|
||||||
|
|
||||||
|
conn.commit()
|
|
@ -0,0 +1,65 @@
|
||||||
|
import re
|
||||||
|
good_chars = 'abcdefghijklmnopqrstuvwxyz0123456789_- ,.()'
|
||||||
|
word_chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
|
||||||
|
stopwords = [ 'stop', 'the', 'to', 'and', 'a', 'in', 'it', 'is', 'i', 'that', 'had', 'on', 'for', 'were', 'was']
|
||||||
|
|
||||||
|
regexes = [
|
||||||
|
(r"\w_\w", 'It looks like you are using the character "_" to separate words in tags. Use spaces instead.', 0.9),
|
||||||
|
(r"\W-\w", 'It looks like you are using the "-" character to exclude tags. Put them into the "negative prompt" instead.', 0.9),
|
||||||
|
(r"^.{,60}$", 'Your prompt is very short. You will probably get a bad image.', 0.8),
|
||||||
|
(r"\b(stalin|hitler|nazi|nigger|waluigi|luigi|toilet)\b", 'Seriously?', 0.01),
|
||||||
|
(r"\b(loli|shota|cub|young|age difference|preteen|underage|child|teen)\b", 'Trying to generate cub art will probably make you banned. I warned you.', 0.001),
|
||||||
|
(r"\b(scat|poop|shit|piss|urine|pooping|rape)\b", 'Some of the tags you sent will be ignored because they were also blacklisted on e621.', 0.75),
|
||||||
|
(r"\({3,}", 'No need to use that many (((parenthesis))). That will give you a worse image.', 0.9),
|
||||||
|
(r"\[{3,}", 'Square [braces] will reduce the enphasis on a tag.', 1.0),
|
||||||
|
(r"\W#", 'There is no need to put # in front of tags. That\'ll worsen the quality of the image', 0.9),
|
||||||
|
(r"[👎👍🌀🌱💎]", "If you copy prompts from the channel, at least copy <strong>only</strong> the prompt.", 0.001)
|
||||||
|
]
|
||||||
|
|
||||||
|
tags = {}
|
||||||
|
with open('yiffy_tags.csv') as f:
|
||||||
|
while 1:
|
||||||
|
line = f.readline()
|
||||||
|
if not line: break
|
||||||
|
|
||||||
|
tag, value = line.strip().rsplit(',', 1)
|
||||||
|
|
||||||
|
if value == 'count': continue
|
||||||
|
tags[tag] = int(value)
|
||||||
|
sorted_by_size = sorted(tags.keys(), key=lambda x: len(x), reverse=True)
|
||||||
|
|
||||||
|
def judge(prompt):
|
||||||
|
|
||||||
|
prompt = ' '+prompt.lower().replace("\n", ' ')+' '
|
||||||
|
quality = 1.0
|
||||||
|
comments = []
|
||||||
|
|
||||||
|
found_tags = {}
|
||||||
|
for tag in sorted_by_size:
|
||||||
|
pos = prompt.find(tag)
|
||||||
|
if pos == -1: continue
|
||||||
|
if prompt[pos-1] in word_chars: continue
|
||||||
|
if prompt[pos+len(tag)] in word_chars: continue
|
||||||
|
|
||||||
|
found_tags[tag] = tags[tag]
|
||||||
|
|
||||||
|
if len(found_tags) == 0:
|
||||||
|
quality *= 0.65
|
||||||
|
comments.append(f"Your prompt doesn't even contain one tag from e621.")
|
||||||
|
elif len(found_tags) < 6:
|
||||||
|
quality *= 0.8
|
||||||
|
comments.append(f"Found only {len(found_tags)} tags in your prompt. The AI knows ~{int(sum(found_tags.values())/len(found_tags))} images from e621 with these tags.")
|
||||||
|
else:
|
||||||
|
comments.append(f"Found {len(found_tags)} tags in your prompt. The AI knows ~{int(sum(found_tags.values())/len(found_tags))} images from e621 with these tags.")
|
||||||
|
|
||||||
|
|
||||||
|
for pattern, comment, value in regexes:
|
||||||
|
match = re.search(pattern, prompt)
|
||||||
|
if match:
|
||||||
|
quality *= value
|
||||||
|
comments.append(comment)
|
||||||
|
|
||||||
|
if quality < 1:
|
||||||
|
comments.append(f"Because of these issues, you will consume {(1/quality):.2f}x the amount of usual cycles.")
|
||||||
|
|
||||||
|
return comments, quality
|
|
@ -0,0 +1,209 @@
|
||||||
|
import asyncio
|
||||||
|
from PIL import Image
|
||||||
|
from PIL import ImageOps
|
||||||
|
from base64 import b64encode, b64decode
|
||||||
|
from io import BytesIO
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
from os import listdir
|
||||||
|
from os.path import join
|
||||||
|
from time import time
|
||||||
|
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
||||||
|
|
||||||
|
log = logging.getLogger('process')
|
||||||
|
temp_folder = '/home/ed/temp/'
|
||||||
|
|
||||||
|
default_vars = {
|
||||||
|
"use_cpu":False,
|
||||||
|
"use_full_precision": False,
|
||||||
|
"stream_progress_updates": True,
|
||||||
|
"stream_image_progress": False,
|
||||||
|
"show_only_filtered_image": True,
|
||||||
|
"sampler_name": "dpm_solver_stability",
|
||||||
|
"save_to_disk_path": temp_folder,
|
||||||
|
"output_format": "png",
|
||||||
|
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
|
||||||
|
"metadata_output_format": "embed",
|
||||||
|
"use_hypernetwork_model": "boring_e621",
|
||||||
|
"hypernetwork_strength": 0.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self, api, client, name):
|
||||||
|
self.api = api
|
||||||
|
self.ready = False
|
||||||
|
self.client = client
|
||||||
|
self.queue = client.queue
|
||||||
|
self.conn = client.conn
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.task = None
|
||||||
|
self.name = name
|
||||||
|
self.log = logging.getLogger(name)
|
||||||
|
self.prompt = None
|
||||||
|
|
||||||
|
def start(self, future=None):
|
||||||
|
|
||||||
|
if not self.client.process_queue:
|
||||||
|
asyncio.create_task(self.client.send_message(self.client.admin_id, f"Loop of {self.name} has been stopped."))
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.task and not self.task.done(): return
|
||||||
|
|
||||||
|
if future and future.exception():
|
||||||
|
self.log.error(future.exception())
|
||||||
|
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), self.prompt))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
priority, prompt_id = self.queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
self.log.info('No more tasks to process!')
|
||||||
|
else:
|
||||||
|
prompt = self.client.conn.execute('SELECT prompt.* FROM prompt WHERE id = ?', (prompt_id,)).fetchone()
|
||||||
|
|
||||||
|
self.log.info(f"Processing {prompt_id}")
|
||||||
|
self.task = self.loop.create_task(self.process_prompt(prompt, (priority, prompt_id)))
|
||||||
|
self.task.add_done_callback(self.start)
|
||||||
|
|
||||||
|
async def process_prompt(self, prompt, queue_item):
|
||||||
|
|
||||||
|
# First of all, check if the user can still be messaged.
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as httpclient:
|
||||||
|
try:
|
||||||
|
await httpclient.get(join(self.api, 'ping'), timeout=5)
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
log.error('Server is dead. Waiting 10 seconds for server availability...')
|
||||||
|
await self.queue.put(queue_item)
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = await self.client.send_message(prompt['user_id'], f"Hello 👀 Generating your prompt now.")
|
||||||
|
except UserIsBlockedError:
|
||||||
|
self.conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?',(prompt['id'],))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare the parameters for the request
|
||||||
|
params = default_vars.copy()
|
||||||
|
params['session_id'] = str(prompt['id'])
|
||||||
|
params['prompt'] = prompt['prompt'] or ''
|
||||||
|
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
|
||||||
|
params['num_outputs'] = int(prompt['number'])
|
||||||
|
params['num_inference_steps'] = prompt['inference_steps']
|
||||||
|
params['guidance_scale'] = prompt['detail']
|
||||||
|
params['width'] = prompt['resolution'].split('x')[0]
|
||||||
|
params['height'] = prompt['resolution'].split('x')[1]
|
||||||
|
params['seed'] = str(prompt['seed'])
|
||||||
|
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
|
||||||
|
|
||||||
|
self.prompt = prompt['id']
|
||||||
|
|
||||||
|
if prompt['hires'] == 'yes':
|
||||||
|
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
||||||
|
|
||||||
|
if prompt['image'] != 'no_image':
|
||||||
|
img = Image.open(prompt['image'])
|
||||||
|
img = img.convert('RGB')
|
||||||
|
if prompt['crop'] == 'no':
|
||||||
|
img = img.resize(list((int(x) for x in prompt['resolution'].split('x'))))
|
||||||
|
else:
|
||||||
|
img = ImageOps.fit(img, list((int(x) for x in prompt['resolution'].split('x'))))
|
||||||
|
|
||||||
|
imgdata = BytesIO()
|
||||||
|
img.save(imgdata, format='JPEG')
|
||||||
|
|
||||||
|
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
|
||||||
|
params['sampler_name'] = 'ddim'
|
||||||
|
params['prompt_strength'] = prompt['blend']
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as httpclient:
|
||||||
|
|
||||||
|
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
failed = False
|
||||||
|
|
||||||
|
self.log.info('POST to server')
|
||||||
|
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
|
||||||
|
res = res.json()
|
||||||
|
|
||||||
|
last_edit = 0
|
||||||
|
|
||||||
|
while 1:
|
||||||
|
step = await httpclient.get(join(self.api, res['stream'][1:]))
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = step.json()
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 'step' in data:
|
||||||
|
if int(data['step'])%10 == 0:
|
||||||
|
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
|
||||||
|
|
||||||
|
if time() - last_edit > 10:
|
||||||
|
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
|
||||||
|
last_edit = time()
|
||||||
|
|
||||||
|
elif 'status' in data and data['status'] == 'failed':
|
||||||
|
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
|
||||||
|
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
|
||||||
|
failed = True
|
||||||
|
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||||
|
break
|
||||||
|
elif 'status' in data and data['status'] == 'succeeded':
|
||||||
|
self.log.info('Success!')
|
||||||
|
images = []
|
||||||
|
for img in data['output']:
|
||||||
|
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
|
||||||
|
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
|
||||||
|
imgdata.name = 'image.png'
|
||||||
|
imgdata.seek(0)
|
||||||
|
images.append(imgdata)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(data)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||||
|
await msg.delete()
|
||||||
|
|
||||||
|
if not failed:
|
||||||
|
asyncio.create_task(self.send_submission(prompt, images))
|
||||||
|
|
||||||
|
async def send_submission(self, prompt, images):
|
||||||
|
|
||||||
|
tg_files = []
|
||||||
|
|
||||||
|
for fn in images:
|
||||||
|
img = Image.open(fn)
|
||||||
|
img.thumbnail((1280,1280))
|
||||||
|
imgdata = BytesIO()
|
||||||
|
img.save(imgdata, format='JPEG')
|
||||||
|
siz = imgdata.seek(0,2)
|
||||||
|
imgdata.seek(0)
|
||||||
|
tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name))
|
||||||
|
|
||||||
|
results = await self.client.send_file(self.client.main_channel_id, tg_files, caption=
|
||||||
|
"\n".join([f"#{prompt['id']} · 🌀 {prompt['inference_steps']} · 🌱 {prompt['seed']} · 💎 {prompt['detail']}" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''),
|
||||||
|
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
|
||||||
|
('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']),
|
||||||
|
(f"\n👎 {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML')
|
||||||
|
|
||||||
|
await self.client.forward_messages(prompt['user_id'], results)
|
||||||
|
if prompt['hires'] == 'yes':
|
||||||
|
await self.client.send_message(prompt['user_id'], 'Uploading raw images... This will take a while')
|
||||||
|
tg_files = []
|
||||||
|
for img in images:
|
||||||
|
siz = img.seek(0,2)
|
||||||
|
img.seek(0)
|
||||||
|
tg_files.append(await self.client.upload_file(img, file_size=siz, file_name=img.name))
|
||||||
|
|
||||||
|
results = await self.client.send_file(self.client.main_channel_id, tg_files, force_document=True, caption=f"Raw images of #{prompt['id']}")
|
||||||
|
await self.client.forward_messages(prompt['user_id'], results)
|
||||||
|
|
||||||
|
self.log.info(f"Files for prompt #{prompt['id']} have been sent succesfully.")
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue