ai621/process_queue.py

210 lines
7.7 KiB
Python

import asyncio
from PIL import Image
from PIL import ImageOps
from base64 import b64encode, b64decode
from io import BytesIO
import logging
import json
import httpx
from os import listdir
from os.path import join
from time import time
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
log = logging.getLogger('process')
temp_folder = '/home/ed/temp/'
default_vars = {
"use_cpu":False,
"use_full_precision": False,
"stream_progress_updates": True,
"stream_image_progress": False,
"show_only_filtered_image": True,
"sampler_name": "dpm_solver_stability",
"save_to_disk_path": temp_folder,
"output_format": "png",
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
"metadata_output_format": "embed",
"use_hypernetwork_model": "boring_e621",
"hypernetwork_strength": 0.25,
}
class Worker:
def __init__(self, api, client, name):
self.api = api
self.ready = False
self.client = client
self.queue = client.queue
self.conn = client.conn
self.loop = asyncio.get_event_loop()
self.task = None
self.name = name
self.log = logging.getLogger(name)
self.prompt = None
def start(self, future=None):
if not self.client.process_queue:
asyncio.create_task(self.client.send_message(self.client.admin_id, f"Loop of {self.name} has been stopped."))
return
if self.task and not self.task.done(): return
if future and future.exception():
self.log.error(future.exception())
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), self.prompt))
return
try:
priority, prompt_id = self.queue.get_nowait()
except asyncio.QueueEmpty:
self.log.info('No more tasks to process!')
else:
prompt = self.client.conn.execute('SELECT prompt.* FROM prompt WHERE id = ?', (prompt_id,)).fetchone()
self.log.info(f"Processing {prompt_id}")
self.task = self.loop.create_task(self.process_prompt(prompt, (priority, prompt_id)))
self.task.add_done_callback(self.start)
async def process_prompt(self, prompt, queue_item):
# First of all, check if the user can still be messaged.
async with httpx.AsyncClient() as httpclient:
try:
await httpclient.get(join(self.api, 'ping'), timeout=5)
except Exception as e:
print(str(e))
log.error('Server is dead. Waiting 10 seconds for server availability...')
await self.queue.put(queue_item)
await asyncio.sleep(10)
return
try:
msg = await self.client.send_message(prompt['user_id'], f"Hello 👀 Generating your prompt now.")
except UserIsBlockedError:
self.conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?',(prompt['id'],))
return
# Prepare the parameters for the request
params = default_vars.copy()
params['session_id'] = str(prompt['id'])
params['prompt'] = prompt['prompt'] or ''
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
params['num_outputs'] = int(prompt['number'])
params['num_inference_steps'] = prompt['inference_steps']
params['guidance_scale'] = prompt['detail']
params['width'] = prompt['resolution'].split('x')[0]
params['height'] = prompt['resolution'].split('x')[1]
params['seed'] = str(prompt['seed'])
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
self.prompt = prompt['id']
if prompt['hires'] == 'yes':
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
if prompt['image'] != 'no_image':
img = Image.open(prompt['image'])
img = img.convert('RGB')
if prompt['crop'] == 'no':
img = img.resize(list((int(x) for x in prompt['resolution'].split('x'))))
else:
img = ImageOps.fit(img, list((int(x) for x in prompt['resolution'].split('x'))))
imgdata = BytesIO()
img.save(imgdata, format='JPEG')
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
params['sampler_name'] = 'ddim'
params['prompt_strength'] = prompt['blend']
async with httpx.AsyncClient() as httpclient:
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
start = time()
failed = False
self.log.info('POST to server')
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
res = res.json()
last_edit = 0
while 1:
step = await httpclient.get(join(self.api, res['stream'][1:]))
try:
data = step.json()
except:
continue
if 'step' in data:
if int(data['step'])%10 == 0:
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
if time() - last_edit > 10:
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
last_edit = time()
elif 'status' in data and data['status'] == 'failed':
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
break
elif 'status' in data and data['status'] == 'succeeded':
self.log.info('Success!')
images = []
for img in data['output']:
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
break
else:
print(data)
await asyncio.sleep(2)
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
await msg.delete()
if not failed:
asyncio.create_task(self.send_submission(prompt, images))
async def send_submission(self, prompt, images):
tg_files = []
for fn in images:
img = Image.open(fn)
img.thumbnail((1280,1280))
imgdata = BytesIO()
img.save(imgdata, format='JPEG')
siz = imgdata.seek(0,2)
imgdata.seek(0)
tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name))
results = await self.client.send_file(self.client.main_channel_id, tg_files, caption=
"\n".join([f"#{prompt['id']} · 🌀 {prompt['inference_steps']} · 🌱 {prompt['seed']} · 💎 {prompt['detail']}" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''),
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']),
(f"\n👎 {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML')
await self.client.forward_messages(prompt['user_id'], results)
if prompt['hires'] == 'yes':
await self.client.send_message(prompt['user_id'], 'Uploading raw images... This will take a while')
tg_files = []
for img in images:
siz = img.seek(0,2)
img.seek(0)
tg_files.append(await self.client.upload_file(img, file_size=siz, file_name=img.name))
results = await self.client.send_file(self.client.main_channel_id, tg_files, force_document=True, caption=f"Raw images of #{prompt['id']}")
await self.client.forward_messages(prompt['user_id'], results)
self.log.info(f"Files for prompt #{prompt['id']} have been sent succesfully.")