rewrite process_queue.py for A1111 api

- txt2img
- img2img 
- progress polling
This commit is contained in:
nameless 2023-08-15 18:28:01 +00:00
parent 862bad5d48
commit 443f26f5fb
1 changed files with 55 additions and 64 deletions

View File

@ -12,21 +12,9 @@ from time import time
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
log = logging.getLogger('process') log = logging.getLogger('process')
temp_folder = '/home/ed/temp/'
default_vars = { default_vars = {
"use_cpu":False, "sampler_name": "DPM++ 2M Karras",
"use_full_precision": False,
"stream_progress_updates": True,
"stream_image_progress": False,
"show_only_filtered_image": True,
"sampler_name": "dpm_solver_stability",
"save_to_disk_path": temp_folder,
"output_format": "png",
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
"metadata_output_format": "embed",
"use_hypernetwork_model": "boring_e621",
"hypernetwork_strength": 0.25,
} }
class Worker: class Worker:
@ -42,6 +30,25 @@ class Worker:
self.log = logging.getLogger(name) self.log = logging.getLogger(name)
self.prompt = None self.prompt = None
async def update_progress_bar(self, msg):
start = time()
last_edit = 0
while True:
async with httpx.AsyncClient() as httpclient:
step = await httpclient.get(self.api + '/sdapi/v1/progress')
try:
data = step.json()
except:
continue
if data['state']['sampling_step']%10 == 0:
self.log.info(f"Generation progress of {self.prompt}: {data['state']['job']} | {data['state']['sampling_step']}/{data['state']['sampling_steps']}")
if time() - last_edit > 3:
await msg.edit(f"Generating prompt #{self.prompt}, {int(data['progress'] * 100)}% done. {time()-start:.1f}s elapsed, {int(data['eta_relative'])}s remaining.")
last_edit = time()
await asyncio.sleep(2)
def start(self, future=None): def start(self, future=None):
if not self.client.process_queue: if not self.client.process_queue:
@ -67,12 +74,13 @@ class Worker:
self.task.add_done_callback(self.start) self.task.add_done_callback(self.start)
async def process_prompt(self, prompt, queue_item): async def process_prompt(self, prompt, queue_item):
endpoint = 'txt2img'
# First of all, check if the user can still be messaged. # First of all, check if the user can still be messaged.
async with httpx.AsyncClient() as httpclient: async with httpx.AsyncClient() as httpclient:
try: try:
await httpclient.get(join(self.api, 'ping'), timeout=5) await httpclient.get(self.api + '/internal/ping', timeout=5)
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
log.error('Server is dead. Waiting 10 seconds for server availability...') log.error('Server is dead. Waiting 10 seconds for server availability...')
@ -88,23 +96,23 @@ class Worker:
# Prepare the parameters for the request # Prepare the parameters for the request
params = default_vars.copy() params = default_vars.copy()
params['session_id'] = str(prompt['id'])
params['prompt'] = prompt['prompt'] or '' params['prompt'] = prompt['prompt'] or ''
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4' params['negative_prompt'] = str(prompt['negative_prompt'])
params['num_outputs'] = int(prompt['number']) params['n_iter'] = int(prompt['number'])
params['num_inference_steps'] = prompt['inference_steps'] #params['batch_size'] = int(prompt['number'])
params['guidance_scale'] = prompt['detail'] params['steps'] = prompt['inference_steps']
params['cfg_scale'] = prompt['detail']
params['width'] = prompt['resolution'].split('x')[0] params['width'] = prompt['resolution'].split('x')[0]
params['height'] = prompt['resolution'].split('x')[1] params['height'] = prompt['resolution'].split('x')[1]
params['seed'] = str(prompt['seed']) params['seed'] = str(prompt['seed'])
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
self.prompt = prompt['id'] self.prompt = prompt['id']
if prompt['hires'] == 'yes': #if prompt['hires'] == 'yes':
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B' # params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
if prompt['image'] != 'no_image': if prompt['image'] != 'no_image':
endpoint = 'img2img'
img = Image.open(prompt['image']) img = Image.open(prompt['image'])
img = img.convert('RGB') img = img.convert('RGB')
if prompt['crop'] == 'no': if prompt['crop'] == 'no':
@ -115,60 +123,43 @@ class Worker:
imgdata = BytesIO() imgdata = BytesIO()
img.save(imgdata, format='JPEG') img.save(imgdata, format='JPEG')
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip() params['init_images'] = [('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()]
params['sampler_name'] = 'ddim' params['sampler_name'] = 'DDIM'
params['prompt_strength'] = prompt['blend'] params['denoising_strength'] = prompt['blend']
async with httpx.AsyncClient() as httpclient: async with httpx.AsyncClient() as httpclient:
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id'])) self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
start = time()
failed = False failed = False
self.log.info('POST to server') self.log.info('POST to server')
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params)) progress_task = asyncio.create_task(self.update_progress_bar(msg))
try:
res = await httpclient.post(self.api + f'/sdapi/v1/{endpoint}', data=json.dumps(params), timeout=300)
except Exception as e:
await self.client.send_message(self.client.admin_id, f"While generating #{prompt['id']}: {e}")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error.\n\nThis might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
return
res = res.json() res = res.json()
last_edit = 0 self.log.info('Success!')
images = []
for img in res['images']:
imgdata = BytesIO(b64decode(img.split(",",1)[0]))
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
while 1:
step = await httpclient.get(join(self.api, res['stream'][1:]))
try:
data = step.json()
except:
continue
if 'step' in data:
if int(data['step'])%10 == 0:
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
if time() - last_edit > 10:
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
last_edit = time()
elif 'status' in data and data['status'] == 'failed':
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
break
elif 'status' in data and data['status'] == 'succeeded':
self.log.info('Success!')
images = []
for img in data['output']:
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
break
else:
print(data)
await asyncio.sleep(2)
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
await msg.delete() await msg.delete()