2023-07-04 18:19:03 +00:00
import asyncio
from PIL import Image
from PIL import ImageOps
from base64 import b64encode , b64decode
from io import BytesIO
import logging
import json
import httpx
from os import listdir
from os . path import join
from time import time
from telethon . errors . rpcerrorlist import MessageNotModifiedError , UserIsBlockedError
log = logging . getLogger ( ' process ' )
2023-08-15 18:45:31 +00:00
scale_factor = 1.5
2023-07-04 18:19:03 +00:00
default_vars = {
2023-08-15 18:28:01 +00:00
" sampler_name " : " DPM++ 2M Karras " ,
2023-07-04 18:19:03 +00:00
}
class Worker :
def __init__ ( self , api , client , name ) :
self . api = api
self . ready = False
self . client = client
self . queue = client . queue
self . conn = client . conn
self . loop = asyncio . get_event_loop ( )
self . task = None
self . name = name
self . log = logging . getLogger ( name )
self . prompt = None
2023-08-15 18:28:01 +00:00
async def update_progress_bar ( self , msg ) :
start = time ( )
last_edit = 0
while True :
async with httpx . AsyncClient ( ) as httpclient :
step = await httpclient . get ( self . api + ' /sdapi/v1/progress ' )
try :
data = step . json ( )
except :
continue
if data [ ' state ' ] [ ' sampling_step ' ] % 10 == 0 :
self . log . info ( f " Generation progress of { self . prompt } : { data [ ' state ' ] [ ' job ' ] } | { data [ ' state ' ] [ ' sampling_step ' ] } / { data [ ' state ' ] [ ' sampling_steps ' ] } " )
if time ( ) - last_edit > 3 :
await msg . edit ( f " Generating prompt # { self . prompt } , { int ( data [ ' progress ' ] * 100 ) } % done. { time ( ) - start : .1f } s elapsed, { int ( data [ ' eta_relative ' ] ) } s remaining. " )
last_edit = time ( )
await asyncio . sleep ( 2 )
2023-07-04 18:19:03 +00:00
def start ( self , future = None ) :
if not self . client . process_queue :
asyncio . create_task ( self . client . send_message ( self . client . admin_id , f " Loop of { self . name } has been stopped. " ) )
return
if self . task and not self . task . done ( ) : return
if future and future . exception ( ) :
self . log . error ( future . exception ( ) )
self . conn . execute ( ' UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ? ' , ( time ( ) , self . prompt ) )
return
try :
priority , prompt_id = self . queue . get_nowait ( )
except asyncio . QueueEmpty :
self . log . info ( ' No more tasks to process! ' )
else :
prompt = self . client . conn . execute ( ' SELECT prompt.* FROM prompt WHERE id = ? ' , ( prompt_id , ) ) . fetchone ( )
self . log . info ( f " Processing { prompt_id } " )
self . task = self . loop . create_task ( self . process_prompt ( prompt , ( priority , prompt_id ) ) )
self . task . add_done_callback ( self . start )
async def process_prompt ( self , prompt , queue_item ) :
2023-08-15 18:28:01 +00:00
endpoint = ' txt2img '
2023-07-04 18:19:03 +00:00
# First of all, check if the user can still be messaged.
async with httpx . AsyncClient ( ) as httpclient :
try :
2023-08-15 18:28:01 +00:00
await httpclient . get ( self . api + ' /internal/ping ' , timeout = 5 )
2023-07-04 18:19:03 +00:00
except Exception as e :
print ( str ( e ) )
log . error ( ' Server is dead. Waiting 10 seconds for server availability... ' )
await self . queue . put ( queue_item )
await asyncio . sleep ( 10 )
return
try :
msg = await self . client . send_message ( prompt [ ' user_id ' ] , f " Hello 👀 Generating your prompt now. " )
except UserIsBlockedError :
self . conn . execute ( ' UPDATE prompt SET is_done = 1, is_error = 1 WHERE id = ? ' , ( prompt [ ' id ' ] , ) )
return
# Prepare the parameters for the request
params = default_vars . copy ( )
params [ ' prompt ' ] = prompt [ ' prompt ' ] or ' '
2023-08-15 18:28:01 +00:00
params [ ' negative_prompt ' ] = str ( prompt [ ' negative_prompt ' ] )
params [ ' n_iter ' ] = int ( prompt [ ' number ' ] )
#params['batch_size'] = int(prompt['number'])
params [ ' steps ' ] = prompt [ ' inference_steps ' ]
params [ ' cfg_scale ' ] = prompt [ ' detail ' ]
2023-07-04 18:19:03 +00:00
params [ ' width ' ] = prompt [ ' resolution ' ] . split ( ' x ' ) [ 0 ]
params [ ' height ' ] = prompt [ ' resolution ' ] . split ( ' x ' ) [ 1 ]
params [ ' seed ' ] = str ( prompt [ ' seed ' ] )
self . prompt = prompt [ ' id ' ]
2023-08-15 18:28:01 +00:00
#if prompt['hires'] == 'yes':
# params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
2023-07-04 18:19:03 +00:00
if prompt [ ' image ' ] != ' no_image ' :
2023-08-15 18:28:01 +00:00
endpoint = ' img2img '
2023-07-04 18:19:03 +00:00
img = Image . open ( prompt [ ' image ' ] )
img = img . convert ( ' RGB ' )
if prompt [ ' crop ' ] == ' no ' :
img = img . resize ( list ( ( int ( x ) for x in prompt [ ' resolution ' ] . split ( ' x ' ) ) ) )
else :
img = ImageOps . fit ( img , list ( ( int ( x ) for x in prompt [ ' resolution ' ] . split ( ' x ' ) ) ) )
imgdata = BytesIO ( )
img . save ( imgdata , format = ' JPEG ' )
2023-08-15 18:28:01 +00:00
params [ ' init_images ' ] = [ ( ' data:image/jpeg;base64, ' + b64encode ( imgdata . getvalue ( ) ) . decode ( ) ) . strip ( ) ]
params [ ' sampler_name ' ] = ' DDIM '
params [ ' denoising_strength ' ] = prompt [ ' blend ' ]
2023-07-04 18:19:03 +00:00
async with httpx . AsyncClient ( ) as httpclient :
self . conn . execute ( ' UPDATE prompt SET started_at = ? WHERE id = ? ' , ( time ( ) , prompt [ ' id ' ] ) )
failed = False
self . log . info ( ' POST to server ' )
2023-08-15 18:28:01 +00:00
progress_task = asyncio . create_task ( self . update_progress_bar ( msg ) )
try :
res = await httpclient . post ( self . api + f ' /sdapi/v1/ { endpoint } ' , data = json . dumps ( params ) , timeout = 300 )
except Exception as e :
await self . client . send_message ( self . client . admin_id , f " While generating # { prompt [ ' id ' ] } : { e } " )
await self . client . send_message ( prompt [ ' user_id ' ] , f " While trying to generate your prompt we encountered an error. \n \n This might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case. " )
failed = True
self . client . conn . execute ( ' UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ? ' , ( time ( ) , prompt [ ' id ' ] ) )
return
2023-07-04 18:19:03 +00:00
res = res . json ( )
2023-08-15 18:28:01 +00:00
self . log . info ( ' Success! ' )
images = [ ]
for img in res [ ' images ' ] :
imgdata = BytesIO ( b64decode ( img . split ( " , " , 1 ) [ 0 ] ) )
imgdata . name = ' image.png '
imgdata . seek ( 0 )
images . append ( imgdata )
2023-07-04 18:19:03 +00:00
2023-08-15 18:28:01 +00:00
progress_task . cancel ( )
try :
await progress_task
except asyncio . CancelledError :
pass
2023-07-04 18:19:03 +00:00
self . conn . execute ( ' UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ? ' , ( time ( ) , prompt [ ' id ' ] ) )
await msg . delete ( )
if not failed :
asyncio . create_task ( self . send_submission ( prompt , images ) )
async def send_submission ( self , prompt , images ) :
2023-08-15 18:35:59 +00:00
await self . client . send_message ( prompt [ ' user_id ' ] , ' Uploading images... This will take a while ' )
2023-07-04 18:19:03 +00:00
tg_files = [ ]
for fn in images :
img = Image . open ( fn )
2023-08-15 18:45:31 +00:00
width , height = img . size
img = img . resize ( ( int ( width * scale_factor ) , int ( height * scale_factor ) ) )
2023-07-04 18:19:03 +00:00
imgdata = BytesIO ( )
img . save ( imgdata , format = ' JPEG ' )
siz = imgdata . seek ( 0 , 2 )
imgdata . seek ( 0 )
tg_files . append ( await self . client . upload_file ( imgdata , file_size = siz , file_name = fn . name ) )
results = await self . client . send_file ( self . client . main_channel_id , tg_files , caption =
2023-08-15 18:35:59 +00:00
" \n " . join ( [ f " # { prompt [ ' id ' ] } · 🌀 <code> { prompt [ ' inference_steps ' ] } </code> · 🌱 <code> { prompt [ ' seed ' ] } </code> · 💎 <code> { prompt [ ' detail ' ] } </code> " + ( f " · 🎚 { prompt [ ' blend ' ] } (seed from image) " if prompt [ ' image ' ] != ' no_image ' else ' ' ) ,
2023-07-04 18:19:03 +00:00
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
2023-08-15 18:35:59 +00:00
( ' 👍 ' if prompt [ ' negative_prompt ' ] else ' ' ) + ( f " Prompt from https://e621.net/posts/ { prompt [ ' prompt_e6 ' ] } " if prompt [ ' prompt_e6 ' ] else f ' <code> { prompt [ " prompt " ] } </code> ' ) ,
( f " \n 👎 <code> { prompt [ ' negative_prompt ' ] } </code> " if prompt [ ' negative_prompt ' ] else ' ' ) ] ) [ : 1000 ] , parse_mode = ' HTML ' )
2023-07-04 18:19:03 +00:00
await self . client . forward_messages ( prompt [ ' user_id ' ] , results )
if prompt [ ' hires ' ] == ' yes ' :
await self . client . send_message ( prompt [ ' user_id ' ] , ' Uploading raw images... This will take a while ' )
tg_files = [ ]
for img in images :
siz = img . seek ( 0 , 2 )
img . seek ( 0 )
tg_files . append ( await self . client . upload_file ( img , file_size = siz , file_name = img . name ) )
results = await self . client . send_file ( self . client . main_channel_id , tg_files , force_document = True , caption = f " Raw images of # { prompt [ ' id ' ] } " )
await self . client . forward_messages ( prompt [ ' user_id ' ] , results )
self . log . info ( f " Files for prompt # { prompt [ ' id ' ] } have been sent succesfully. " )