forked from AI621/ai621
rewrite process_queue.py for A1111 api
- txt2img - img2img - progress polling
This commit is contained in:
parent
862bad5d48
commit
443f26f5fb
119
process_queue.py
119
process_queue.py
|
@ -12,21 +12,9 @@ from time import time
|
||||||
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
||||||
|
|
||||||
log = logging.getLogger('process')
|
log = logging.getLogger('process')
|
||||||
temp_folder = '/home/ed/temp/'
|
|
||||||
|
|
||||||
default_vars = {
|
default_vars = {
|
||||||
"use_cpu":False,
|
"sampler_name": "DPM++ 2M Karras",
|
||||||
"use_full_precision": False,
|
|
||||||
"stream_progress_updates": True,
|
|
||||||
"stream_image_progress": False,
|
|
||||||
"show_only_filtered_image": True,
|
|
||||||
"sampler_name": "dpm_solver_stability",
|
|
||||||
"save_to_disk_path": temp_folder,
|
|
||||||
"output_format": "png",
|
|
||||||
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
|
|
||||||
"metadata_output_format": "embed",
|
|
||||||
"use_hypernetwork_model": "boring_e621",
|
|
||||||
"hypernetwork_strength": 0.25,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
|
@ -42,6 +30,25 @@ class Worker:
|
||||||
self.log = logging.getLogger(name)
|
self.log = logging.getLogger(name)
|
||||||
self.prompt = None
|
self.prompt = None
|
||||||
|
|
||||||
|
async def update_progress_bar(self, msg):
|
||||||
|
start = time()
|
||||||
|
last_edit = 0
|
||||||
|
while True:
|
||||||
|
async with httpx.AsyncClient() as httpclient:
|
||||||
|
step = await httpclient.get(self.api + '/sdapi/v1/progress')
|
||||||
|
try:
|
||||||
|
data = step.json()
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if data['state']['sampling_step']%10 == 0:
|
||||||
|
self.log.info(f"Generation progress of {self.prompt}: {data['state']['job']} | {data['state']['sampling_step']}/{data['state']['sampling_steps']}")
|
||||||
|
|
||||||
|
if time() - last_edit > 3:
|
||||||
|
await msg.edit(f"Generating prompt #{self.prompt}, {int(data['progress'] * 100)}% done. {time()-start:.1f}s elapsed, {int(data['eta_relative'])}s remaining.")
|
||||||
|
last_edit = time()
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
def start(self, future=None):
|
def start(self, future=None):
|
||||||
|
|
||||||
if not self.client.process_queue:
|
if not self.client.process_queue:
|
||||||
|
@ -67,12 +74,13 @@ class Worker:
|
||||||
self.task.add_done_callback(self.start)
|
self.task.add_done_callback(self.start)
|
||||||
|
|
||||||
async def process_prompt(self, prompt, queue_item):
|
async def process_prompt(self, prompt, queue_item):
|
||||||
|
endpoint = 'txt2img'
|
||||||
|
|
||||||
# First of all, check if the user can still be messaged.
|
# First of all, check if the user can still be messaged.
|
||||||
|
|
||||||
async with httpx.AsyncClient() as httpclient:
|
async with httpx.AsyncClient() as httpclient:
|
||||||
try:
|
try:
|
||||||
await httpclient.get(join(self.api, 'ping'), timeout=5)
|
await httpclient.get(self.api + '/internal/ping', timeout=5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
log.error('Server is dead. Waiting 10 seconds for server availability...')
|
log.error('Server is dead. Waiting 10 seconds for server availability...')
|
||||||
|
@ -88,23 +96,23 @@ class Worker:
|
||||||
|
|
||||||
# Prepare the parameters for the request
|
# Prepare the parameters for the request
|
||||||
params = default_vars.copy()
|
params = default_vars.copy()
|
||||||
params['session_id'] = str(prompt['id'])
|
|
||||||
params['prompt'] = prompt['prompt'] or ''
|
params['prompt'] = prompt['prompt'] or ''
|
||||||
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
|
params['negative_prompt'] = str(prompt['negative_prompt'])
|
||||||
params['num_outputs'] = int(prompt['number'])
|
params['n_iter'] = int(prompt['number'])
|
||||||
params['num_inference_steps'] = prompt['inference_steps']
|
#params['batch_size'] = int(prompt['number'])
|
||||||
params['guidance_scale'] = prompt['detail']
|
params['steps'] = prompt['inference_steps']
|
||||||
|
params['cfg_scale'] = prompt['detail']
|
||||||
params['width'] = prompt['resolution'].split('x')[0]
|
params['width'] = prompt['resolution'].split('x')[0]
|
||||||
params['height'] = prompt['resolution'].split('x')[1]
|
params['height'] = prompt['resolution'].split('x')[1]
|
||||||
params['seed'] = str(prompt['seed'])
|
params['seed'] = str(prompt['seed'])
|
||||||
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
|
|
||||||
|
|
||||||
self.prompt = prompt['id']
|
self.prompt = prompt['id']
|
||||||
|
|
||||||
if prompt['hires'] == 'yes':
|
#if prompt['hires'] == 'yes':
|
||||||
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
# params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
||||||
|
|
||||||
if prompt['image'] != 'no_image':
|
if prompt['image'] != 'no_image':
|
||||||
|
endpoint = 'img2img'
|
||||||
img = Image.open(prompt['image'])
|
img = Image.open(prompt['image'])
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
if prompt['crop'] == 'no':
|
if prompt['crop'] == 'no':
|
||||||
|
@ -115,60 +123,43 @@ class Worker:
|
||||||
imgdata = BytesIO()
|
imgdata = BytesIO()
|
||||||
img.save(imgdata, format='JPEG')
|
img.save(imgdata, format='JPEG')
|
||||||
|
|
||||||
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
|
params['init_images'] = [('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()]
|
||||||
params['sampler_name'] = 'ddim'
|
params['sampler_name'] = 'DDIM'
|
||||||
params['prompt_strength'] = prompt['blend']
|
params['denoising_strength'] = prompt['blend']
|
||||||
|
|
||||||
async with httpx.AsyncClient() as httpclient:
|
async with httpx.AsyncClient() as httpclient:
|
||||||
|
|
||||||
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
|
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
|
||||||
|
|
||||||
start = time()
|
|
||||||
failed = False
|
failed = False
|
||||||
|
|
||||||
self.log.info('POST to server')
|
self.log.info('POST to server')
|
||||||
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
|
progress_task = asyncio.create_task(self.update_progress_bar(msg))
|
||||||
|
try:
|
||||||
|
res = await httpclient.post(self.api + f'/sdapi/v1/{endpoint}', data=json.dumps(params), timeout=300)
|
||||||
|
except Exception as e:
|
||||||
|
await self.client.send_message(self.client.admin_id, f"While generating #{prompt['id']}: {e}")
|
||||||
|
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error.\n\nThis might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case.")
|
||||||
|
failed = True
|
||||||
|
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||||
|
return
|
||||||
res = res.json()
|
res = res.json()
|
||||||
|
|
||||||
last_edit = 0
|
self.log.info('Success!')
|
||||||
|
images = []
|
||||||
|
for img in res['images']:
|
||||||
|
imgdata = BytesIO(b64decode(img.split(",",1)[0]))
|
||||||
|
imgdata.name = 'image.png'
|
||||||
|
imgdata.seek(0)
|
||||||
|
images.append(imgdata)
|
||||||
|
|
||||||
while 1:
|
|
||||||
step = await httpclient.get(join(self.api, res['stream'][1:]))
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = step.json()
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if 'step' in data:
|
|
||||||
if int(data['step'])%10 == 0:
|
|
||||||
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
|
|
||||||
|
|
||||||
if time() - last_edit > 10:
|
|
||||||
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
|
|
||||||
last_edit = time()
|
|
||||||
|
|
||||||
elif 'status' in data and data['status'] == 'failed':
|
|
||||||
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
|
|
||||||
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
|
|
||||||
failed = True
|
|
||||||
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
|
||||||
break
|
|
||||||
elif 'status' in data and data['status'] == 'succeeded':
|
|
||||||
self.log.info('Success!')
|
|
||||||
images = []
|
|
||||||
for img in data['output']:
|
|
||||||
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
|
|
||||||
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
|
|
||||||
imgdata.name = 'image.png'
|
|
||||||
imgdata.seek(0)
|
|
||||||
images.append(imgdata)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(data)
|
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
|
|
||||||
|
progress_task.cancel()
|
||||||
|
try:
|
||||||
|
await progress_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
||||||
await msg.delete()
|
await msg.delete()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue