Compare commits

...

15 Commits
main ... dev

Author SHA1 Message Date
nameless d3de1bb531 example config: admin_id comment 2023-08-15 22:05:22 +00:00
nameless 17a84f8dc3 Update README.md 2023-08-15 21:50:01 +00:00
nameless 4c3a2f405b example config: add comments, change port 2023-08-15 21:36:46 +00:00
nameless b3e13ea6e8 button for negative_prompt removal 2023-08-15 21:18:45 +00:00
nameless 5ba5b6043c markdown: pre -> code
pre doesn't allow to copy prompt by clicking on it
2023-08-15 20:49:45 +00:00
nameless 44cbd89d30 Update README.md
replace Easy Diffusion with A1111 WebUI, add some extra steps
2023-08-15 20:46:47 +00:00
nameless a206eb032e db counter reset, schema update
- prompt id starts from 1
- "no_image" by default for "image" in prompt/pending_prompt tables
2023-08-15 20:11:29 +00:00
nameless 7bed58748b use bot_token for client.start 2023-08-15 19:52:40 +00:00
nameless 17bf9835c7 remove temp_folder, add bot_token in config 2023-08-15 19:50:58 +00:00
nameless 2109757dd0 yiffy-e18 tags -> fluffyrock-3m tags 2023-08-15 19:14:11 +00:00
nameless 263117d149 Delete yiffy-e18.json
not used
2023-08-15 19:05:04 +00:00
nameless e5096335d0 fix TypeError caused by empty db query results 2023-08-15 19:02:13 +00:00
nameless 36db9e79dc minimize compression damage with image resize
not sure if it's effective
2023-08-15 18:45:31 +00:00
nameless 01ed370ab6 little UX changes
- notify user about image upload
- code markdown for prompt/settings
2023-08-15 18:35:59 +00:00
nameless 443f26f5fb rewrite process_queue.py for A1111 api
- txt2img
- img2img 
- progress polling
2023-08-15 18:28:01 +00:00
7 changed files with 144043 additions and 77952 deletions

View File

