rewrite process_queue.py for A1111 api
- txt2img - img2img - progress polling
This commit is contained in:
parent
862bad5d48
commit
443f26f5fb
117
process_queue.py
117
process_queue.py
|
@ -12,21 +12,9 @@ from time import time
|
|||
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
||||
|
||||
log = logging.getLogger('process')
|
||||
temp_folder = '/home/ed/temp/'
|
||||
|
||||
default_vars = {
|
||||
"use_cpu":False,
|
||||
"use_full_precision": False,
|
||||
"stream_progress_updates": True,
|
||||
"stream_image_progress": False,
|
||||
"show_only_filtered_image": True,
|
||||
"sampler_name": "dpm_solver_stability",
|
||||
"save_to_disk_path": temp_folder,
|
||||
"output_format": "png",
|
||||
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
|
||||
"metadata_output_format": "embed",
|
||||
"use_hypernetwork_model": "boring_e621",
|
||||
"hypernetwork_strength": 0.25,
|
||||
"sampler_name": "DPM++ 2M Karras",
|
||||
}
|
||||
|
||||
class Worker:
|
||||
|
@ -42,6 +30,25 @@ class Worker:
|
|||
self.log = logging.getLogger(name)
|
||||
self.prompt = None
|
||||
|
||||
async def update_progress_bar(self, msg):
|
||||
start = time()
|
||||
last_edit = 0
|
||||
while True:
|
||||
async with httpx.AsyncClient() as httpclient:
|
||||
step = await httpclient.get(self.api + '/sdapi/v1/progress')
|
||||
try:
|
||||
data = step.json()
|
||||
except:
|
||||
continue
|
||||
if data['state']['sampling_step']%10 == 0:
|
||||
self.log.info(f"Generation progress of {self.prompt}: {data['state']['job']} | {data['state']['sampling_step']}/{data['state']['sampling_steps']}")
|
||||
|
||||
if time() - last_edit > 3:
|
||||
await msg.edit(f"Generating prompt #{self.prompt}, {int(data['progress'] * 100)}% done. {time()-start:.1f}s elapsed, {int(data['eta_relative'])}s remaining.")
|
||||
last_edit = time()
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
def start(self, future=None):
|
||||
|
||||
if not self.client.process_queue:
|
||||
|
@ -67,12 +74,13 @@ class Worker:
|
|||
self.task.add_done_callback(self.start)
|
||||
|
||||
async def process_prompt(self, prompt, queue_item):
|
||||
endpoint = 'txt2img'
|
||||
|
||||
# First of all, check if the user can still be messaged.
|
||||
|
||||
async with httpx.AsyncClient() as httpclient:
|
||||
try:
|
||||
await httpclient.get(join(self.api, 'ping'), timeout=5)
|
||||
await httpclient.get(self.api + '/internal/ping', timeout=5)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
log.error('Server is dead. Waiting 10 seconds for server availability...')
|
||||
|
@ -88,23 +96,23 @@ class Worker:
|
|||
|
||||
# Prepare the parameters for the request
|
||||
params = default_vars.copy()
|
||||
params['session_id'] = str(prompt['id'])
|
||||
params['prompt'] = prompt['prompt'] or ''
|
||||
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
|
||||
params['num_outputs'] = int(prompt['number'])
|
||||
params['num_inference_steps'] = prompt['inference_steps']
|
||||
params['guidance_scale'] = prompt['detail']
|
||||
params['negative_prompt'] = str(prompt['negative_prompt'])
|
||||
params['n_iter'] = int(prompt['number'])
|
||||
#params['batch_size'] = int(prompt['number'])
|
||||
params['steps'] = prompt['inference_steps']
|
||||
params['cfg_scale'] = prompt['detail']
|
||||
params['width'] = prompt['resolution'].split('x')[0]
|
||||
params['height'] = prompt['resolution'].split('x')[1]
|
||||
params['seed'] = str(prompt['seed'])
|
||||
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
|
||||
|
||||
self.prompt = prompt['id']
|
||||
|
||||
if prompt['hires'] == 'yes':
|
||||
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
||||
#if prompt['hires'] == 'yes':
|
||||
# params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
||||
|
||||
if prompt['image'] != 'no_image':
|
||||
endpoint = 'img2img'
|
||||
img = Image.open(prompt['image'])
|
||||
img = img.convert('RGB')
|
||||
if prompt['crop'] == 'no':
|
||||
|
@ -115,60 +123,43 @@ class Worker:
|
|||
imgdata = BytesIO()
|
||||
img.save(imgdata, format='JPEG')
|
||||
|
||||
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
|
||||
params['sampler_name'] = 'ddim'
|
||||
params['prompt_strength'] = prompt['blend']
|
||||
params['init_images'] = [('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()]
|
||||
params['sampler_name'] = 'DDIM'
|
||||
params['denoising_strength'] = prompt['blend']
|
||||
|
||||
async with httpx.AsyncClient() as httpclient:
|
||||
|
||||
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
|
||||
|
||||
start = time()
|
||||
failed = False
|
||||
|
||||
self.log.info('POST to server')
|
||||
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
|
||||
progress_task = asyncio.create_task(self.update_progress_bar(msg))
|
||||
try:
|
||||
res = await httpclient.post(self.api + f'/sdapi/v1/{endpoint}', data=json.dumps(params), timeout=300)
|
||||
except Exception as e:
|
||||
await self.client.send_message(self.client.admin_id, f"While generating #{prompt['id']}: {e}")
|
||||
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error.\n\nThis might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case.")
|
||||
failed = True
|
||||
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||
return
|
||||
res = res.json()
|
||||
|
||||
last_edit = 0
|
||||
self.log.info('Success!')
|
||||
images = []
|
||||
for img in res['images']:
|
||||
imgdata = BytesIO(b64decode(img.split(",",1)[0]))
|
||||
imgdata.name = 'image.png'
|
||||
imgdata.seek(0)
|
||||
images.append(imgdata)
|
||||
|
||||
while 1:
|
||||
step = await httpclient.get(join(self.api, res['stream'][1:]))
|
||||
|
||||
try:
|
||||
data = step.json()
|
||||
except:
|
||||
continue
|
||||
|
||||
if 'step' in data:
|
||||
if int(data['step'])%10 == 0:
|
||||
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
|
||||
|
||||
if time() - last_edit > 10:
|
||||
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
|
||||
last_edit = time()
|
||||
|
||||
elif 'status' in data and data['status'] == 'failed':
|
||||
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
|
||||
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
|
||||
failed = True
|
||||
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||
break
|
||||
elif 'status' in data and data['status'] == 'succeeded':
|
||||
self.log.info('Success!')
|
||||
images = []
|
||||
for img in data['output']:
|
||||
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
|
||||
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
|
||||
imgdata.name = 'image.png'
|
||||
imgdata.seek(0)
|
||||
images.append(imgdata)
|
||||
break
|
||||
else:
|
||||
print(data)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
progress_task.cancel()
|
||||
try:
|
||||
await progress_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||
await msg.delete()
|
||||
|
||||
|
|
Loading…
Reference in New Issue