210 lines
7.7 KiB
Python
210 lines
7.7 KiB
Python
import asyncio
|
|
from PIL import Image
|
|
from PIL import ImageOps
|
|
from base64 import b64encode, b64decode
|
|
from io import BytesIO
|
|
import logging
|
|
import json
|
|
import httpx
|
|
from os import listdir
|
|
from os.path import join
|
|
from time import time
|
|
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
|
|
|
|
log = logging.getLogger('process')
|
|
temp_folder = '/home/ed/temp/'
|
|
|
|
default_vars = {
|
|
"use_cpu":False,
|
|
"use_full_precision": False,
|
|
"stream_progress_updates": True,
|
|
"stream_image_progress": False,
|
|
"show_only_filtered_image": True,
|
|
"sampler_name": "dpm_solver_stability",
|
|
"save_to_disk_path": temp_folder,
|
|
"output_format": "png",
|
|
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
|
|
"metadata_output_format": "embed",
|
|
"use_hypernetwork_model": "boring_e621",
|
|
"hypernetwork_strength": 0.25,
|
|
}
|
|
|
|
class Worker:
|
|
def __init__(self, api, client, name):
|
|
self.api = api
|
|
self.ready = False
|
|
self.client = client
|
|
self.queue = client.queue
|
|
self.conn = client.conn
|
|
self.loop = asyncio.get_event_loop()
|
|
self.task = None
|
|
self.name = name
|
|
self.log = logging.getLogger(name)
|
|
self.prompt = None
|
|
|
|
def start(self, future=None):
|
|
|
|
if not self.client.process_queue:
|
|
asyncio.create_task(self.client.send_message(self.client.admin_id, f"Loop of {self.name} has been stopped."))
|
|
return
|
|
|
|
if self.task and not self.task.done(): return
|
|
|
|
if future and future.exception():
|
|
self.log.error(future.exception())
|
|
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), self.prompt))
|
|
return
|
|
|
|
try:
|
|
priority, prompt_id = self.queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
self.log.info('No more tasks to process!')
|
|
else:
|
|
prompt = self.client.conn.execute('SELECT prompt.* FROM prompt WHERE id = ?', (prompt_id,)).fetchone()
|
|
|
|
self.log.info(f"Processing {prompt_id}")
|
|
self.task = self.loop.create_task(self.process_prompt(prompt, (priority, prompt_id)))
|
|
self.task.add_done_callback(self.start)
|
|
|
|
async def process_prompt(self, prompt, queue_item):
|
|
|
|
# First of all, check if the user can still be messaged.
|
|
|
|
async with httpx.AsyncClient() as httpclient:
|
|
try:
|
|
await httpclient.get(join(self.api, 'ping'), timeout=5)
|
|
except Exception as e:
|
|
print(str(e))
|
|
log.error('Server is dead. Waiting 10 seconds for server availability...')
|
|
await self.queue.put(queue_item)
|
|
await asyncio.sleep(10)
|
|
return
|
|
|
|
try:
|
|
msg = await self.client.send_message(prompt['user_id'], f"Hello 👀 Generating your prompt now.")
|
|
except UserIsBlockedError:
|
|
self.conn.execute('UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ?',(prompt['id'],))
|
|
return
|
|
|
|
# Prepare the parameters for the request
|
|
params = default_vars.copy()
|
|
params['session_id'] = str(prompt['id'])
|
|
params['prompt'] = prompt['prompt'] or ''
|
|
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4'
|
|
params['num_outputs'] = int(prompt['number'])
|
|
params['num_inference_steps'] = prompt['inference_steps']
|
|
params['guidance_scale'] = prompt['detail']
|
|
params['width'] = prompt['resolution'].split('x')[0]
|
|
params['height'] = prompt['resolution'].split('x')[1]
|
|
params['seed'] = str(prompt['seed'])
|
|
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
|
|
|
|
self.prompt = prompt['id']
|
|
|
|
if prompt['hires'] == 'yes':
|
|
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
|
|
|
|
if prompt['image'] != 'no_image':
|
|
img = Image.open(prompt['image'])
|
|
img = img.convert('RGB')
|
|
if prompt['crop'] == 'no':
|
|
img = img.resize(list((int(x) for x in prompt['resolution'].split('x'))))
|
|
else:
|
|
img = ImageOps.fit(img, list((int(x) for x in prompt['resolution'].split('x'))))
|
|
|
|
imgdata = BytesIO()
|
|
img.save(imgdata, format='JPEG')
|
|
|
|
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()
|
|
params['sampler_name'] = 'ddim'
|
|
params['prompt_strength'] = prompt['blend']
|
|
|
|
async with httpx.AsyncClient() as httpclient:
|
|
|
|
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
|
|
|
|
start = time()
|
|
failed = False
|
|
|
|
self.log.info('POST to server')
|
|
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params))
|
|
res = res.json()
|
|
|
|
last_edit = 0
|
|
|
|
while 1:
|
|
step = await httpclient.get(join(self.api, res['stream'][1:]))
|
|
|
|
try:
|
|
data = step.json()
|
|
except:
|
|
continue
|
|
|
|
if 'step' in data:
|
|
if int(data['step'])%10 == 0:
|
|
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
|
|
|
|
if time() - last_edit > 10:
|
|
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
|
|
last_edit = time()
|
|
|
|
elif 'status' in data and data['status'] == 'failed':
|
|
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
|
|
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
|
|
failed = True
|
|
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
|
break
|
|
elif 'status' in data and data['status'] == 'succeeded':
|
|
self.log.info('Success!')
|
|
images = []
|
|
for img in data['output']:
|
|
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
|
|
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
|
|
imgdata.name = 'image.png'
|
|
imgdata.seek(0)
|
|
images.append(imgdata)
|
|
break
|
|
else:
|
|
print(data)
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
|
|
await msg.delete()
|
|
|
|
if not failed:
|
|
asyncio.create_task(self.send_submission(prompt, images))
|
|
|
|
async def send_submission(self, prompt, images):
|
|
|
|
tg_files = []
|
|
|
|
for fn in images:
|
|
img = Image.open(fn)
|
|
img.thumbnail((1280,1280))
|
|
imgdata = BytesIO()
|
|
img.save(imgdata, format='JPEG')
|
|
siz = imgdata.seek(0,2)
|
|
imgdata.seek(0)
|
|
tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name))
|
|
|
|
results = await self.client.send_file(self.client.main_channel_id, tg_files, caption=
|
|
"\n".join([f"#{prompt['id']} · 🌀 {prompt['inference_steps']} · 🌱 {prompt['seed']} · 💎 {prompt['detail']}" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''),
|
|
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
|
|
('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']),
|
|
(f"\n👎 {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML')
|
|
|
|
await self.client.forward_messages(prompt['user_id'], results)
|
|
if prompt['hires'] == 'yes':
|
|
await self.client.send_message(prompt['user_id'], 'Uploading raw images... This will take a while')
|
|
tg_files = []
|
|
for img in images:
|
|
siz = img.seek(0,2)
|
|
img.seek(0)
|
|
tg_files.append(await self.client.upload_file(img, file_size=siz, file_name=img.name))
|
|
|
|
results = await self.client.send_file(self.client.main_channel_id, tg_files, force_document=True, caption=f"Raw images of #{prompt['id']}")
|
|
await self.client.forward_messages(prompt['user_id'], results)
|
|
|
|
self.log.info(f"Files for prompt #{prompt['id']} have been sent succesfully.")
|