@ -1,35 +1,20 @@
# ai621 # ai621111 (A1111 webui fork)
This is the code currently running over at @ai621bot on telegram. This is fork of ai621 (https://git.foxo.me/AI621/ai621) bot that contains some additional features/fixes.
Don't worry, it's gonna be rewritten soon.
## How does it work? ## How does it work?
This is only the "bot" part. The actual generation is done by Easy Diffusion. To run this, you have to install and run an Easy Diffusion (https://github.com/easydiffusion/easydiffusion) instance and define the IP/hostname in the config files so that the bot can query and send data to port 9000 using the embedded api. This is only the "bot" part. The actual generation is done by AUTOMATIC1111 Stable Diffusion web UI. To run this, you have to install and run an AUTOMATIC1111 Stable Diffusion web UI (https://github.com/AUTOMATIC1111/stable-diffusion-webui) instance and define the IP/hostname in the config files so that the bot can query and send data to port 7860 using the embedded api.
## How do i run it? ## How do i run it?
- Create bot account. To get a bot account, you need to talk with `@BotFather`.
- Create two Telegram channels - one for output, and one for log. Add your bot as administrator, enable rights to post messages.
- Start your A1111 web UI instance. Don't forget to add `--api` to `set COMMANDLINE_ARGS=` in `webui-user.bat`!
- Obtain a Telegram api_id: https://docs.telethon.dev/en/stable/basic/signing-in.html
- Create a Python virtual env to not pollute your system: https://docs.python.org/3/library/venv.html#creating-virtual-environments
- Install all the dependencies needed by the bot using `pip install -r requirements.txt` after activating the venv
- Copy example config file, rename it to `config.py` and edit it following the instructions.
- Copy example database and rename it to `ai621.db`.
- Run the bot with `python3 bot.py`
1. Obtain a Telegram api_id: https://docs.telethon.dev/en/stable/basic/signing-in.html Additional info about the bot can be found in the upstream repo (https://git.foxo.me/AI621/ai621).
2. Create a Python virtual env to not pollute your system: https://docs.python.org/3/library/venv.html#creating-virtual-environments
3. Install all the dependencies needed by the bot using "pip install -r requirements.txt" after activating the venv
4. Edit the config files following the instructions
5. Run the bot with "python3 bot.py"
## User data
Data is stored inside "ai621.db": three tables (user, prompt and pending_prompt) define all the data. Additionally, uploaded images (for the img2img mode) are stored inside the "userdata/" folder which is inside of the work folder.
## Can you define multiple video cards / nodes to speed generation up?
No. Or at least, not anymore. You will notice from the example configuration that three nodes were defined. This was working properly, until i accidentally overwrote the files that were responsible for that with an older version that didn't have that function anymore. I didn't bother.
## What are the log and raw channels?
ai621 posts all generated images to a channel called "ai621 raw", and then forwards them to the user at the end of the prompt processing. This is so that everybody can see and get inspired by other prompts. This is what the raw channel is.
The log channel, instead, is where all prompts and stored along with buttons to be able to moderate users (delete prompts and ban users).
## What are "cycles"?
Every day each user has a defined number of cycles, also known as "steps". Since steps can be defined in a sorta stable amount of time, this is used to make sure that all users can fairly use the bot.
Some specific prompts also ramp up the amount of cycles used as a "penalty", like using blacklisted tags or bad quality prompts.

Binary file not shown.

18
bot.py
View File

@ -113,6 +113,12 @@ async def edit_pending(ev):
await ev.respond('Something went wrong. Try /start again.') await ev.respond('Something went wrong. Try /start again.')
return return
if field == 'negative_prompt':
if data == 'no_negative_prompt':
conn.execute('UPDATE pending_prompt SET negative_prompt = NULL WHERE user_id = ?', (ev.input_sender.user_id,))
conn.execute(f"UPDATE user SET pending_action = NULL WHERE id = ?", (ev.input_sender.user_id,))
return await edit_prompt(ev)
if field == 'number': if field == 'number':
try: try:
data = int(data) data = int(data)
@ -298,6 +304,10 @@ async def queue_info(ev):
waiting_usrs = conn.execute("SELECT count(*) FROM prompt WHERE is_done IS NULL").fetchone() waiting_usrs = conn.execute("SELECT count(*) FROM prompt WHERE is_done IS NULL").fetchone()
avg_speed = conn.execute("SELECT avg(aa) FROM (SELECT (inference_steps*number)/abs(completed_at-started_at) AS aa FROM prompt WHERE started_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)").fetchone() avg_speed = conn.execute("SELECT avg(aa) FROM (SELECT (inference_steps*number)/abs(completed_at-started_at) AS aa FROM prompt WHERE started_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)").fetchone()
avg_comp_time = avg_comp_time if avg_comp_time[0] is not None else [0]
min_max_wait = min_max_wait if min_max_wait[0] is not None else [0, 0]
avg_speed = avg_speed if avg_speed[0] is not None else [0]
await ev.respond("\n".join([ await ev.respond("\n".join([
f"👯‍♂️ {waiting_usrs[0]} people are waiting in the queue", f"👯‍♂️ {waiting_usrs[0]} people are waiting in the queue",
@ -377,6 +387,7 @@ async def queue(ev, new_message=False):
raise events.StopPropagation raise events.StopPropagation
avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone() avg_comp_time = conn.execute('SELECT avg(aa) FROM (SELECT abs(queued_at-started_at) AS aa FROM prompt WHERE queued_at IS NOT NULL AND completed_at IS NOT NULL ORDER BY id DESC LIMIT 10)').fetchone()
avg_comp_time = avg_comp_time if avg_comp_time[0] is not None else [0]
behind_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id > ?', (prompt['id'],)).fetchone()[0] behind_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id > ?', (prompt['id'],)).fetchone()[0]
front_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id < ?', (prompt['id'],)).fetchone()[0] front_you = conn.execute('SELECT count(*) FROM prompt WHERE is_done IS NULL AND id < ?', (prompt['id'],)).fetchone()[0]
@ -455,6 +466,7 @@ desc = {
desc_buttons = { desc_buttons = {
'seed': [Button.inline(f"Get from last prompt", f"msg_but:last_prompt"),], 'seed': [Button.inline(f"Get from last prompt", f"msg_but:last_prompt"),],
'negative_prompt': [Button.inline(f"🚫 Remove negative prompt", f"msg_but:no_negative_prompt")],
'number': [Button.inline(f"1", f"msg_but:1"), Button.inline(f"2", f"msg_but:2"), Button.inline(f"3", f"msg_but:3"), Button.inline(f"4", f"msg_but:4")], 'number': [Button.inline(f"1", f"msg_but:1"), Button.inline(f"2", f"msg_but:2"), Button.inline(f"3", f"msg_but:3"), Button.inline(f"4", f"msg_but:4")],
'detail': [Button.inline(f"S (6.0)", f"msg_but:6"), Button.inline(f"M (10)", f"msg_but:10"), Button.inline(f"L (18)", f"msg_but:18")], 'detail': [Button.inline(f"S (6.0)", f"msg_but:6"), Button.inline(f"M (10)", f"msg_but:10"), Button.inline(f"L (18)", f"msg_but:18")],
'inference_steps': [Button.inline(f"S (20)", f"msg_but:20"), Button.inline(f"M (40)", f"msg_but:40"), Button.inline(f"L (60)", f"msg_but:60")], 'inference_steps': [Button.inline(f"S (20)", f"msg_but:20"), Button.inline(f"M (40)", f"msg_but:40"), Button.inline(f"L (60)", f"msg_but:60")],
@ -546,8 +558,8 @@ async def edit_prompt(ev):
🖼 Starting image: {'None' if p['image'] == 'no_image' else ('https://e621.net/posts/'+str(p['image_e6']) if p['image_e6'] else 'User uploaded image')} 🖼 Starting image: {'None' if p['image'] == 'no_image' else ('https://e621.net/posts/'+str(p['image_e6']) if p['image_e6'] else 'User uploaded image')}
🌱 Seed: {p['seed']} 🌱 Seed: {p['seed']}
👀 Prompt: <pre>{p['prompt']}</pre> 👀 Prompt: <code>{p['prompt']}</code>
Negative prompt: <pre>{p['negative_prompt'] or 'no negative prompt set.'}</pre> Negative prompt: <code>{p['negative_prompt'] or 'no negative prompt set.'}</code>
<b>Do not touch these if you don't know what you're doing</b> <b>Do not touch these if you don't know what you're doing</b>
🔢 Number of images: {p['number']} 🔢 Number of images: {p['number']}
@ -615,6 +627,6 @@ if __name__ == '__main__':
continue continue
client.queue.put_nowait((task['id'], task['id'],)) client.queue.put_nowait((task['id'], task['id'],))
client.start() client.start(bot_token=bot_token)
client.flood_sleep_threshold = 24*60*60 client.flood_sleep_threshold = 24*60*60
client.run_until_disconnected() client.run_until_disconnected()

View File

@ -7,10 +7,12 @@ from process_queue import *
coloredlogs.install(level='INFO') coloredlogs.install(level='INFO')
api_id = YOUR TG API ID HERE # https://docs.telethon.dev/en/stable/basic/signing-in.html
api_hash = YOUR TG API HASH HERE api_id = <YOUR TG API ID HERE>
api_hash = <YOUR TG API HASH HERE>
temp_folder = TEMP FOLDER OF THE GENERATIONS # Get it from @BotFather bot.
bot_token = <YOUR BOT TOKEN HERE>
client = TelegramClient('bot', api_id, api_hash) client = TelegramClient('bot', api_id, api_hash)
@ -22,16 +24,23 @@ conn.row_factory = sqlite3.Row
client.conn = conn client.conn = conn
client.process = None client.process = None
client.log_channel_id = ID OF LOG CHANNEL # To get channel ID: Send something in private(!) channel, tap/click on it,
client.main_channel_id = ID OF RAW CHANNEL # choose "Copy post link", paste it somewhere. It should look like this:
# https://t.me/c/12345/2
# 12345 part is ID - add "-100" prefix before it (like this: -10012345)
# Replace placeholder with result:
client.log_channel_id = <ID OF LOG CHANNEL>
client.main_channel_id = <ID OF RAW CHANNEL>
client.queue = PriorityQueue() client.queue = PriorityQueue()
client.media_lock = asyncio.Lock() client.media_lock = asyncio.Lock()
client.process_queue = False client.process_queue = False
client.admin_id = USER ID OF THE ADMIN # Send message to @my_id_bot, it will reply with your Telegram ID.
client.admin_id = <USER ID OF THE ADMIN>
workers = [ workers = [
Worker('http://127.0.0.1:9000', client, 'armorlink'), Worker('http://127.0.0.1:7860', client, 'armorlink'),
#Worker('http://127.0.0.1:9001', client, 'armorlink-low'), #Worker('http://127.0.0.1:9001', client, 'armorlink-low'),
#Worker('http://local.proxy:9000', client, 'g14') #Worker('http://local.proxy:9000', client, 'g14')
] ]

View File

@ -12,21 +12,10 @@ from time import time
from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError from telethon.errors.rpcerrorlist import MessageNotModifiedError, UserIsBlockedError
log = logging.getLogger('process') log = logging.getLogger('process')
temp_folder = '/home/ed/temp/'
scale_factor = 1.5
default_vars = { default_vars = {
"use_cpu":False, "sampler_name": "DPM++ 2M Karras",
"use_full_precision": False,
"stream_progress_updates": True,
"stream_image_progress": False,
"show_only_filtered_image": True,
"sampler_name": "dpm_solver_stability",
"save_to_disk_path": temp_folder,
"output_format": "png",
"use_stable_diffusion_model": "fluffyrock-576-704-832-960-1088-lion-low-lr-e61-terminal-snr-e34",
"metadata_output_format": "embed",
"use_hypernetwork_model": "boring_e621",
"hypernetwork_strength": 0.25,
} }
class Worker: class Worker:
@ -42,6 +31,25 @@ class Worker:
self.log = logging.getLogger(name) self.log = logging.getLogger(name)
self.prompt = None self.prompt = None
async def update_progress_bar(self, msg):
start = time()
last_edit = 0
while True:
async with httpx.AsyncClient() as httpclient:
step = await httpclient.get(self.api + '/sdapi/v1/progress')
try:
data = step.json()
except:
continue
if data['state']['sampling_step']%10 == 0:
self.log.info(f"Generation progress of {self.prompt}: {data['state']['job']} | {data['state']['sampling_step']}/{data['state']['sampling_steps']}")
if time() - last_edit > 3:
await msg.edit(f"Generating prompt #{self.prompt}, {int(data['progress'] * 100)}% done. {time()-start:.1f}s elapsed, {int(data['eta_relative'])}s remaining.")
last_edit = time()
await asyncio.sleep(2)
def start(self, future=None): def start(self, future=None):
if not self.client.process_queue: if not self.client.process_queue:
@ -67,12 +75,13 @@ class Worker:
self.task.add_done_callback(self.start) self.task.add_done_callback(self.start)
async def process_prompt(self, prompt, queue_item): async def process_prompt(self, prompt, queue_item):
endpoint = 'txt2img'
# First of all, check if the user can still be messaged. # First of all, check if the user can still be messaged.
async with httpx.AsyncClient() as httpclient: async with httpx.AsyncClient() as httpclient:
try: try:
await httpclient.get(join(self.api, 'ping'), timeout=5) await httpclient.get(self.api + '/internal/ping', timeout=5)
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
log.error('Server is dead. Waiting 10 seconds for server availability...') log.error('Server is dead. Waiting 10 seconds for server availability...')
@ -88,23 +97,23 @@ class Worker:
# Prepare the parameters for the request # Prepare the parameters for the request
params = default_vars.copy() params = default_vars.copy()
params['session_id'] = str(prompt['id'])
params['prompt'] = prompt['prompt'] or '' params['prompt'] = prompt['prompt'] or ''
params['negative_prompt'] = prompt['negative_prompt'] or 'boring_e621_v4' params['negative_prompt'] = str(prompt['negative_prompt'])
params['num_outputs'] = int(prompt['number']) params['n_iter'] = int(prompt['number'])
params['num_inference_steps'] = prompt['inference_steps'] #params['batch_size'] = int(prompt['number'])
params['guidance_scale'] = prompt['detail'] params['steps'] = prompt['inference_steps']
params['cfg_scale'] = prompt['detail']
params['width'] = prompt['resolution'].split('x')[0] params['width'] = prompt['resolution'].split('x')[0]
params['height'] = prompt['resolution'].split('x')[1] params['height'] = prompt['resolution'].split('x')[1]
params['seed'] = str(prompt['seed']) params['seed'] = str(prompt['seed'])
params['vram_usage_level'] = 'low' if '-low' in self.name else ('medium' if '1024' in prompt['resolution'] else 'high')
self.prompt = prompt['id'] self.prompt = prompt['id']
if prompt['hires'] == 'yes': #if prompt['hires'] == 'yes':
params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B' # params['use_upscale'] = 'RealESRGAN_x4plus_anime_6B'
if prompt['image'] != 'no_image': if prompt['image'] != 'no_image':
endpoint = 'img2img'
img = Image.open(prompt['image']) img = Image.open(prompt['image'])
img = img.convert('RGB') img = img.convert('RGB')
if prompt['crop'] == 'no': if prompt['crop'] == 'no':
@ -115,60 +124,43 @@ class Worker:
imgdata = BytesIO() imgdata = BytesIO()
img.save(imgdata, format='JPEG') img.save(imgdata, format='JPEG')
params['init_image'] = ('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip() params['init_images'] = [('data:image/jpeg;base64,'+b64encode(imgdata.getvalue()).decode()).strip()]
params['sampler_name'] = 'ddim' params['sampler_name'] = 'DDIM'
params['prompt_strength'] = prompt['blend'] params['denoising_strength'] = prompt['blend']
async with httpx.AsyncClient() as httpclient: async with httpx.AsyncClient() as httpclient:
self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id'])) self.conn.execute('UPDATE prompt SET started_at = ? WHERE id = ?', (time(), prompt['id']))
start = time()
failed = False failed = False
self.log.info('POST to server') self.log.info('POST to server')
res = await httpclient.post(join(self.api, 'render'), data=json.dumps(params)) progress_task = asyncio.create_task(self.update_progress_bar(msg))
try:
res = await httpclient.post(self.api + f'/sdapi/v1/{endpoint}', data=json.dumps(params), timeout=300)
except Exception as e:
await self.client.send_message(self.client.admin_id, f"While generating #{prompt['id']}: {e}")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error.\n\nThis might mean a bad combination of parameters, or issues on our side. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
return
res = res.json() res = res.json()
last_edit = 0 self.log.info('Success!')
images = []
for img in res['images']:
imgdata = BytesIO(b64decode(img.split(",",1)[0]))
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
while 1:
step = await httpclient.get(join(self.api, res['stream'][1:]))
try:
data = step.json()
except:
continue
if 'step' in data:
if int(data['step'])%10 == 0:
self.log.info(f"Generation progress of {prompt['id']}: {data['step']}/{data['total_steps']}")
if time() - last_edit > 10:
await msg.edit(f"Generating prompt #{prompt['id']}, step {data['step']} of {data['total_steps']}. {time()-start:.1f}s elapsed.")
last_edit = time()
elif 'status' in data and data['status'] == 'failed':
await self.client.send_message(184151234, f"While generating #{prompt['id']}: {data['detail']}...")
await self.client.send_message(prompt['user_id'], f"While trying to generate your prompt we encountered an error: {data['detail']}\n\nThis might mean a bad combination of parameters, or issues on our sifde. We will retry a couple of times just in case.")
failed = True
self.client.conn.execute('UPDATE prompt SET is_error = 1, is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
break
elif 'status' in data and data['status'] == 'succeeded':
self.log.info('Success!')
images = []
for img in data['output']:
imgdata = BytesIO(b64decode(img['data'].split('base64,',1)[1]))
#imgdata.name = img['path_abs'].rsplit('/', 1)[-1]
imgdata.name = 'image.png'
imgdata.seek(0)
images.append(imgdata)
break
else:
print(data)
await asyncio.sleep(2)
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id'])) self.conn.execute('UPDATE prompt SET is_done = 1, completed_at = ? WHERE id = ?',(time(), prompt['id']))
await msg.delete() await msg.delete()
@ -176,12 +168,14 @@ class Worker:
asyncio.create_task(self.send_submission(prompt, images)) asyncio.create_task(self.send_submission(prompt, images))
async def send_submission(self, prompt, images): async def send_submission(self, prompt, images):
await self.client.send_message(prompt['user_id'], 'Uploading images... This will take a while')
tg_files = [] tg_files = []
for fn in images: for fn in images:
img = Image.open(fn) img = Image.open(fn)
img.thumbnail((1280,1280)) width, height = img.size
img = img.resize((int(width * scale_factor), int(height * scale_factor)))
imgdata = BytesIO() imgdata = BytesIO()
img.save(imgdata, format='JPEG') img.save(imgdata, format='JPEG')
siz = imgdata.seek(0,2) siz = imgdata.seek(0,2)
@ -189,10 +183,10 @@ class Worker:
tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name)) tg_files.append(await self.client.upload_file(imgdata, file_size=siz, file_name=fn.name))
results = await self.client.send_file(self.client.main_channel_id, tg_files, caption= results = await self.client.send_file(self.client.main_channel_id, tg_files, caption=
"\n".join([f"#{prompt['id']} · 🌀 {prompt['inference_steps']} · 🌱 {prompt['seed']} · 💎 {prompt['detail']}" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''), "\n".join([f"#{prompt['id']} · 🌀 <code>{prompt['inference_steps']}</code> · 🌱 <code>{prompt['seed']}</code> · 💎 <code>{prompt['detail']}</code>" + (f" · 🎚 {prompt['blend']} (seed from image)" if prompt['image'] != 'no_image' else ''),
#((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''), #((f"🖼 https://e621.net/posts/{prompt['image_e6']}" if prompt['image_e6'] else 'user-uploaded image') if prompt['image'] != 'no_image' else ''),
('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else prompt['prompt']), ('👍 ' if prompt['negative_prompt'] else '')+(f"Prompt from https://e621.net/posts/{prompt['prompt_e6']}" if prompt['prompt_e6'] else f'<code>{prompt["prompt"]}</code>'),
(f"\n👎 {prompt['negative_prompt']}" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML') (f"\n👎 <code>{prompt['negative_prompt']}</code>" if prompt['negative_prompt'] else '')])[:1000], parse_mode='HTML')
await self.client.forward_messages(prompt['user_id'], results) await self.client.forward_messages(prompt['user_id'], results)
if prompt['hires'] == 'yes': if prompt['hires'] == 'yes':

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff