import asyncio from PIL import Image from PIL import ImageOps from base64 import b64encode, b64decode from io import BytesIO import logging import json import httpx from os import listdir from os.path import join from time import time from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError log = logging.getLogger('process') temp_folder = '/home/ed/temp/' default_vars = { "use_cpu":False, "use_full_precision": False, "stream_progress_updates": True, "stream_image_progress": False, "show_only_filtered_image": True, "sampler_name": "dpm_solver_stability", "save_to_disk_path": temp_folder, "output_format": "png", "use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34", "metadata_output_format": "embed", "use_hypernetwork_model": "boring_e621", "hypernetwork_strength": 0.25, } class Worker: def __init__(self, api, client, name): self.api = api self.ready = False self.client = client self.queue = client.queue self.conn = client.conn self.loop = asyncio.get_event_loop() self.task = None self.name = name self.log = logging.getLogger(name) self.prompt = None def start(self, future=None): if not self.client.process_queue: asyncio.create_task(self.client.send_message(self.client.admin_id, f"Loop of {self.name} has been stopped.")) return if self.task and not self.task.done(): return if future and future.exception(): self.log.error(future.exception()) self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), self.prompt)) return try: priority, prompt_id = self.queue.get_nowait() except asyncio.QueueEmpty: self.log.info('No more tasks to process!') else: prompt = self.client.conn.execute('SELECT prompt.* FROM prompt WHERE id = ?', (prompt_id,)).fetchone() self.log.info(f"Processing {prompt_id}") self.task = self.loop.create_task(self.process_prompt(prompt, (priority, prompt_id))) self.task.add_done_callback(self.start) async def process_prompt(self, prompt, queue_item): # First of all, check if the user can still be messaged. async with httpx.AsyncClient() as httpclient: try: await httpclient.get(join(self.api, 'ping'), timeout=5) except Exception as e: print(str(e)) log.error('Server is dead. Waiting 10 seconds for server availability...') await self.queue.put(queue_item) await asyncio.sleep(10) return try: msg = await self.client.send_message(prompt['user_id'], f"Hello ๐Ÿ‘€ Generating your prompt now.") except UserIsBlockedError: self.conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?',(prompt['id'],)) return # Prepare the parameters for the request params = default_vars.copy() params['session_id'] = str(prompt['id']) params['prompt'] = prompt['prompt'] or '' params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4' params['num_outputs'] = int(prompt['number']) params['num_inference_steps'] = prompt['inference_steps'] params['guidance_scale'] = prompt['detail'] params['width'] = prompt['resolution'].split('x')[0] params['height'] = prompt['resolution'].split('x')[1] params['seed'] = str(prompt['seed']) params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high') self.prompt = prompt['id'] if prompt['hires'] == 'yes': params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B' if prompt['image'] != 'no_image': img = Image.open(prompt['image']) img = img.convert('RGB') if prompt['crop'] == 'no': img = img.resize(list((int(x) for x in prompt['resolution'].split('x')))) else: img = ImageOps.fit(img, list((int(x) for x in prompt['resolution'].split('x')))) imgdata = BytesIO() img.save(imgdata, format='JPEG') params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip() params['sampler_name'] = 'ddim' params['prompt_strength'] = prompt['blend'] async with httpx.AsyncClient() as httpclient: self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id'])) start = time() failed = False self.log.info('POST to server') res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params)) res = res.json() last_edit = 0 while 1: step = await httpclient.get(join(self.api, res['stream'][1:])) try: data = step.json() except: continue if 'step' in data: if int(data['step'])%10 == 0: self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}") if time() - last_edit > 10: await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.") last_edit = time() elif 'status' in data and data['status'] == 'failed': await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...") await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.") failed = True self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) break elif 'status' in data and data['status'] == 'succeeded': self.log.info('Success!') images = [] for img in data['output']: imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1])) #imgdata.name = img['path_abs'].rsplit('/', 1)[-1] imgdata.name = 'image.png' imgdata.seek(0) images.append(imgdata) break else: print(data) await asyncio.sleep(2) self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) await msg.delete() if not failed: asyncio.create_task(self.send_submission(prompt, images)) async def send_submission(self, prompt, images): tg_files = [] for fn in images: img = Image.open(fn) img.thumbnail((1280,1280)) imgdata = BytesIO() img.save(imgdata, format='JPEG') siz = imgdata.seek(0,2) imgdata.seek(0) tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name)) results = await self.client.send_file(self.client.main_channel_id, tg_files, caption= "\n".join([f"#{prompt['id']} ยท ๐ŸŒ€ {prompt['inference_steps']} ยท ๐ŸŒฑ {prompt['seed']} ยท ๐Ÿ’Ž {prompt['detail']}" + (f" ยท ๐ŸŽš {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''), #((f"๐Ÿ–ผ https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''), ('๐Ÿ‘ ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']), (f"\n๐Ÿ‘Ž {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML') await self.client.forward_messages(prompt['user_id'], results) if prompt['hires'] == 'yes': await self.client.send_message(prompt['user_id'], 'Uploading raw images... This will take a while') tg_files = [] for img in images: siz = img.seek(0,2) img.seek(0) tg_files.append(await self.client.upload_file(img, file_size=siz, file_name=img.name)) results = await self.client.send_file(self.client.main_channel_id, tg_files, force_document=True, caption=f"Raw images of #{prompt['id']}") await self.client.forward_messages(prompt['user_id'], results) self.log.info(f"Files for prompt #{prompt['id']} have been sent succesfully.")