diff --git a/README.md b/README.md index 5042560..eb5e343 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,44 @@ Telegram bot for generating images from text. You can put these in `.env` file or pass them as environment variables. -- `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather)) -- `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`) -- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands - (optional) +- `TG_BOT_TOKEN` - Telegram bot token. Get yours from [@BotFather](https://t.me/BotFather). + Required. +- `SD_API_URL` - URL to Stable Diffusion API. Only used on first run. Default: + `http://127.0.0.1:7860/` +- `TG_ADMIN_USERS` - Comma separated list of usernames of users that can use admin commands. Only + used on first run. Optional. ## Running - Start stable diffusion webui: `cd sd-webui`, `./webui.sh --api` - Start bot: `deno task start` + +## TODO + +- [x] Keep generation history +- [x] Changing params, parsing png info in request +- [x] Cancelling jobs by deleting message +- [x] Multiple parallel workers +- [ ] Replying to another text message to copy prompt and generate +- [ ] Replying to bot message, conversation in DMs +- [ ] Replying to png message to extract png info nad generate +- [ ] Banning tags +- [ ] Img2Img + Upscale +- [ ] Admin WebUI +- [ ] User daily generation limits +- [ ] Querying all generation history, displaying stats +- [ ] Analyzing prompt quality based on tag csv +- [ ] Report aliased/unknown tags based on csv +- [ ] Report unknown loras +- [ ] Investigate "sendMediaGroup failed" +- [ ] Changing sampler without error on unknown sampler +- [ ] Changing model +- [ ] Inpaint using telegram photo edit +- [ ] Outpaint +- [ ] Non-SD (extras) upscale +- [ ] Tiled generation to allow very big images +- [ ] Downloading raw images +- [ ] Extra prompt syntax, fixing `()+++` syntax +- [ ] Translations + - replace fmtDuration usage + - replace formatOrdinal usage diff --git a/bot.ts b/bot.ts deleted file mode 100644 index 625ffe6..0000000 --- a/bot.ts +++ /dev/null @@ -1,273 +0,0 @@ -import { autoQuote, bold, Bot, Context, hydrateReply, log, ParseModeFlavor } from "./deps.ts"; -import { fmt } from "./intl.ts"; -import { getAllJobs, pushJob } from "./queue.ts"; -import { mySession, MySessionFlavor } from "./session.ts"; - -const logger = () => log.getLogger(); - -export type MyContext = ParseModeFlavor & MySessionFlavor; -export const bot = new Bot(Deno.env.get("TG_BOT_TOKEN") ?? ""); -bot.use(autoQuote); -bot.use(hydrateReply); -bot.use(mySession); - -// Automatically retry bot requests if we get a 429 error -bot.api.config.use(async (prev, method, payload, signal) => { - let remainingAttempts = 5; - while (true) { - const result = await prev(method, payload, signal); - if (result.ok) return result; - if (result.error_code !== 429 || remainingAttempts <= 0) return result; - remainingAttempts -= 1; - const retryAfterMs = (result.parameters?.retry_after ?? 30) * 1000; - await new Promise((resolve) => setTimeout(resolve, retryAfterMs)); - } -}); - -// if error happened, try to reply to the user with the error -bot.use(async (ctx, next) => { - try { - await next(); - } catch (err) { - try { - await ctx.reply(`Handling update failed: ${err}`, { - reply_to_message_id: ctx.message?.message_id, - }); - } catch { - throw err; - } - } -}); - -bot.api.setMyShortDescription("I can generate furry images from text"); -bot.api.setMyDescription( - "I can generate furry images from text. " + - "Send /txt2img to generate an image.", -); -bot.api.setMyCommands([ - { command: "txt2img", description: "Generate an image" }, - { command: "queue", description: "Show the current queue" }, - { command: "sdparams", description: "Show the current SD parameters" }, -]); - -bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image")); - -bot.command("txt2img", async (ctx) => { - if (!ctx.from?.id) { - return ctx.reply("I don't know who you are"); - } - const config = ctx.session.global; - if (config.pausedReason != null) { - return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); - } - const jobs = await getAllJobs(); - if (jobs.length >= config.maxJobs) { - return ctx.reply( - `The queue is full. Try again later. (Max queue size: ${config.maxJobs})`, - ); - } - const jobCount = jobs.filter((job) => job.user.id === ctx.from.id).length; - if (jobCount >= config.maxUserJobs) { - return ctx.reply( - `You already have ${config.maxUserJobs} jobs in queue. Try again later.`, - ); - } - if (!ctx.match) { - return ctx.reply("Please describe what you want to see after the command"); - } - const statusMessage = await ctx.reply("Accepted. You are now in queue."); - await pushJob({ - params: { prompt: ctx.match }, - user: ctx.from, - chat: ctx.chat, - requestMessage: ctx.message, - statusMessage, - status: { type: "idle" }, - }); - logger().info("Job enqueued", ctx.from.first_name, ctx.chat.type, ctx.match.replace(/\s+/g, " ")); -}); - -bot.command("queue", async (ctx) => { - let jobs = await getAllJobs(); - const getMessageText = () => { - if (jobs.length === 0) return fmt`Queue is empty.`; - const sortedJobs = []; - let place = 0; - for (const job of jobs) { - if (job.status.type === "idle") place += 1; - sortedJobs.push({ ...job, place }); - } - return fmt`Current queue:\n\n${ - sortedJobs.map((job) => - fmt`${job.place}. ${bold(job.user.first_name)} in ${job.chat.type} chat ${ - job.status.type === "processing" ? `(${(job.status.progress * 100).toFixed(0)}%)` : "" - }\n` - ) - }`; - }; - const message = await ctx.replyFmt(getMessageText()); - handleFutureUpdates(); - async function handleFutureUpdates() { - for (let idx = 0; idx < 12; idx++) { - await new Promise((resolve) => setTimeout(resolve, 5000)); - jobs = await getAllJobs(); - const formattedMessage = getMessageText(); - await ctx.api.editMessageText(ctx.chat.id, message.message_id, formattedMessage.text, { - entities: formattedMessage.entities, - }).catch(() => undefined); - } - } -}); - -bot.command("pause", (ctx) => { - if (!ctx.from?.username) return; - const config = ctx.session.global; - if (!config.adminUsernames.includes(ctx.from.username)) return; - if (config.pausedReason != null) { - return ctx.reply(`Already paused: ${config.pausedReason}`); - } - config.pausedReason = ctx.match ?? "No reason given"; - return ctx.reply("Paused"); -}); - -bot.command("resume", (ctx) => { - if (!ctx.from?.username) return; - const config = ctx.session.global; - if (!config.adminUsernames.includes(ctx.from.username)) return; - if (config.pausedReason == null) return ctx.reply("Already running"); - config.pausedReason = null; - return ctx.reply("Resumed"); -}); - -bot.command("setsdapiurl", async (ctx) => { - if (!ctx.from?.username) return; - const config = ctx.session.global; - if (!config.adminUsernames.includes(ctx.from.username)) return; - if (!ctx.match) return ctx.reply("Please specify an URL"); - let url: URL; - try { - url = new URL(ctx.match); - } catch { - return ctx.reply("Invalid URL"); - } - let resp: Response; - try { - resp = await fetch(new URL("config", url)); - } catch (err) { - return ctx.reply(`Could not connect: ${err}`); - } - if (!resp.ok) { - return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`); - } - let data: unknown; - try { - data = await resp.json(); - } catch { - return ctx.reply("Invalid response from API"); - } - if (data != null && typeof data === "object" && "version" in data) { - config.sdApiUrl = url.toString(); - return ctx.reply(`Now using SD at ${url} running version ${data.version}`); - } else { - return ctx.reply("Invalid response from API"); - } -}); - -bot.command("setsdparam", (ctx) => { - if (!ctx.from?.username) return; - const config = ctx.session.global; - if (!config.adminUsernames.includes(ctx.from.username)) return; - let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim()); - if (!param) return ctx.reply("Please specify a parameter"); - if (value == null) return ctx.reply("Please specify a value after the ="); - param = param.toLowerCase().replace(/[\s_]+/g, ""); - if (config.defaultParams == null) config.defaultParams = {}; - switch (param) { - case "steps": { - const steps = parseInt(value); - if (isNaN(steps)) return ctx.reply("Invalid number value"); - if (steps > 100) return ctx.reply("Steps must be less than 100"); - if (steps < 10) return ctx.reply("Steps must be greater than 10"); - config.defaultParams.steps = steps; - return ctx.reply("Steps set to " + steps); - } - case "detail": - case "cfgscale": { - const detail = parseInt(value); - if (isNaN(detail)) return ctx.reply("Invalid number value"); - if (detail > 20) return ctx.reply("Detail must be less than 20"); - if (detail < 1) return ctx.reply("Detail must be greater than 1"); - config.defaultParams.cfg_scale = detail; - return ctx.reply("Detail set to " + detail); - } - case "niter": - case "niters": { - const nIter = parseInt(value); - if (isNaN(nIter)) return ctx.reply("Invalid number value"); - if (nIter > 10) return ctx.reply("Iterations must be less than 10"); - if (nIter < 1) return ctx.reply("Iterations must be greater than 1"); - config.defaultParams.n_iter = nIter; - return ctx.reply("Iterations set to " + nIter); - } - case "batchsize": { - const batchSize = parseInt(value); - if (isNaN(batchSize)) return ctx.reply("Invalid number value"); - if (batchSize > 8) return ctx.reply("Batch size must be less than 8"); - if (batchSize < 1) return ctx.reply("Batch size must be greater than 1"); - config.defaultParams.batch_size = batchSize; - return ctx.reply("Batch size set to " + batchSize); - } - case "size": { - let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim())); - if (!width || !height || isNaN(width) || isNaN(height)) { - return ctx.reply("Invalid size value"); - } - if (width > 2048) return ctx.reply("Width must be less than 2048"); - if (height > 2048) return ctx.reply("Height must be less than 2048"); - // find closest multiple of 64 - width = Math.round(width / 64) * 64; - height = Math.round(height / 64) * 64; - if (width <= 0) return ctx.reply("Width too small"); - if (height <= 0) return ctx.reply("Height too small"); - config.defaultParams.width = width; - config.defaultParams.height = height; - return ctx.reply(`Size set to ${width}x${height}`); - } - case "negativeprompt": { - config.defaultParams.negative_prompt = value; - return ctx.reply(`Negative prompt set to: ${value}`); - } - default: { - return ctx.reply("Invalid parameter"); - } - } -}); - -bot.command("sdparams", (ctx) => { - if (!ctx.from?.username) return; - const config = ctx.session.global; - return ctx.replyFmt( - fmt`Current config:\n\n${ - Object.entries(config.defaultParams ?? {}).map(([key, value]) => - fmt`${bold(key)} = ${String(value)}\n` - ) - }`, - ); -}); - -bot.command("crash", () => { - throw new Error("Crash command used"); -}); - -bot.catch((err) => { - let msg = "Error processing update"; - const { from, chat } = err.ctx; - if (from?.first_name) msg += ` from ${from.first_name}`; - if (from?.last_name) msg += ` ${from.last_name}`; - if (from?.username) msg += ` (@${from.username})`; - if (chat?.type === "supergroup" || chat?.type === "group") { - msg += ` in ${chat.title}`; - if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`; - } - logger().error("handling update failed", from?.first_name, chat?.type, err); -}); diff --git a/bot/mod.ts b/bot/mod.ts new file mode 100644 index 0000000..9ef9106 --- /dev/null +++ b/bot/mod.ts @@ -0,0 +1,92 @@ +import { Grammy, GrammyAutoQuote, GrammyParseMode, Log } from "../deps.ts"; +import { formatUserChat } from "../utils.ts"; +import { session, SessionFlavor } from "./session.ts"; +import { queueCommand } from "./queueCommand.ts"; +import { txt2imgCommand } from "./txt2imgCommand.ts"; + +export const logger = () => Log.getLogger(); + +export type Context = GrammyParseMode.ParseModeFlavor & SessionFlavor; +export const bot = new Grammy.Bot(Deno.env.get("TG_BOT_TOKEN") ?? ""); +bot.use(GrammyAutoQuote.autoQuote); +bot.use(GrammyParseMode.hydrateReply); +bot.use(session); + +bot.catch((err) => { + logger().error(`Handling update from ${formatUserChat(err.ctx)} failed: ${err}`); +}); + +// Automatically retry bot requests if we get a "too many requests" or telegram internal error +bot.api.config.use(async (prev, method, payload, signal) => { + let attempt = 0; + while (true) { + attempt++; + const result = await prev(method, payload, signal); + if ( + result.ok || + ![429, 500].includes(result.error_code) || + attempt >= 5 + ) { + return result; + } + const retryAfterMs = (result.parameters?.retry_after ?? (attempt * 5)) * 1000; + await new Promise((resolve) => setTimeout(resolve, retryAfterMs)); + } +}); + +// if error happened, try to reply to the user with the error +bot.use(async (ctx, next) => { + try { + await next(); + } catch (err) { + try { + await ctx.reply(`Handling update failed: ${err}`, { + reply_to_message_id: ctx.message?.message_id, + }); + } catch { + throw err; + } + } +}); + +bot.api.setMyShortDescription("I can generate furry images from text"); +bot.api.setMyDescription( + "I can generate furry images from text. " + + "Send /txt2img to generate an image.", +); +bot.api.setMyCommands([ + { command: "txt2img", description: "Generate an image" }, + { command: "queue", description: "Show the current queue" }, +]); + +bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image")); + +bot.command("txt2img", txt2imgCommand); + +bot.command("queue", queueCommand); + +bot.command("pause", (ctx) => { + if (!ctx.from?.username) return; + const config = ctx.session.global; + if (!config.adminUsernames.includes(ctx.from.username)) return; + if (config.pausedReason != null) { + return ctx.reply(`Already paused: ${config.pausedReason}`); + } + config.pausedReason = ctx.match ?? "No reason given"; + logger().warning(`Bot paused by ${ctx.from.first_name} because ${config.pausedReason}`); + return ctx.reply("Paused"); +}); + +bot.command("resume", (ctx) => { + if (!ctx.from?.username) return; + const config = ctx.session.global; + if (!config.adminUsernames.includes(ctx.from.username)) return; + if (config.pausedReason == null) return ctx.reply("Already running"); + config.pausedReason = null; + logger().info(`Bot resumed by ${ctx.from.first_name}`); + return ctx.reply("Resumed"); +}); + +bot.command("crash", () => { + throw new Error("Crash command used"); +}); diff --git a/bot/queueCommand.ts b/bot/queueCommand.ts new file mode 100644 index 0000000..2eb490c --- /dev/null +++ b/bot/queueCommand.ts @@ -0,0 +1,67 @@ +import { Grammy, GrammyParseMode } from "../deps.ts"; +import { fmt, getFlagEmoji } from "../utils.ts"; +import { runningWorkers } from "../tasks/pingWorkers.ts"; +import { jobStore } from "../db/jobStore.ts"; +import { Context, logger } from "./mod.ts"; + +export async function queueCommand(ctx: Grammy.CommandContext) { + let formattedMessage = await getMessageText(); + const queueMessage = await ctx.replyFmt(formattedMessage); + handleFutureUpdates().catch((err) => logger().warning(`Updating queue message failed: ${err}`)); + + async function getMessageText() { + const processingJobs = await jobStore.getBy("status.type", "processing") + .then((jobs) => jobs.map((job) => ({ ...job.value, place: 0 }))); + const waitingJobs = await jobStore.getBy("status.type", "waiting") + .then((jobs) => jobs.map((job, index) => ({ ...job.value, place: index + 1 }))); + const jobs = [...processingJobs, ...waitingJobs]; + const config = ctx.session.global; + const { bold } = GrammyParseMode; + return fmt([ + "Current queue:\n", + ...jobs.length > 0 + ? jobs.flatMap((job) => [ + `${job.place}. `, + fmt`${bold(job.request.from.first_name)} `, + job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "", + job.request.from.username ? `(@${job.request.from.username}) ` : "", + getFlagEmoji(job.request.from.language_code) ?? "", + job.request.chat.type === "private" + ? " in private chat " + : ` in ${job.request.chat.title} `, + job.request.chat.type !== "private" && job.request.chat.type !== "group" && + job.request.chat.username + ? `(@${job.request.chat.username}) ` + : "", + job.status.type === "processing" + ? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) ` + : "", + "\n", + ]) + : ["Queue is empty.\n"], + "\nActive workers:\n", + ...config.workers.flatMap((worker) => [ + runningWorkers.has(worker.name) ? "✅ " : "☠️ ", + fmt`${bold(worker.name)} `, + `(max ${(worker.maxResolution / 1000000).toFixed(1)} Mpx) `, + "\n", + ]), + ]); + } + + async function handleFutureUpdates() { + for (let idx = 0; idx < 20; idx++) { + await new Promise((resolve) => setTimeout(resolve, 3000)); + const nextFormattedMessage = await getMessageText(); + if (nextFormattedMessage.text !== formattedMessage.text) { + await ctx.api.editMessageText( + ctx.chat.id, + queueMessage.message_id, + nextFormattedMessage.text, + { entities: nextFormattedMessage.entities }, + ); + formattedMessage = nextFormattedMessage; + } + } + } +} diff --git a/session.ts b/bot/session.ts similarity index 53% rename from session.ts rename to bot/session.ts index f34ef2a..2f17df1 100644 --- a/session.ts +++ b/bot/session.ts @@ -1,7 +1,7 @@ -import { Context, DenoKVAdapter, session, SessionFlavor } from "./deps.ts"; -import { SdTxt2ImgRequest } from "./sd.ts"; +import { Grammy, GrammyKvStorage } from "../deps.ts"; +import { SdApi, SdTxt2ImgRequest } from "../sd.ts"; -export type MySessionFlavor = SessionFlavor; +export type SessionFlavor = Grammy.SessionFlavor; export interface SessionData { global: GlobalData; @@ -12,45 +12,55 @@ export interface SessionData { export interface GlobalData { adminUsernames: string[]; pausedReason: string | null; - sdApiUrl: string; maxUserJobs: number; maxJobs: number; defaultParams?: Partial; + workers: WorkerData[]; +} + +export interface WorkerData { + name: string; + api: SdApi; + auth?: string; + maxResolution: number; } export interface ChatData { - language: string; + language?: string; } export interface UserData { - steps: number; - detail: number; - batchSize: number; + params?: Partial; } const globalDb = await Deno.openKv("./app.db"); -const globalDbAdapter = new DenoKVAdapter(globalDb); +const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter(globalDb); const getDefaultGlobalData = (): GlobalData => ({ - adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean), + adminUsernames: Deno.env.get("TG_ADMIN_USERS")?.split(",") ?? [], pausedReason: null, - sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/", maxUserJobs: 3, maxJobs: 20, defaultParams: { batch_size: 1, n_iter: 1, - width: 128 * 2, - height: 128 * 3, - steps: 20, - cfg_scale: 9, - send_images: true, + width: 512, + height: 768, + steps: 30, + cfg_scale: 10, negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4", }, + workers: [ + { + name: "local", + api: { url: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/" }, + maxResolution: 1024 * 1024, + }, + ], }); -export const mySession = session({ +export const session = Grammy.session({ type: "multi", global: { getSessionKey: () => "global", @@ -58,17 +68,11 @@ export const mySession = session({ storage: globalDbAdapter, }, chat: { - initial: () => ({ - language: "en", - }), + initial: () => ({}), }, user: { getSessionKey: (ctx) => ctx.from?.id.toFixed(), - initial: () => ({ - steps: 20, - detail: 8, - batchSize: 2, - }), + initial: () => ({}), }, }); diff --git a/bot/txt2imgCommand.ts b/bot/txt2imgCommand.ts new file mode 100644 index 0000000..40f205b --- /dev/null +++ b/bot/txt2imgCommand.ts @@ -0,0 +1,39 @@ +import { Grammy } from "../deps.ts"; +import { formatUserChat } from "../utils.ts"; +import { jobStore } from "../db/jobStore.ts"; +import { parsePngInfo } from "../sd.ts"; +import { Context, logger } from "./mod.ts"; + +export async function txt2imgCommand(ctx: Grammy.CommandContext) { + if (!ctx.from?.id) { + return ctx.reply("I don't know who you are"); + } + const config = ctx.session.global; + if (config.pausedReason != null) { + return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); + } + const jobs = await jobStore.getBy("status.type", "waiting"); + if (jobs.length >= config.maxJobs) { + return ctx.reply( + `The queue is full. Try again later. (Max queue size: ${config.maxJobs})`, + ); + } + const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.from?.id); + if (userJobs.length >= config.maxUserJobs) { + return ctx.reply( + `You already have ${config.maxUserJobs} jobs in queue. Try again later.`, + ); + } + const params = parsePngInfo(ctx.match); + if (!params.prompt) { + return ctx.reply("Please describe what you want to see after the command"); + } + const reply = await ctx.reply("Accepted. You are now in queue."); + await jobStore.create({ + params, + request: ctx.message, + reply, + status: { type: "waiting" }, + }); + logger().debug(`Job enqueued for ${formatUserChat(ctx)}`); +} diff --git a/db/db.ts b/db/db.ts new file mode 100644 index 0000000..032772a --- /dev/null +++ b/db/db.ts @@ -0,0 +1 @@ +export const db = await Deno.openKv("./app.db"); diff --git a/db/jobStore.ts b/db/jobStore.ts new file mode 100644 index 0000000..6b74717 --- /dev/null +++ b/db/jobStore.ts @@ -0,0 +1,18 @@ +import { GrammyTypes, IKV } from "../deps.ts"; +import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts"; +import { db } from "./db.ts"; + +export interface JobSchema { + params: Partial; + request: GrammyTypes.Message.TextMessage & { from: GrammyTypes.User }; + reply?: GrammyTypes.Message.TextMessage; + status: + | { type: "waiting" } + | { type: "processing"; progress: number; worker: string; updatedDate: Date } + | { type: "done"; info?: SdTxt2ImgInfo; startDate?: Date; endDate?: Date }; +} + +export const jobStore = new IKV.Store(db, "job", { + schema: new IKV.Schema(), + indices: ["status.type"], +}); diff --git a/deps.ts b/deps.ts index 03320d9..2a15398 100644 --- a/deps.ts +++ b/deps.ts @@ -1,7 +1,18 @@ -export * from "https://deno.land/x/grammy@v1.18.1/mod.ts"; -export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; -export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; -export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts"; -export * as types from "https://deno.land/x/grammy_types@v3.2.0/mod.ts"; -export * from "https://deno.land/x/ulid@v0.3.0/mod.ts"; -export * as log from "https://deno.land/std@0.201.0/log/mod.ts"; +export * as Log from "https://deno.land/std@0.201.0/log/mod.ts"; +export * as Async from "https://deno.land/std@0.201.0/async/mod.ts"; +export * as FmtDuration from "https://deno.land/std@0.201.0/fmt/duration.ts"; +export * as Collections from "https://deno.land/std@0.201.0/collections/mod.ts"; +export * as Base64 from "https://deno.land/std@0.201.0/encoding/base64.ts"; +export * as AsyncX from "https://deno.land/x/async@v2.0.2/mod.ts"; +export * as ULID from "https://deno.land/x/ulid@v0.3.0/mod.ts"; +export * as IKV from "https://deno.land/x/indexed_kv@v0.2.0/mod.ts"; +export * as Grammy from "https://deno.land/x/grammy@v1.18.1/mod.ts"; +export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts"; +export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; +export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; +export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts"; +export * as FileType from "npm:file-type@18.5.0"; +// @deno-types="./types/png-chunks-extract.d.ts" +export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0"; +// @deno-types="./types/png-chunk-text.d.ts" +export * as PngChunkText from "npm:png-chunk-text@1.0.0"; diff --git a/intl.ts b/intl.ts deleted file mode 100644 index fb943fe..0000000 --- a/intl.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { FormattedString } from "./deps.ts"; - -export function formatOrdinal(n: number) { - if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`; - if (n % 10 === 1) return `${n}st`; - if (n % 10 === 2) return `${n}nd`; - if (n % 10 === 3) return `${n}rd`; - return `${n}th`; -} - -type DeepArray = Array>; -type StringLikes = DeepArray; - -/** - * Like `fmt` from `grammy_parse_mode` but additionally accepts arrays. - * @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182 - */ -export const fmt = ( - rawStringParts: TemplateStringsArray | StringLikes, - ...stringLikes: StringLikes -): FormattedString => { - let text = ""; - const entities: ConstructorParameters[1][] = []; - - const length = Math.max(rawStringParts.length, stringLikes.length); - for (let i = 0; i < length; i++) { - for (let stringLike of [rawStringParts[i], stringLikes[i]]) { - if (Array.isArray(stringLike)) { - stringLike = fmt(stringLike); - } - if (stringLike instanceof FormattedString) { - entities.push( - ...stringLike.entities.map((e) => ({ - ...e, - offset: e.offset + text.length, - })), - ); - } - if (stringLike != null) text += stringLike.toString(); - } - } - return new FormattedString(text, entities); -}; diff --git a/main.ts b/main.ts index 55df7d3..1853539 100644 --- a/main.ts +++ b/main.ts @@ -1,24 +1,21 @@ +// Load environment variables from .env file import "https://deno.land/std@0.201.0/dotenv/load.ts"; -import { bot } from "./bot.ts"; -import { processQueue, returnHangedJobs } from "./queue.ts"; -import { log } from "./deps.ts"; -log.setup({ +// Setup logging +import { Log } from "./deps.ts"; +Log.setup({ handlers: { - console: new log.handlers.ConsoleHandler("INFO", { - formatter: (record) => - `[${record.levelName}] ${record.msg} ${ - record.args.map((arg) => JSON.stringify(arg)).join(" ") - } (${record.datetime.toISOString()})`, - }), + console: new Log.handlers.ConsoleHandler("DEBUG"), }, loggers: { - default: { level: "INFO", handlers: ["console"] }, + default: { level: "DEBUG", handlers: ["console"] }, }, }); +// Main program logic +import { bot } from "./bot/mod.ts"; +import { runAllTasks } from "./tasks/mod.ts"; await Promise.all([ bot.start(), - processQueue(), - returnHangedJobs(), + runAllTasks(), ]); diff --git a/mimeType.ts b/mimeType.ts deleted file mode 100644 index 2d61931..0000000 --- a/mimeType.ts +++ /dev/null @@ -1,15 +0,0 @@ -export function mimeTypeFromBase64(base64: string) { - if (base64.startsWith("/9j/")) return "image/jpeg"; - if (base64.startsWith("iVBORw0KGgo")) return "image/png"; - if (base64.startsWith("R0lGODlh")) return "image/gif"; - if (base64.startsWith("UklGRg")) return "image/webp"; - throw new Error("Unknown image type"); -} - -export function extFromMimeType(mimeType: string) { - if (mimeType === "image/jpeg") return "jpg"; - if (mimeType === "image/png") return "png"; - if (mimeType === "image/gif") return "gif"; - if (mimeType === "image/webp") return "webp"; - throw new Error("Unknown image type"); -} diff --git a/queue.ts b/queue.ts deleted file mode 100644 index 04d9c61..0000000 --- a/queue.ts +++ /dev/null @@ -1,181 +0,0 @@ -import { InputFile, InputMediaBuilder, log, types } from "./deps.ts"; -import { bot } from "./bot.ts"; -import { getGlobalSession } from "./session.ts"; -import { formatOrdinal } from "./intl.ts"; -import { SdTxt2ImgRequest, SdTxt2ImgResponse, txt2img } from "./sd.ts"; -import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts"; -import { Model, Schema, Store } from "./store.ts"; - -const logger = () => log.getLogger(); - -interface Job { - params: Partial; - user: types.User; - chat: types.Chat.PrivateChat | types.Chat.GroupChat | types.Chat.SupergroupChat; - requestMessage: types.Message & types.Message.TextMessage; - statusMessage?: types.Message & types.Message.TextMessage; - status: - | { type: "idle" } - | { type: "processing"; progress: number; updatedDate: Date }; -} - -const db = await Deno.openKv("./app.db"); - -const jobStore = new Store(db, "job", { - schema: new Schema(), - indices: ["status.type", "user.id", "chat.id"], -}); - -jobStore.getBy("user.id", 123).then(() => {}); - -export async function pushJob(job: Job) { - await jobStore.create(job); -} - -async function takeJob(): Promise | null> { - const jobs = await jobStore.getAll(); - const job = jobs.find((job) => job.value.status.type === "idle"); - if (!job) return null; - await job.update({ status: { type: "processing", progress: 0, updatedDate: new Date() } }); - return job; -} - -export async function getAllJobs(): Promise> { - return await jobStore.getAll().then((jobs) => jobs.map((job) => job.value)); -} - -export async function processQueue() { - while (true) { - const job = await takeJob().catch((err) => - void logger().warning("failed getting job", err.message) - ); - if (!job) { - await new Promise((resolve) => setTimeout(resolve, 1000)); - continue; - } - let place = 0; - for (const job of await jobStore.getAll().catch(() => [])) { - if (job.value.status.type === "idle") place += 1; - if (place === 0) continue; - const statusMessageText = `You are ${formatOrdinal(place)} in queue.`; - if (!job.value.statusMessage) { - await bot.api.sendMessage(job.value.chat.id, statusMessageText, { - reply_to_message_id: job.value.requestMessage.message_id, - }).catch(() => undefined) - .then((message) => job.update({ statusMessage: message })).catch(() => undefined); - } else { - await bot.api.editMessageText( - job.value.chat.id, - job.value.statusMessage.message_id, - statusMessageText, - ) - .catch(() => undefined); - } - } - try { - if (job.value.statusMessage) { - await bot.api - .deleteMessage(job.value.chat.id, job.value.statusMessage?.message_id) - .catch(() => undefined) - .then(() => job.update({ statusMessage: undefined })); - } - await bot.api.sendMessage( - job.value.chat.id, - "Generating your prompt now...", - { reply_to_message_id: job.value.requestMessage.message_id }, - ).then((message) => job.update({ statusMessage: message })); - const config = await getGlobalSession(); - const response = await txt2img( - config.sdApiUrl, - { ...config.defaultParams, ...job.value.params }, - (progress) => { - job.update({ - status: { type: "processing", progress: progress.progress, updatedDate: new Date() }, - }); - if (job.value.statusMessage) { - bot.api - .editMessageText( - job.value.chat.id, - job.value.statusMessage.message_id, - `Generating your prompt now... ${ - Math.round( - progress.progress * 100, - ) - }%`, - ) - .catch(() => undefined); - } - }, - ); - const jobCount = (await jobStore.getAll()).filter((job) => - job.value.status.type !== "processing" - ).length; - logger().info("Job finished", job.value.user.first_name, job.value.chat.type, { jobCount }); - if (job.value.statusMessage) { - await bot.api.editMessageText( - job.value.chat.id, - job.value.statusMessage.message_id, - `Uploading your images...`, - ).catch(() => undefined); - } - const inputFiles = await Promise.all( - response.images.map(async (imageBase64, idx) => { - const mimeType = mimeTypeFromBase64(imageBase64); - const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`) - .then((resp) => resp.blob()); - return InputMediaBuilder.photo( - new InputFile(imageBlob, `image_${idx}.${extFromMimeType(mimeType)}`), - ); - }), - ); - if (job.value.statusMessage) { - await bot.api - .deleteMessage(job.value.chat.id, job.value.statusMessage.message_id) - .catch(() => undefined).then(() => job.update({ statusMessage: undefined })); - } - await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, { - reply_to_message_id: job.value.requestMessage.message_id, - }); - await job.delete(); - } catch (err) { - logger().error("Job failed", job.value.user.first_name, job.value.chat.type, err); - const errorMessage = await bot.api - .sendMessage(job.value.chat.id, err.toString(), { - reply_to_message_id: job.value.requestMessage.message_id, - }) - .catch(() => undefined); - if (errorMessage) { - if (job.value.statusMessage) { - await bot.api - .deleteMessage(job.value.chat.id, job.value.statusMessage.message_id) - .then(() => job.update({ statusMessage: undefined })) - .catch(() => void logger().warning("failed deleting status message", err.message)); - } - await job.update({ status: { type: "idle" } }).catch((err) => - void logger().warning("failed returning job", err.message) - ); - } else { - await job.delete().catch((err) => - void logger().warning("failed deleting job", err.message) - ); - } - } - } -} - -export async function returnHangedJobs() { - while (true) { - await new Promise((resolve) => setTimeout(resolve, 5000)); - const jobs = await jobStore.getAll().catch(() => []); - for (const job of jobs) { - if (job.value.status.type !== "processing") continue; - // if job wasn't updated for 1 minute, return it to the queue - if (job.value.status.updatedDate.getTime() < Date.now() - 60 * 1000) { - logger().warning("Hanged job returned", job.value.user.first_name, job.value.chat.type); - await job.update({ status: { type: "idle" } }).catch((err) => - void logger().warning("failed returning job", err.message) - ); - } - } - } -} diff --git a/sd.ts b/sd.ts index ff1ce30..e6fa570 100644 --- a/sd.ts +++ b/sd.ts @@ -1,53 +1,88 @@ -export async function txt2img( - apiUrl: string, +import { Async, AsyncX, PngChunksExtract, PngChunkText } from "./deps.ts"; + +const neverSignal = new AbortController().signal; + +export interface SdApi { + url: string; + auth?: string; +} + +async function fetchSdApi(api: SdApi, endpoint: string, body?: unknown): Promise { + let options: RequestInit | undefined; + if (body != null) { + options = { + method: "POST", + headers: { + "Content-Type": "application/json", + ...api.auth ? { Authorization: api.auth } : {}, + }, + body: JSON.stringify(body), + }; + } else if (api.auth) { + options = { + headers: { Authorization: api.auth }, + }; + } + const response = await fetch(new URL(endpoint, api.url), options).catch(() => { + throw new SdApiError(endpoint, options, 0, "Network error"); + }); + const result = await response.json().catch(() => { + throw new SdApiError(endpoint, options, response.status, response.statusText, { + detail: "Invalid JSON", + }); + }); + if (!response.ok) { + throw new SdApiError(endpoint, options, response.status, response.statusText, result); + } + return result; +} + +export async function sdTxt2Img( + api: SdApi, params: Partial, onProgress?: (progress: SdProgressResponse) => void, - signal?: AbortSignal, + signal: AbortSignal = neverSignal, ): Promise { - let response: Response | undefined; - let error: unknown; - - fetch(new URL("sdapi/v1/txt2img", apiUrl), { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(params), - }).then( - (resp) => (response = resp), - (err) => (error = err), - ); + const request = fetchSdApi(api, "sdapi/v1/txt2img", params) + // JSON field "info" is a JSON-serialized string so we need to parse this part second time + .then((data) => ({ + ...data, + info: typeof data.info === "string" ? JSON.parse(data.info) : data.info, + })); try { while (true) { - await new Promise((resolve) => setTimeout(resolve, 3000)); - const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl)); - if (progressRequest.ok) { - const progress = (await progressRequest.json()) as SdProgressResponse; - onProgress?.(progress); - } - if (response != null) { - if (response.ok) { - const result = (await response.json()) as SdTxt2ImgResponse; - return result; - } else { - throw new Error(`Request failed: ${response.status} ${response.statusText}`); - } - } - if (error != null) { - throw error; - } - signal?.throwIfAborted(); + await Async.abortable(Promise.race([request, Async.delay(3000)]), signal); + if (await AsyncX.promiseState(request) !== "pending") return await request; + onProgress?.(await fetchSdApi(api, "sdapi/v1/progress")); } } finally { - if (!response && !error) { - await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" }); + if (await AsyncX.promiseState(request) === "pending") { + await fetchSdApi(api, "sdapi/v1/interrupt", {}); } } } export interface SdTxt2ImgRequest { + enable_hr: boolean; denoising_strength: number; + firstphase_width: number; + firstphase_height: number; + hr_scale: number; + hr_upscaler: unknown; + hr_second_pass_steps: number; + hr_resize_x: number; + hr_resize_y: number; + hr_sampler_name: unknown; + hr_prompt: string; + hr_negative_prompt: string; prompt: string; + styles: unknown; seed: number; + subseed: number; + subseed_strength: number; + seed_resize_from_h: number; + seed_resize_from_w: number; sampler_name: unknown; batch_size: number; n_iter: number; @@ -55,16 +90,68 @@ export interface SdTxt2ImgRequest { cfg_scale: number; width: number; height: number; + restore_faces: boolean; + tiling: boolean; + do_not_save_samples: boolean; + do_not_save_grid: boolean; negative_prompt: string; + eta: unknown; + s_min_uncond: number; + s_churn: number; + s_tmax: unknown; + s_tmin: number; + s_noise: number; + override_settings: object; + override_settings_restore_afterwards: boolean; + script_args: unknown[]; + sampler_index: string; + script_name: unknown; send_images: boolean; save_images: boolean; + alwayson_scripts: object; } export interface SdTxt2ImgResponse { images: string[]; parameters: SdTxt2ImgRequest; - /** Contains serialized JSON */ - info: string; + // Warning: raw response from API is a JSON-serialized string + info: SdTxt2ImgInfo; +} + +export interface SdTxt2ImgInfo { + prompt: string; + all_prompts: string[]; + negative_prompt: string; + all_negative_prompts: string[]; + seed: number; + all_seeds: number[]; + subseed: number; + all_subseeds: number[]; + subseed_strength: number; + width: number; + height: number; + sampler_name: string; + cfg_scale: number; + steps: number; + batch_size: number; + restore_faces: boolean; + face_restoration_model: unknown; + sd_model_hash: string; + seed_resize_from_w: number; + seed_resize_from_h: number; + denoising_strength: number; + extra_generation_params: SdTxt2ImgInfoExtraParams; + index_of_first_image: number; + infotexts: string[]; + styles: unknown[]; + job_timestamp: string; + clip_skip: number; + is_using_inpainting_conditioning: boolean; +} + +export interface SdTxt2ImgInfoExtraParams { + "Lora hashes": string; + "TI hashes": string; } export interface SdProgressResponse { @@ -86,3 +173,138 @@ export interface SdProgressState { sampling_step: number; sampling_steps: number; } + +export function sdGetConfig(api: SdApi): Promise { + return fetchSdApi(api, "config"); +} + +export interface SdConfigResponse { + /** version with new line at the end for some reason */ + version: string; + mode: string; + dev_mode: boolean; + analytics_enabled: boolean; + components: object[]; + css: unknown; + title: string; + is_space: boolean; + enable_queue: boolean; + show_error: boolean; + show_api: boolean; + is_colab: boolean; + stylesheets: unknown[]; + theme: string; + layout: object; + dependencies: object[]; + root: string; +} + +export interface SdErrorResponse { + /** + * The HTTP status message or array of invalid fields. + * Can also be empty string. + */ + detail: string | Array<{ loc: string[]; msg: string; type: string }>; + /** Can be e.g. "OutOfMemoryError" or undefined. */ + error?: string; + /** Empty string. */ + body?: string; + /** Long description of error. */ + errors?: string; +} + +export class SdApiError extends Error { + constructor( + public readonly endpoint: string, + public readonly options: RequestInit | undefined, + public readonly statusCode: number, + public readonly statusText: string, + public readonly response?: SdErrorResponse, + ) { + let message = `${options?.method ?? "GET"} ${endpoint} : ${statusCode} ${statusText}`; + if (response?.error) { + message += `: ${response.error}`; + if (response.errors) message += ` - ${response.errors}`; + } else if (typeof response?.detail === "string" && response.detail.length > 0) { + message += `: ${response.detail}`; + } else if (response?.detail) { + message += `: ${JSON.stringify(response.detail)}`; + } + super(message); + } +} + +export function getPngInfo(pngData: Uint8Array): string | undefined { + return PngChunksExtract.default(pngData) + .filter((chunk) => chunk.name === "tEXt") + .map((chunk) => PngChunkText.decode(chunk.data)) + .find((textChunk) => textChunk.keyword === "parameters") + ?.text; +} + +export function parsePngInfo(pngInfo: string): Partial { + const tags = pngInfo.split(/[,;]+|\.+\s|\n/u); + let part: "prompt" | "negative_prompt" | "params" = "prompt"; + const params: Partial = {}; + const prompt: string[] = []; + const negativePrompt: string[] = []; + for (const tag of tags) { + const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+([\d\w. ]+)\s*$/u); + if (paramValuePair) { + const [, param, value] = paramValuePair; + switch (param.replace(/\s+/u, "").toLowerCase()) { + case "prompt": + part = "prompt"; + prompt.push(value.trim()); + break; + case "negativeprompt": + part = "negative_prompt"; + negativePrompt.push(value.trim()); + break; + case "steps": + case "cycles": { + part = "params"; + const steps = Number(value.trim()); + if (steps > 0) params.steps = Math.min(steps, 50); + break; + } + case "cfgscale": + case "detail": { + part = "params"; + const cfgScale = Number(value.trim()); + if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20); + break; + } + case "size": + case "resolution": { + part = "params"; + const [width, height] = value.trim() + .split(/\s*[x,]\s*/u, 2) + .map((v) => v.trim()) + .map(Number); + if (width > 0 && height > 0) { + params.width = Math.min(width, 2048); + params.height = Math.min(height, 2048); + } + break; + } + default: + break; + } + } else if (tag.trim().length > 0) { + switch (part) { + case "prompt": + prompt.push(tag.trim()); + break; + case "negative_prompt": + negativePrompt.push(tag.trim()); + break; + default: + break; + } + } + } + if (prompt.length > 0) params.prompt = prompt.join(", "); + if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", "); + return params; +} diff --git a/store.test.ts b/store.test.ts deleted file mode 100644 index 64a09f6..0000000 --- a/store.test.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { assert } from "https://deno.land/std@0.198.0/assert/assert.ts"; -import { Schema, Store } from "./store.ts"; -import { log } from "./deps.ts"; - -const db = await Deno.openKv(); - -log.setup({ - handlers: { - console: new log.handlers.ConsoleHandler("DEBUG", {}), - }, - loggers: { - kvStore: { level: "DEBUG", handlers: ["console"] }, - }, -}); - -interface PointSchema { - x: number; - y: number; -} - -interface JobSchema { - name: string; - params: { - a: number; - b: number | null; - }; - status: { type: "idle" } | { type: "processing"; progress: number } | { type: "done" }; - lastUpdateDate: Date; -} - -const pointStore = new Store(db, "points", { - schema: new Schema(), - indices: ["x", "y"], -}); -const jobStore = new Store(db, "jobs", { - schema: new Schema(), - indices: ["name", "status.type"], -}); - -Deno.test("create and delete", async () => { - await pointStore.deleteAll(); - const point1 = await pointStore.create({ x: 1, y: 2 }); - const point2 = await pointStore.create({ x: 3, y: 4 }); - assert((await pointStore.getAll()).length === 2); - const point3 = await pointStore.create({ x: 5, y: 6 }); - assert((await pointStore.getAll()).length === 3); - assert((await pointStore.get(point2.id))?.value.y === 4); - await point1.delete(); - assert((await pointStore.getAll()).length === 2); - await point2.delete(); - await point3.delete(); - assert((await pointStore.getAll()).length === 0); -}); - -Deno.test("list by index", async () => { - await jobStore.deleteAll(); - - const test = await jobStore.create({ - name: "test", - params: { a: 1, b: null }, - status: { type: "idle" }, - lastUpdateDate: new Date(), - }); - assert((await jobStore.getBy("name", "test"))[0].value.params.a === 1); - assert((await jobStore.getBy("status.type", "idle"))[0].value.params.a === 1); - - await test.update({ status: { type: "processing", progress: 33 } }); - assert((await jobStore.getBy("status.type", "processing"))[0].value.params.a === 1); - - await test.update({ status: { type: "done" } }); - assert((await jobStore.getBy("status.type", "done"))[0].value.params.a === 1); - assert((await jobStore.getBy("status.type", "processing")).length === 0); - - await test.delete(); - assert((await jobStore.getBy("status.type", "done")).length === 0); - assert((await jobStore.getBy("name", "test")).length === 0); -}); - -Deno.test("fail on concurrent update", async () => { - await jobStore.deleteAll(); - - const test = await jobStore.create({ - name: "test", - params: { a: 1, b: null }, - status: { type: "idle" }, - lastUpdateDate: new Date(), - }); - - const result = await Promise.all([ - test.update({ status: { type: "processing", progress: 33 } }), - test.update({ status: { type: "done" } }), - ]).catch(() => true); - assert(result === true); - - await test.delete(); -}); diff --git a/store.ts b/store.ts deleted file mode 100644 index 6faf668..0000000 --- a/store.ts +++ /dev/null @@ -1,241 +0,0 @@ -import { log, ulid } from "./deps.ts"; - -const logger = () => log.getLogger("kvStore"); - -export type validIndexKey = { - [K in keyof T]: K extends string ? (T[K] extends Deno.KvKeyPart ? K - : T[K] extends readonly unknown[] ? never - : T[K] extends object ? `${K}.${validIndexKey}` - : never) - : never; -}[keyof T]; - -export type indexValue> = I extends `${infer K}.${infer Rest}` - ? K extends keyof T ? Rest extends validIndexKey ? indexValue - : never - : never - : I extends keyof T ? T[I] - : never; - -export class Schema {} - -interface StoreOptions { - readonly schema: Schema; - readonly indices: readonly I[]; -} - -export class Store> { - readonly #db: Deno.Kv; - readonly #key: Deno.KvKeyPart; - readonly #indices: readonly I[]; - - constructor(db: Deno.Kv, key: Deno.KvKeyPart, options: StoreOptions) { - this.#db = db; - this.#key = key; - this.#indices = options.indices; - } - - async create(value: T): Promise> { - const id = ulid(); - await this.#db.set([this.#key, "id", id], value); - logger().debug(["created", this.#key, "id", id].join(" ")); - for (const index of this.#indices) { - const indexValue: Deno.KvKeyPart = index - .split(".") - .reduce((value, key) => value[key], value as any); - await this.#db.set([this.#key, index, indexValue, id], value); - logger().debug(["created", this.#key, index, indexValue, id].join(" ")); - } - return new Model(this.#db, this.#key, this.#indices, id, value); - } - - async get(id: Deno.KvKeyPart): Promise | null> { - const entry = await this.#db.get([this.#key, "id", id]); - if (entry.versionstamp == null) return null; - return new Model(this.#db, this.#key, this.#indices, id, entry.value); - } - - async getBy( - index: J, - value: indexValue, - options?: Deno.KvListOptions, - ): Promise>> { - const models: Model[] = []; - for await ( - const entry of this.#db.list( - { prefix: [this.#key, index, value as Deno.KvKeyPart] }, - options, - ) - ) { - models.push(new Model(this.#db, this.#key, this.#indices, entry.key[3], entry.value)); - } - return models; - } - - async getAll( - opts?: { limit?: number; reverse?: boolean }, - ): Promise>> { - const { limit, reverse } = opts ?? {}; - const models: Array> = []; - for await ( - const entry of this.#db.list({ - prefix: [this.#key, "id"], - }, { limit, reverse }) - ) { - models.push(new Model(this.#db, this.#key, this.#indices, entry.key[2], entry.value)); - } - return models; - } - - async deleteAll(): Promise { - for await (const entry of this.#db.list({ prefix: [this.#key] })) { - await this.#db.delete(entry.key); - logger().debug(["deleted", ...entry.key].join(" ")); - } - } -} - -export class Model { - readonly #db: Deno.Kv; - readonly #key: Deno.KvKeyPart; - readonly #indices: readonly string[]; - readonly #id: Deno.KvKeyPart; - value: T; - - constructor( - db: Deno.Kv, - key: Deno.KvKeyPart, - indices: readonly string[], - id: Deno.KvKeyPart, - value: T, - ) { - this.#db = db; - this.#key = key; - this.#indices = indices; - this.#id = id; - this.value = value; - } - - get id(): Deno.KvKeyPart { - return this.#id; - } - - async update(updater: Partial | ((value: T) => T)): Promise { - // get current main entry - const oldEntry = await this.#db.get([this.#key, "id", this.#id]); - - // get all current index entries - const oldIndexEntries: Record> = {}; - for (const index of this.#indices) { - const indexKey: Deno.KvKeyPart = index - .split(".") - .reduce((value, key) => value[key], oldEntry.value as any); - oldIndexEntries[index] = await this.#db.get([this.#key, index, indexKey, this.#id]); - } - - // compute new value - if (typeof updater === "function") { - this.value = updater(this.value); - } else { - this.value = { ...this.value, ...updater }; - } - - // begin transaction - const transaction = this.#db.atomic(); - - // set the main entry - transaction - .check(oldEntry) - .set([this.#key, "id", this.#id], this.value); - logger().debug(["updated", this.#key, "id", this.#id].join(" ")); - - // delete and create all changed index entries - for (const index of this.#indices) { - const oldIndexKey: Deno.KvKeyPart = index - .split(".") - .reduce((value, key) => value[key], oldIndexEntries[index].value as any); - const newIndexKey: Deno.KvKeyPart = index - .split(".") - .reduce((value, key) => value[key], this.value as any); - if (newIndexKey !== oldIndexKey) { - transaction - .check(oldIndexEntries[index]) - .delete([this.#key, index, oldIndexKey, this.#id]) - .set([this.#key, index, newIndexKey, this.#id], this.value); - logger().debug(["deleted", this.#key, index, oldIndexKey, this.#id].join(" ")); - logger().debug(["created", this.#key, index, newIndexKey, this.#id].join(" ")); - } - } - - // commit - const result = await transaction.commit(); - if (!result.ok) throw new Error(`Failed to update ${this.#key} ${this.#id}`); - return this.value; - } - - async delete(): Promise { - // get current main entry - const entry = await this.#db.get([this.#key, "id", this.#id]); - - // begin transaction - const transaction = this.#db.atomic(); - - // delete main entry - transaction - .check(entry) - .delete([this.#key, "id", this.#id]); - logger().debug(["deleted", this.#key, "id", this.#id].join(" ")); - - // delete all index entries - for (const index of this.#indices) { - const indexKey: Deno.KvKeyPart = index - .split(".") - .reduce((value, key) => value[key], entry.value as any); - transaction - .delete([this.#key, index, indexKey, this.#id]); - logger().debug(["deleted", this.#key, index, indexKey, this.#id].join(" ")); - } - - // commit - const result = await transaction.commit(); - if (!result.ok) throw new Error(`Failed to delete ${this.#key} ${this.#id}`); - } -} - -export async function retry( - fn: () => Promise, - options: { maxAttempts?: number; delayMs?: number } = {}, -): Promise { - const { maxAttempts = 3, delayMs = 1000 } = options; - let error: unknown; - for (let attempt = 0; attempt < maxAttempts; attempt++) { - try { - return await fn(); - } catch (err) { - error = err; - await new Promise((resolve) => setTimeout(resolve, delayMs)); - } - } - throw error; -} - -export async function collectIterator( - iterator: AsyncIterableIterator, - options: { maxItems?: number; timeoutMs?: number } = {}, -): Promise { - const { maxItems = 1000, timeoutMs = 2000 } = options; - const result: T[] = []; - const timeout = setTimeout(() => iterator.return?.(), timeoutMs); - try { - for await (const item of iterator) { - result.push(item); - if (result.length >= maxItems) { - iterator.return?.(); - break; - } - } - } finally { - clearTimeout(timeout); - } - return result; -} diff --git a/tasks/mod.ts b/tasks/mod.ts new file mode 100644 index 0000000..1f26487 --- /dev/null +++ b/tasks/mod.ts @@ -0,0 +1,13 @@ +import { pingWorkers } from "./pingWorkers.ts"; +import { processJobs } from "./processJobs.ts"; +import { returnHangedJobs } from "./returnHangedJobs.ts"; +import { updateJobStatusMsgs } from "./updateJobStatusMsgs.ts"; + +export async function runAllTasks() { + await Promise.all([ + processJobs(), + updateJobStatusMsgs(), + returnHangedJobs(), + pingWorkers(), + ]); +} diff --git a/tasks/pingWorkers.ts b/tasks/pingWorkers.ts new file mode 100644 index 0000000..cee5e6f --- /dev/null +++ b/tasks/pingWorkers.ts @@ -0,0 +1,32 @@ +import { Async, Log } from "../deps.ts"; +import { getGlobalSession } from "../bot/session.ts"; +import { sdGetConfig } from "../sd.ts"; + +const logger = () => Log.getLogger(); + +export const runningWorkers = new Set(); + +/** + * Periodically ping the workers to see if they are alive. + */ +export async function pingWorkers(): Promise { + while (true) { + try { + const config = await getGlobalSession(); + for (const worker of config.workers) { + const status = await sdGetConfig(worker.api).catch(() => null); + const wasRunning = runningWorkers.has(worker.name); + if (status) { + runningWorkers.add(worker.name); + if (!wasRunning) logger().info(`Worker ${worker.name} is online`); + } else { + runningWorkers.delete(worker.name); + if (wasRunning) logger().warning(`Worker ${worker.name} went offline`); + } + } + await Async.delay(60 * 1000); + } catch (err) { + logger().warning(`Pinging workers failed: ${err}`); + } + } +} diff --git a/tasks/processJobs.ts b/tasks/processJobs.ts new file mode 100644 index 0000000..46f33c3 --- /dev/null +++ b/tasks/processJobs.ts @@ -0,0 +1,221 @@ +import { Base64, FileType, FmtDuration, Grammy, GrammyParseMode, IKV, Log } from "../deps.ts"; +import { bot } from "../bot/mod.ts"; +import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts"; +import { fmt, formatUserChat } from "../utils.ts"; +import { SdApiError, sdTxt2Img } from "../sd.ts"; +import { JobSchema, jobStore } from "../db/jobStore.ts"; +import { runningWorkers } from "./pingWorkers.ts"; + +const logger = () => Log.getLogger(); + +/** + * Sends waiting jobs to workers. + */ +export async function processJobs(): Promise { + const busyWorkers = new Set(); + while (true) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + + try { + // get first waiting job + const job = await jobStore.getBy("status.type", "waiting").then((jobs) => jobs[0]); + if (!job) continue; + + // find a worker to handle the job + const config = await getGlobalSession(); + const worker = config.workers?.find((worker) => + runningWorkers.has(worker.name) && + !busyWorkers.has(worker.name) + ); + if (!worker) continue; + + // process the job + await job.update({ + status: { type: "processing", progress: 0, worker: worker.name, updatedDate: new Date() }, + }); + + busyWorkers.add(worker.name); + processJob(job, worker, config) + .catch(async (err) => { + logger().error( + `Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`, + ); + if (err instanceof Grammy.GrammyError || err instanceof SdApiError) { + await bot.api.sendMessage( + job.value.request.chat.id, + `Failed to generate your prompt: ${err.message}`, + { reply_to_message_id: job.value.request.message_id }, + ).catch(() => undefined); + await job.update({ status: { type: "waiting" } }).catch(() => undefined); + } + if ( + err instanceof SdApiError && + (err.statusCode === 0 /* Network error */ || err.statusCode === 404) + ) { + runningWorkers.delete(worker.name); + logger().warning( + `Worker ${worker.name} was marked as offline because of network error`, + ); + } + await job.delete().catch(() => undefined); + await jobStore.create(job.value); + }) + .finally(() => busyWorkers.delete(worker.name)); + } catch (err) { + logger().warning(`Processing jobs failed: ${err}`); + } + } +} + +async function processJob(job: IKV.Model, worker: WorkerData, config: GlobalData) { + logger().debug( + `Job started for ${formatUserChat(job.value.request)} using ${worker.name}`, + ); + const startDate = new Date(); + + // if there is already a status message delete it + if (job.value.reply) { + await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id) + .catch(() => undefined); + } + + // send a new status message + const newStatusMessage = await bot.api.sendMessage( + job.value.request.chat.id, + `Generating your prompt now... 0% using ${worker.name}`, + { reply_to_message_id: job.value.request.message_id }, + ).catch((err) => { + // don't error if the request message was deleted + if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null; + else throw err; + }); + // if the request message was deleted, cancel the job + if (!newStatusMessage) { + await job.delete(); + logger().info( + `Job cancelled for ${formatUserChat(job.value.request)}`, + ); + return; + } + await job.update({ reply: newStatusMessage }); + + // reduce size if worker can't handle the resolution + const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution); + + // process the job + const response = await sdTxt2Img( + worker.api, + { ...config.defaultParams, ...job.value.params, ...size }, + async (progress) => { + // important: don't let any errors escape this callback + if (job.value.reply) { + await bot.api.editMessageText( + job.value.reply.chat.id, + job.value.reply.message_id, + `Generating your prompt now... ${ + (progress.progress * 100).toFixed(0) + }% using ${worker.name}`, + ).catch(() => undefined); + } + await job.update({ + status: { + type: "processing", + progress: progress.progress, + worker: worker.name, + updatedDate: new Date(), + }, + }).catch(() => undefined); + }, + ); + + // upload the result + if (job.value.reply) { + await bot.api.editMessageText( + job.value.reply.chat.id, + job.value.reply.message_id, + `Uploading your images...`, + ).catch(() => undefined); + } + + // render the caption + // const detailedReply = Object.keys(job.value.params).filter((key) => key !== "prompt").length > 0; + const detailedReply = true; + const jobDurationMs = Math.trunc((Date.now() - startDate.getTime()) / 1000) * 1000; + const { bold } = GrammyParseMode; + const caption = fmt([ + `${response.info.prompt}\n`, + ...detailedReply + ? [ + response.info.negative_prompt + ? fmt`${bold("Negative prompt:")} ${response.info.negative_prompt}\n` + : "", + fmt`${bold("Steps:")} ${response.info.steps}, `, + fmt`${bold("Sampler:")} ${response.info.sampler_name}, `, + fmt`${bold("CFG scale:")} ${response.info.cfg_scale}, `, + fmt`${bold("Seed:")} ${response.info.seed}, `, + fmt`${bold("Size")}: ${response.info.width}x${response.info.height}, `, + fmt`${bold("Worker")}: ${worker.name}, `, + fmt`${bold("Time taken")}: ${FmtDuration.format(jobDurationMs, { ignoreZero: true })}`, + ] + : [], + ]); + + // parse files from reply JSON + const inputFiles = await Promise.all( + response.images.map(async (imageBase64, idx) => { + const imageBuffer = Base64.decode(imageBase64); + const imageType = await FileType.fileTypeFromBuffer(imageBuffer); + if (!imageType) throw new Error("Unknown file type returned from worker"); + return Grammy.InputMediaBuilder.photo( + new Grammy.InputFile(imageBuffer, `image${idx}.${imageType.ext}`), + // if it can fit, add caption for first photo + idx === 0 && caption.text.length <= 1024 + ? { caption: caption.text, caption_entities: caption.entities } + : undefined, + ); + }), + ); + + // send the result to telegram + const resultMessage = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, { + reply_to_message_id: job.value.request.message_id, + }); + // send caption in separate message if it couldn't fit + if (caption.text.length > 1024 && caption.text.length <= 4096) { + await bot.api.sendMessage(job.value.request.chat.id, caption.text, { + reply_to_message_id: resultMessage[0].message_id, + entities: caption.entities, + }); + } + + // delete the status message + if (job.value.reply) { + await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id) + .catch(() => undefined) + .then(() => job.update({ reply: undefined })) + .catch(() => undefined); + } + + // update job to status done + await job.update({ + status: { type: "done", info: response.info, startDate, endDate: new Date() }, + }); + logger().debug( + `Job finished for ${formatUserChat(job.value.request)} using ${worker.name}`, + ); +} + +function limitSize( + { width, height }: { width?: number; height?: number }, + maxResolution: number, +): { width?: number; height?: number } { + if (!width || !height) return {}; + const ratio = width / height; + if (width * height > maxResolution) { + return { + width: Math.trunc(Math.sqrt(maxResolution * ratio)), + height: Math.trunc(Math.sqrt(maxResolution / ratio)), + }; + } + return { width, height }; +} diff --git a/tasks/returnHangedJobs.ts b/tasks/returnHangedJobs.ts new file mode 100644 index 0000000..162198d --- /dev/null +++ b/tasks/returnHangedJobs.ts @@ -0,0 +1,36 @@ +import { FmtDuration, Log } from "../deps.ts"; +import { formatUserChat } from "../utils.ts"; +import { jobStore } from "../db/jobStore.ts"; + +const logger = () => Log.getLogger(); + +/** + * Returns hanged jobs to the queue. + */ +export async function returnHangedJobs(): Promise { + while (true) { + try { + await new Promise((resolve) => setTimeout(resolve, 5000)); + const jobs = await jobStore.getBy("status.type", "processing"); + for (const job of jobs) { + if (job.value.status.type !== "processing") continue; + // if job wasn't updated for 1 minute, return it to the queue + const timeSinceLastUpdateMs = Date.now() - job.value.status.updatedDate.getTime(); + if (timeSinceLastUpdateMs > 60 * 1000) { + await job.update({ status: { type: "waiting" } }); + logger().warning( + `Job for ${ + formatUserChat(job.value.request) + } was returned to the queue because it hanged for ${ + FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, { + ignoreZero: true, + }) + }`, + ); + } + } + } catch (err) { + logger().warning(`Returning hanged jobs failed: ${err}`); + } + } +} diff --git a/tasks/updateJobStatusMsgs.ts b/tasks/updateJobStatusMsgs.ts new file mode 100644 index 0000000..2fa6149 --- /dev/null +++ b/tasks/updateJobStatusMsgs.ts @@ -0,0 +1,28 @@ +import { Log } from "../deps.ts"; +import { bot } from "../bot/mod.ts"; +import { formatOrdinal } from "../utils.ts"; +import { jobStore } from "../db/jobStore.ts"; + +const logger = () => Log.getLogger(); + +/** + * Updates status messages for jobs in the queue. + */ +export async function updateJobStatusMsgs(): Promise { + while (true) { + try { + await new Promise((resolve) => setTimeout(resolve, 5000)); + const jobs = await jobStore.getBy("status.type", "waiting"); + for (const [index, job] of jobs.entries()) { + if (!job.value.reply) continue; + await bot.api.editMessageText( + job.value.reply.chat.id, + job.value.reply.message_id, + `You are ${formatOrdinal(index + 1)} in queue.`, + ).catch(() => undefined); + } + } catch (err) { + logger().warning(`Updating job status messages failed: ${err}`); + } + } +} diff --git a/types/png-chunk-text.d.ts b/types/png-chunk-text.d.ts new file mode 100644 index 0000000..cb2adfb --- /dev/null +++ b/types/png-chunk-text.d.ts @@ -0,0 +1,2 @@ +export function decode(chunk: Uint8Array): { keyword: string; text: string }; +export function encode(keyword: string, text: string): Uint8Array; diff --git a/types/png-chunks-encode.d.ts b/types/png-chunks-encode.d.ts new file mode 100644 index 0000000..4267e6d --- /dev/null +++ b/types/png-chunks-encode.d.ts @@ -0,0 +1 @@ +export default function encode(chunks: Array<{ name: string; data: Uint8Array }>): Uint8Array; diff --git a/types/png-chunks-extract.d.ts b/types/png-chunks-extract.d.ts new file mode 100644 index 0000000..daee7c8 --- /dev/null +++ b/types/png-chunks-extract.d.ts @@ -0,0 +1 @@ +export default function extract(data: Uint8Array): Array<{ name: string; data: Uint8Array }>; diff --git a/utils.ts b/utils.ts new file mode 100644 index 0000000..736b793 --- /dev/null +++ b/utils.ts @@ -0,0 +1,111 @@ +import { GrammyParseMode, GrammyTypes } from "./deps.ts"; + +export function formatOrdinal(n: number) { + if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`; + if (n % 10 === 1) return `${n}st`; + if (n % 10 === 2) return `${n}nd`; + if (n % 10 === 3) return `${n}rd`; + return `${n}th`; +} + +export const fmt = ( + rawStringParts: TemplateStringsArray | GrammyParseMode.Stringable[], + ...stringLikes: GrammyParseMode.Stringable[] +): GrammyParseMode.FormattedString => { + let text = ""; + const entities: GrammyTypes.MessageEntity[] = []; + + const length = Math.max(rawStringParts.length, stringLikes.length); + for (let i = 0; i < length; i++) { + for (const stringLike of [rawStringParts[i], stringLikes[i]]) { + if (stringLike instanceof GrammyParseMode.FormattedString) { + entities.push( + ...stringLike.entities.map((e) => ({ + ...e, + offset: e.offset + text.length, + })), + ); + } + if (stringLike != null) text += stringLike.toString(); + } + } + return new GrammyParseMode.FormattedString(text, entities); +}; + +export function formatUserChat(ctx: { from?: GrammyTypes.User; chat?: GrammyTypes.Chat }) { + const msg: string[] = []; + if (ctx.from) { + msg.push(ctx.from.first_name); + if (ctx.from.last_name) msg.push(ctx.from.last_name); + if (ctx.from.username) msg.push(`(@${ctx.from.username})`); + if (ctx.from.language_code) msg.push(`(${ctx.from.language_code.toUpperCase()})`); + } + if (ctx.chat) { + if ( + ctx.chat.type === "group" || + ctx.chat.type === "supergroup" || + ctx.chat.type === "channel" + ) { + msg.push("in"); + msg.push(ctx.chat.title); + if ( + (ctx.chat.type === "supergroup" || ctx.chat.type === "channel") && + ctx.chat.username + ) { + msg.push(`(@${ctx.chat.username})`); + } + } + } + return msg.join(" "); +} + +/** Language to biggest country emoji map */ +const languageToFlagMap: Record = { + "en": "🇺🇸", // English - United States + "zh": "🇨🇳", // Chinese - China + "es": "🇪🇸", // Spanish - Spain + "hi": "🇮🇳", // Hindi - India + "ar": "🇪🇬", // Arabic - Egypt + "pt": "🇧🇷", // Portuguese - Brazil + "bn": "🇧🇩", // Bengali - Bangladesh + "ru": "🇷🇺", // Russian - Russia + "ja": "🇯🇵", // Japanese - Japan + "pa": "🇮🇳", // Punjabi - India + "de": "🇩🇪", // German - Germany + "ko": "🇰🇷", // Korean - South Korea + "fr": "🇫🇷", // French - France + "tr": "🇹🇷", // Turkish - Turkey + "ur": "🇵🇰", // Urdu - Pakistan + "it": "🇮🇹", // Italian - Italy + "th": "🇹🇭", // Thai - Thailand + "vi": "🇻🇳", // Vietnamese - Vietnam + "pl": "🇵🇱", // Polish - Poland + "uk": "🇺🇦", // Ukrainian - Ukraine + "uz": "🇺🇿", // Uzbek - Uzbekistan + "su": "🇮🇩", // Sundanese - Indonesia + "sw": "🇹🇿", // Swahili - Tanzania + "nl": "🇳🇱", // Dutch - Netherlands + "fi": "🇫🇮", // Finnish - Finland + "el": "🇬🇷", // Greek - Greece + "da": "🇩🇰", // Danish - Denmark + "cs": "🇨🇿", // Czech - Czech Republic + "sk": "🇸🇰", // Slovak - Slovakia + "bg": "🇧🇬", // Bulgarian - Bulgaria + "sv": "🇸🇪", // Swedish - Sweden + "be": "🇧🇾", // Belarusian - Belarus + "hu": "🇭🇺", // Hungarian - Hungary + "lt": "🇱🇹", // Lithuanian - Lithuania + "lv": "🇱🇻", // Latvian - Latvia + "et": "🇪🇪", // Estonian - Estonia + "sl": "🇸🇮", // Slovenian - Slovenia + "hr": "🇭🇷", // Croatian - Croatia + "zu": "🇿🇦", // Zulu - South Africa + "id": "🇮🇩", // Indonesian - Indonesia + "is": "🇮🇸", // Icelandic - Iceland + "lb": "🇱🇺", // Luxembourgish - Luxembourg +}; + +export function getFlagEmoji(countryCode?: string): string | undefined { + if (!countryCode) return; + return languageToFlagMap[countryCode]; +}