From 443f26f5fbc0fd66d8a17c190fa801a7e54bd62c Mon Sep 17 00:00:00 2001 From: nameless Date: Tue, 15 Aug 2023 18:28:01 +0000 Subject: [PATCH] rewrite process_queue.py for A1111 api - txt2img - img2img - progress polling --- process_queue.py | 119 ++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 64 deletions(-) diff --git a/process_queue.py b/process_queue.py index ca63147..fbd5693 100644 --- a/process_queue.py +++ b/process_queue.py @@ -12,21 +12,9 @@ from time import time from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError log = logging.getLogger('process') -temp_folder = '/home/ed/temp/' default_vars = { - "use_cpu":False, - "use_full_precision": False, - "stream_progress_updates": True, - "stream_image_progress": False, - "show_only_filtered_image": True, - "sampler_name": "dpm_solver_stability", - "save_to_disk_path": temp_folder, - "output_format": "png", - "use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34", - "metadata_output_format": "embed", - "use_hypernetwork_model": "boring_e621", - "hypernetwork_strength": 0.25, + "sampler_name": "DPM++ 2M Karras", } class Worker: @@ -42,6 +30,25 @@ class Worker: self.log = logging.getLogger(name) self.prompt = None + async def update_progress_bar(self, msg): + start = time() + last_edit = 0 + while True: + async with httpx.AsyncClient() as httpclient: + step = await httpclient.get(self.api + '/sdapi/v1/progress') + try: + data = step.json() + except: + continue + if data['state']['sampling_step']%10 == 0: + self.log.info(f"Generation progress of {self.prompt}: {data['state']['job']} | {data['state']['sampling_step']}/{data['state']['sampling_steps']}") + + if time() - last_edit > 3: + await msg.edit(f"Generating prompt #{self.prompt}, {int(data['progress'] * 100)}% done. {time()-start:.1f}s elapsed, {int(data['eta_relative'])}s remaining.") + last_edit = time() + + await asyncio.sleep(2) + def start(self, future=None): if not self.client.process_queue: @@ -67,12 +74,13 @@ class Worker: self.task.add_done_callback(self.start) async def process_prompt(self, prompt, queue_item): + endpoint = 'txt2img' # First of all, check if the user can still be messaged. async with httpx.AsyncClient() as httpclient: try: - await httpclient.get(join(self.api, 'ping'), timeout=5) + await httpclient.get(self.api + '/internal/ping', timeout=5) except Exception as e: print(str(e)) log.error('Server is dead. Waiting 10 seconds for server availability...') @@ -88,23 +96,23 @@ class Worker: # Prepare the parameters for the request params = default_vars.copy() - params['session_id'] = str(prompt['id']) params['prompt'] = prompt['prompt'] or '' - params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4' - params['num_outputs'] = int(prompt['number']) - params['num_inference_steps'] = prompt['inference_steps'] - params['guidance_scale'] = prompt['detail'] + params['negative_prompt'] = str(prompt['negative_prompt']) + params['n_iter'] = int(prompt['number']) + #params['batch_size'] = int(prompt['number']) + params['steps'] = prompt['inference_steps'] + params['cfg_scale'] = prompt['detail'] params['width'] = prompt['resolution'].split('x')[0] params['height'] = prompt['resolution'].split('x')[1] params['seed'] = str(prompt['seed']) - params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high') self.prompt = prompt['id'] - if prompt['hires'] == 'yes': - params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B' + #if prompt['hires'] == 'yes': + # params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B' if prompt['image'] != 'no_image': + endpoint = 'img2img' img = Image.open(prompt['image']) img = img.convert('RGB') if prompt['crop'] == 'no': @@ -115,60 +123,43 @@ class Worker: imgdata = BytesIO() img.save(imgdata, format='JPEG') - params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip() - params['sampler_name'] = 'ddim' - params['prompt_strength'] = prompt['blend'] + params['init_images'] = [('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()] + params['sampler_name'] = 'DDIM' + params['denoising_strength'] = prompt['blend'] async with httpx.AsyncClient() as httpclient: self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id'])) - start = time() failed = False self.log.info('POST to server') - res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params)) + progress_task = asyncio.create_task(self.update_progress_bar(msg)) + try: + res = await httpclient.post(self.api + f'/sdapi/v1/{endpoint}', data=json.dumps(params), timeout=300) + except Exception as e: + await self.client.send_message(self.client.admin_id, f"While generating #{prompt['id']}: {e}") + await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error.\n\nThis might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case.") + failed = True + self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) + return res = res.json() - last_edit = 0 + self.log.info('Success!') + images = [] + for img in res['images']: + imgdata = BytesIO(b64decode(img.split(",",1)[0])) + imgdata.name = 'image.png' + imgdata.seek(0) + images.append(imgdata) - while 1: - step = await httpclient.get(join(self.api, res['stream'][1:])) - - try: - data = step.json() - except: - continue - - if 'step' in data: - if int(data['step'])%10 == 0: - self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}") - - if time() - last_edit > 10: - await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.") - last_edit = time() - - elif 'status' in data and data['status'] == 'failed': - await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...") - await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.") - failed = True - self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) - break - elif 'status' in data and data['status'] == 'succeeded': - self.log.info('Success!') - images = [] - for img in data['output']: - imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1])) - #imgdata.name = img['path_abs'].rsplit('/', 1)[-1] - imgdata.name = 'image.png' - imgdata.seek(0) - images.append(imgdata) - break - else: - print(data) - - await asyncio.sleep(2) + + progress_task.cancel() + try: + await progress_task + except asyncio.CancelledError: + pass self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) await msg.delete()