From 2f5058ac45f005713a7b5ab532785cd07dbaa6e5 Mon Sep 17 00:00:00 2001 From: pinks Date: Mon, 4 Sep 2023 18:55:48 +0200 Subject: [PATCH] chore: split into modules --- README.md | 5 +- bot.ts | 113 +++++++++++++++ config.ts | 15 ++ deno.jsonc | 3 + deps.ts | 4 + fmtArray.ts | 0 intl.ts | 33 +++++ main.ts | 392 +--------------------------------------------------- mimeType.ts | 15 ++ queue.ts | 123 +++++++++++++++++ sd.ts | 88 ++++++++++++ 11 files changed, 403 insertions(+), 388 deletions(-) create mode 100644 bot.ts create mode 100644 config.ts create mode 100644 deps.ts create mode 100644 fmtArray.ts create mode 100644 intl.ts create mode 100644 mimeType.ts create mode 100644 queue.ts create mode 100644 sd.ts diff --git a/README.md b/README.md index c60d875..5042560 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Telegram bot for generating images from text. ## Requirements - [Deno](https://deno.land/) -- [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/) +- [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/) ## Options @@ -13,7 +13,8 @@ You can put these in `.env` file or pass them as environment variables. - `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather)) - `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`) -- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands (optional) +- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands + (optional) ## Running diff --git a/bot.ts b/bot.ts new file mode 100644 index 0000000..544608d --- /dev/null +++ b/bot.ts @@ -0,0 +1,113 @@ +import { + autoQuote, + autoRetry, + bold, + Bot, + Context, + fmt, + hydrateReply, + ParseModeFlavor, +} from "./deps.ts"; +import { fmtArray, formatOrdinal } from "./intl.ts"; +import { config } from "./config.ts"; +import { queue } from "./queue.ts"; + +export const bot = new Bot>(Deno.env.get("TG_BOT_TOKEN") ?? ""); +bot.use(autoQuote); +bot.use(hydrateReply); +bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 })); + +bot.api.setMyShortDescription("I can generate furry images from text"); +bot.api.setMyDescription( + "I can generate furry images from text. Send /txt2img to generate an image.", +); +bot.api.setMyCommands([ + { command: "txt2img", description: "Generate an image" }, + { command: "queue", description: "Show the current queue" }, +]); + +bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image")); + +bot.command("txt2img", async (ctx) => { + if (!ctx.from?.id) { + return ctx.reply("I don't know who you are"); + } + if (config.pausedReason != null) { + return ctx.reply(`I'm paused: ${config.pausedReason}`); + } + if (queue.length >= config.maxJobs) { + return ctx.reply( + `The queue is full. Try again later. (Max queue size: ${config.maxJobs})`, + ); + } + const jobCount = queue.filter((job) => job.userId === ctx.from.id).length; + if (jobCount >= config.maxUserJobs) { + return ctx.reply( + `You already have ${config.maxUserJobs} jobs in queue. Try again later.`, + ); + } + if (!ctx.match) { + return ctx.reply("Please describe what you want to see after the command"); + } + const place = queue.length + 1; + const queueMessage = await ctx.reply(`You are ${formatOrdinal(place)} in queue.`); + const userName = [ctx.from.first_name, ctx.from.last_name].filter(Boolean).join(" "); + const chatName = ctx.chat.type === "supergroup" || ctx.chat.type === "group" + ? ctx.chat.title + : "private chat"; + queue.push({ + params: { prompt: ctx.match }, + userId: ctx.from.id, + userName, + chatId: ctx.chat.id, + chatName, + requestMessageId: ctx.message.message_id, + statusMessageId: queueMessage.message_id, + }); + console.log(`Enqueued job for ${userName} in chat ${chatName}`); +}); + +bot.command("queue", async (ctx) => { + if (queue.length === 0) return ctx.reply("Queue is empty"); + return await ctx.replyFmt( + fmt`Current queue:\n\n${ + fmtArray( + queue.map((job, index) => + fmt`${bold(index + 1)}. ${bold(job.userName)} in ${bold(job.chatName)}` + ), + "\n", + ) + }`, + ); +}); + +bot.command("pause", async (ctx) => { + if (!ctx.from?.username) return; + if (!config.adminUsernames.includes(ctx.from.username)) return; + if (config.pausedReason != null) { + return await ctx.reply(`Already paused: ${config.pausedReason}`); + } + config.pausedReason = ctx.match ?? "No reason given"; + return await ctx.reply("Paused"); +}); + +bot.command("resume", async (ctx) => { + if (!ctx.from?.username) return; + if (!config.adminUsernames.includes(ctx.from.username)) return; + if (config.pausedReason == null) return await ctx.reply("Already running"); + config.pausedReason = null; + return await ctx.reply("Resumed"); +}); + +bot.catch((err) => { + let msg = "Error processing update"; + const { from, chat } = err.ctx; + if (from?.first_name) msg += ` from ${from.first_name}`; + if (from?.last_name) msg += ` ${from.last_name}`; + if (from?.username) msg += ` (@${from.username})`; + if (chat?.type === "supergroup" || chat?.type === "group") { + msg += ` in ${chat.title}`; + if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`; + } + console.error(msg, err.error); +}); diff --git a/config.ts b/config.ts new file mode 100644 index 0000000..4d41272 --- /dev/null +++ b/config.ts @@ -0,0 +1,15 @@ +export const config: Config = { + adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean), + pausedReason: null, + sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/", + maxUserJobs: 3, + maxJobs: 20, +}; + +interface Config { + adminUsernames: string[]; + pausedReason: string | null; + sdApiUrl: string; + maxUserJobs: number; + maxJobs: number; +} diff --git a/deno.jsonc b/deno.jsonc index 6330019..9be0600 100644 --- a/deno.jsonc +++ b/deno.jsonc @@ -2,5 +2,8 @@ "tasks": { "dev": "deno run --watch --allow-env --allow-read --allow-net main.ts", "start": "deno run --allow-env --allow-read --allow-net main.ts" + }, + "fmt": { + "lineWidth": 100 } } diff --git a/deps.ts b/deps.ts new file mode 100644 index 0000000..faa9429 --- /dev/null +++ b/deps.ts @@ -0,0 +1,4 @@ +export * from "https://deno.land/x/grammy@v1.18.1/mod.ts"; +export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; +export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; +export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1"; diff --git a/fmtArray.ts b/fmtArray.ts new file mode 100644 index 0000000..e69de29 diff --git a/intl.ts b/intl.ts new file mode 100644 index 0000000..c955dfa --- /dev/null +++ b/intl.ts @@ -0,0 +1,33 @@ +import { FormattedString } from "./deps.ts"; + +export function formatOrdinal(n: number) { + if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`; + if (n % 10 === 1) return `${n}st`; + if (n % 10 === 2) return `${n}nd`; + if (n % 10 === 3) return `${n}rd`; + return `${n}th`; +} + +/** + * Like `fmt` from `grammy_parse_mode` but accepts an array instead of template string. + * @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182 + */ +export function fmtArray( + stringLikes: FormattedString[], + separator = "", +): FormattedString { + let text = ""; + const entities: ConstructorParameters[1] = []; + for (let i = 0; i < stringLikes.length; i++) { + const stringLike = stringLikes[i]; + entities.push( + ...stringLike.entities.map((e) => ({ + ...e, + offset: e.offset + text.length, + })), + ); + text += stringLike.toString(); + if (i < stringLikes.length - 1) text += separator; + } + return new FormattedString(text, entities); +} diff --git a/main.ts b/main.ts index bc53960..623e567 100644 --- a/main.ts +++ b/main.ts @@ -1,388 +1,8 @@ -import { - Bot, - Context, - InputFile, - InputMediaBuilder, -} from "https://deno.land/x/grammy@v1.18.1/mod.ts"; -import { autoQuote } from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; -import { - fmt, - hydrateReply, - ParseModeFlavor, -} from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; -import "https://deno.land/x/dotenv@v3.2.2/load.ts"; -import { - FormattedString, - bold, -} from "https://deno.land/x/grammy_parse_mode@1.7.1/format.ts"; -import { autoRetry } from "https://esm.sh/@grammyjs/auto-retry"; -import { MessageEntity } from "https://deno.land/x/grammy@v1.18.1/types.ts"; +import "https://deno.land/std@0.201.0/dotenv/load.ts"; +import { bot } from "./bot.ts"; +import { processQueue } from "./queue.ts"; -const maxUserJobs = 3; -const maxJobs = 10; - -let pausedReason: string | null = null; - -const sdApiUrl = Deno.env.get("SD_API_URL"); -if (!sdApiUrl) throw new Error("SD_API_URL not set"); -console.log("Using SD API URL:", sdApiUrl); -const sdConfigUrl = new URL("/config", sdApiUrl); -const sdConfigRequest = await fetch(sdConfigUrl); -if (!sdConfigRequest.ok) - throw new Error( - `Failed to fetch SD config from ${sdConfigUrl}: ${sdConfigRequest.statusText}` - ); -const sdConfig = await sdConfigRequest.json(); -console.log("Using SD WebUI version:", String(sdConfig.version).trim()); - -const adminUsernames = (Deno.env.get("ADMIN_USERNAMES") ?? "") - .split(",") - .filter(Boolean); - -const tgBotToken = Deno.env.get("TG_BOT_TOKEN"); -if (!tgBotToken) throw new Error("TG_BOT_TOKEN not set"); - -const bot = new Bot>(tgBotToken); -bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 30 })); - -bot.api.setMyShortDescription("I can generate furry images from text"); -bot.api.setMyDescription( - "I can generate furry images from text. Send /txt2img to generate an image." -); -bot.api.setMyCommands([ - { command: "txt2img", description: "Generate an image" }, - { command: "queue", description: "Show the current queue" }, +await Promise.all([ + bot.start(), + processQueue(), ]); - -bot.use(autoQuote); -bot.use(hydrateReply); - -bot.command("start", (ctx) => - ctx.reply("Hello! Use the /txt2img command to generate an image") -); - -bot.command("txt2img", async (ctx) => { - if (!ctx.from?.id) { - return ctx.reply("I don't know who you are"); - } - if (pausedReason != null) { - return ctx.reply(`I'm paused: ${pausedReason}`); - } - if (queue.length >= maxJobs) { - return ctx.reply( - `The queue is full. Try again later. (Max queue size: ${maxJobs})` - ); - } - const jobCount = queue.filter((job) => job.userId === ctx.from.id).length; - if (jobCount >= maxUserJobs) { - return ctx.reply( - `You already have ${maxUserJobs} jobs in queue. Try again later.` - ); - } - if (!ctx.match) { - return ctx.reply("Please describe what you want to see"); - } - const place = queue.length + 1; - const queueMessage = await ctx.reply( - `You are ${formatOrdinal(place)} in queue.` - ); - const userName = [ctx.from.first_name, ctx.from.last_name] - .filter(Boolean) - .join(" "); - const chatName = - ctx.chat.type === "supergroup" || ctx.chat.type === "group" - ? ctx.chat.title - : "private chat"; - queue.push({ - params: { prompt: ctx.match }, - userId: ctx.from.id, - userName, - chatId: ctx.chat.id, - chatName, - requestMessageId: ctx.message.message_id, - statusMessageId: queueMessage.message_id, - }); - console.log(`Enqueued job for ${userName} in chat ${chatName}`); -}); - -bot.command("queue", async (ctx) => { - if (queue.length === 0) return ctx.reply("Queue is empty"); - return await ctx.replyFmt( - fmt`Current queue:\n\n${fmtArray( - queue.map( - (job, index) => - fmt`${bold(index + 1)}. ${bold(job.userName)} in ${bold( - job.chatName - )}` - ), - "\n" - )}` - ); -}); - -bot.command("pause", async (ctx) => { - if (!ctx.from?.username) return; - if (!adminUsernames.includes(ctx.from.username)) return; - if (pausedReason != null) - return await ctx.reply(`Already paused: ${pausedReason}`); - pausedReason = ctx.match ?? "No reason given"; - return await ctx.reply("Paused"); -}); - -bot.command("resume", async (ctx) => { - if (!ctx.from?.username) return; - if (!adminUsernames.includes(ctx.from.username)) return; - if (pausedReason == null) return await ctx.reply("Already running"); - pausedReason = null; - return await ctx.reply("Resumed"); -}); - -bot.catch((err) => { - let msg = "Error processing update"; - const { from, chat } = err.ctx; - if (from?.first_name) msg += ` from ${from.first_name}`; - if (from?.last_name) msg += ` ${from.last_name}`; - if (from?.username) msg += ` (@${from.username})`; - if (chat?.type === "supergroup" || chat?.type === "group") { - msg += ` in ${chat.title}`; - if (chat.type === "supergroup" && chat.username) - msg += ` (@${chat.username})`; - } - console.error(msg, err.error); -}); - -const queue: Job[] = []; - -interface Job { - params: Partial; - userId: number; - userName: string; - chatId: number; - chatName: string; - requestMessageId: number; - statusMessageId: number; -} - -async function processQueue() { - while (true) { - const job = queue.shift(); - if (!job) { - await new Promise((resolve) => setTimeout(resolve, 1000)); - continue; - } - for (const [index, job] of queue.entries()) { - const place = index + 1; - await bot.api - .editMessageText( - job.chatId, - job.statusMessageId, - `You are ${formatOrdinal(place)} in queue.` - ) - .catch(() => {}); - } - try { - await bot.api.deleteMessage(job.chatId, job.statusMessageId); - const progressMessage = await bot.api.sendMessage( - job.chatId, - "Generating your prompt now...", - { reply_to_message_id: job.requestMessageId } - ); - const onProgress = (progress: SdProgressResponse) => { - bot.api - .editMessageText( - job.chatId, - progressMessage.message_id, - `Generating your prompt now... ${Math.round( - progress.progress * 100 - )}%` - ) - .catch(() => {}); - }; - const response = await txt2img( - { ...defaultParams, ...job.params }, - onProgress - ); - console.log( - `Generated image for ${job.userName} in ${job.chatName}: ${job.params.prompt}` - ); - bot.api.editMessageText( - job.chatId, - progressMessage.message_id, - `Uploading your images...` - ); - const inputFiles = await Promise.all( - response.images.slice(1).map(async (imageBase64) => { - const imageBlob = await fetch( - `data:${mimeTypeFromBase64(imageBase64)};base64,${imageBase64}` - ).then((resp) => resp.blob()); - return InputMediaBuilder.photo(new InputFile(imageBlob)); - }) - ); - await bot.api.sendMediaGroup(job.chatId, inputFiles, { - reply_to_message_id: job.requestMessageId, - }); - await bot.api.deleteMessage(job.chatId, progressMessage.message_id); - console.log(`${queue.length} jobs remaining`); - } catch (err) { - console.error( - `Failed to generate image for ${job.userName} in ${job.chatName}: ${job.params.prompt} - ${err}` - ); - await bot.api - .sendMessage(job.chatId, err.toString(), { - reply_to_message_id: job.requestMessageId, - }) - .catch(() => {}); - } - } -} - -function formatOrdinal(n: number) { - if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`; - if (n % 10 === 1) return `${n}st`; - if (n % 10 === 2) return `${n}nd`; - if (n % 10 === 3) return `${n}rd`; - return `${n}th`; -} - -const defaultParams: Partial = { - batch_size: 3, - n_iter: 1, - width: 128 * 5, - height: 128 * 7, - steps: 40, - cfg_scale: 9, - send_images: true, - save_images: true, - negative_prompt: - "id210 boring_e621_fluffyrock_v4 boring_e621_v4 easynegative ng_deepnegative_v1_75t", -}; - -function mimeTypeFromBase64(base64: string) { - if (base64.startsWith("/9j/")) { - return "image/jpeg"; - } - if (base64.startsWith("iVBORw0KGgo")) { - return "image/png"; - } - if (base64.startsWith("R0lGODlh")) { - return "image/gif"; - } - if (base64.startsWith("UklGRg")) { - return "image/webp"; - } - throw new Error("Unknown image type"); -} - -async function txt2img( - params: Partial, - onProgress?: (progress: SdProgressResponse) => void, - signal?: AbortSignal -): Promise { - let response: Response | undefined; - let error: unknown; - fetch(new URL("sdapi/v1/txt2img", sdApiUrl), { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(params), - }).then( - (resp) => (response = resp), - (err) => (error = err) - ); - try { - while (true) { - await new Promise((resolve) => setTimeout(resolve, 3000)); - const progressRequest = await fetch( - new URL("sdapi/v1/progress", sdApiUrl) - ); - if (progressRequest.ok) { - const progress = (await progressRequest.json()) as SdProgressResponse; - onProgress?.(progress); - } - if (response != null) { - if (response.ok) { - const result = (await response.json()) as SdResponse; - return result; - } else { - throw new Error( - `Request failed: ${response.status} ${response.statusText}` - ); - } - } - if (error != null) { - throw error; - } - signal?.throwIfAborted(); - } - } finally { - if (!response && !error) - await fetch(new URL("sdapi/v1/interrupt", sdApiUrl), { method: "POST" }); - } -} - -interface SdRequest { - denoising_strength: number; - prompt: string; - seed: number; - sampler_name: unknown; - batch_size: number; - n_iter: number; - steps: number; - cfg_scale: number; - width: number; - height: number; - negative_prompt: string; - send_images: boolean; - save_images: boolean; -} - -interface SdResponse { - images: string[]; - parameters: SdRequest; - /** Contains serialized JSON */ - info: string; -} - -interface SdProgressResponse { - progress: number; - eta_relative: number; - state: SdProgressState; - /** base64 encoded preview */ - current_image: string | null; - textinfo: string | null; -} - -interface SdProgressState { - skipped: boolean; - interrupted: boolean; - job: string; - job_count: number; - job_timestamp: string; - job_no: number; - sampling_step: number; - sampling_steps: number; -} - -/** Like {@link fmt} but accepts an array instead of template string. */ -function fmtArray( - stringLikes: FormattedString[], - separator = "" -): FormattedString { - let text = ""; - const entities: MessageEntity[] = []; - for (let i = 0; i < stringLikes.length; i++) { - const stringLike = stringLikes[i]; - - entities.push( - ...stringLike.entities.map((e) => ({ - ...e, - offset: e.offset + text.length, - })) - ); - - text += stringLike.toString(); - if (i < stringLikes.length - 1) text += separator; - } - return new FormattedString(text, entities); -} - -await Promise.all([bot.start(), processQueue()]); diff --git a/mimeType.ts b/mimeType.ts new file mode 100644 index 0000000..2d61931 --- /dev/null +++ b/mimeType.ts @@ -0,0 +1,15 @@ +export function mimeTypeFromBase64(base64: string) { + if (base64.startsWith("/9j/")) return "image/jpeg"; + if (base64.startsWith("iVBORw0KGgo")) return "image/png"; + if (base64.startsWith("R0lGODlh")) return "image/gif"; + if (base64.startsWith("UklGRg")) return "image/webp"; + throw new Error("Unknown image type"); +} + +export function extFromMimeType(mimeType: string) { + if (mimeType === "image/jpeg") return "jpg"; + if (mimeType === "image/png") return "png"; + if (mimeType === "image/gif") return "gif"; + if (mimeType === "image/webp") return "webp"; + throw new Error("Unknown image type"); +} diff --git a/queue.ts b/queue.ts new file mode 100644 index 0000000..a1e0db5 --- /dev/null +++ b/queue.ts @@ -0,0 +1,123 @@ +import { InputFile, InputMediaBuilder } from "./deps.ts"; +import { config } from "./config.ts"; +import { bot } from "./bot.ts"; +import { formatOrdinal } from "./intl.ts"; +import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts"; +import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts"; + +export const queue: Job[] = []; + +interface Job { + params: Partial; + userId: number; + userName: string; + chatId: number; + chatName: string; + requestMessageId: number; + statusMessageId: number; +} + +export async function processQueue() { + while (true) { + const job = queue.shift(); + if (!job) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + continue; + } + for (const [index, job] of queue.entries()) { + const place = index + 1; + await bot.api + .editMessageText( + job.chatId, + job.statusMessageId, + `You are ${formatOrdinal(place)} in queue.`, + ) + .catch(() => {}); + } + try { + await bot.api + .deleteMessage(job.chatId, job.statusMessageId) + .catch(() => {}); + const progressMessage = await bot.api.sendMessage( + job.chatId, + "Generating your prompt now...", + { reply_to_message_id: job.requestMessageId }, + ); + const onProgress = (progress: SdProgressResponse) => { + bot.api + .editMessageText( + job.chatId, + progressMessage.message_id, + `Generating your prompt now... ${ + Math.round( + progress.progress * 100, + ) + }%`, + ) + .catch(() => {}); + }; + const response = await txt2img( + config.sdApiUrl, + { ...defaultParams, ...job.params }, + onProgress, + ); + + console.log( + `Generated ${response.images.length} images (${ + response.images + .map((image) => (image.length / 1024).toFixed(0) + "kB") + .join(", ") + }) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`, + ); + bot.api.editMessageText( + job.chatId, + progressMessage.message_id, + `Uploading your images...`, + ); + const inputFiles = await Promise.all( + response.images.map(async (imageBase64, idx) => { + const mimeType = mimeTypeFromBase64(imageBase64); + const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`).then((resp) => + resp.blob() + ); + console.log( + `Uploading image ${idx + 1} of ${response.images.length} (${ + (imageBlob.size / 1024).toFixed(0) + }kB)`, + ); + return InputMediaBuilder.photo( + new InputFile(imageBlob, `${idx}.${extFromMimeType(mimeType)}`), + ); + }), + ); + await bot.api.sendMediaGroup(job.chatId, inputFiles, { + reply_to_message_id: job.requestMessageId, + }); + await bot.api + .deleteMessage(job.chatId, progressMessage.message_id) + .catch(() => {}); + console.log(`${queue.length} jobs remaining`); + } catch (err) { + console.error( + `Failed to generate image for ${job.userName} in ${job.chatName}: ${job.params.prompt} - ${err}`, + ); + await bot.api + .sendMessage(job.chatId, err.toString(), { + reply_to_message_id: job.requestMessageId, + }) + .catch(() => bot.api.sendMessage(job.chatId, err.toString())) + .catch(() => {}); + } + } +} + +const defaultParams: Partial = { + batch_size: 1, + n_iter: 1, + width: 128 * 2, + height: 128 * 3, + steps: 20, + cfg_scale: 9, + send_images: true, + negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4", +}; diff --git a/sd.ts b/sd.ts new file mode 100644 index 0000000..01cf0f5 --- /dev/null +++ b/sd.ts @@ -0,0 +1,88 @@ +export async function txt2img( + apiUrl: string, + params: Partial, + onProgress?: (progress: SdProgressResponse) => void, + signal?: AbortSignal, +): Promise { + let response: Response | undefined; + let error: unknown; + + fetch(new URL("sdapi/v1/txt2img", apiUrl), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(params), + }).then( + (resp) => (response = resp), + (err) => (error = err), + ); + + try { + while (true) { + await new Promise((resolve) => setTimeout(resolve, 3000)); + const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl)); + if (progressRequest.ok) { + const progress = (await progressRequest.json()) as SdProgressResponse; + onProgress?.(progress); + } + if (response != null) { + if (response.ok) { + const result = (await response.json()) as SdResponse; + return result; + } else { + throw new Error(`Request failed: ${response.status} ${response.statusText}`); + } + } + if (error != null) { + throw error; + } + signal?.throwIfAborted(); + } + } finally { + if (!response && !error) { + await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" }); + } + } +} + +export interface SdRequest { + denoising_strength: number; + prompt: string; + seed: number; + sampler_name: unknown; + batch_size: number; + n_iter: number; + steps: number; + cfg_scale: number; + width: number; + height: number; + negative_prompt: string; + send_images: boolean; + save_images: boolean; +} + +export interface SdResponse { + images: string[]; + parameters: SdRequest; + /** Contains serialized JSON */ + info: string; +} + +export interface SdProgressResponse { + progress: number; + eta_relative: number; + state: SdProgressState; + /** base64 encoded preview */ + current_image: string | null; + textinfo: string | null; +} + +export interface SdProgressState { + skipped: boolean; + interrupted: boolean; + job: string; + job_count: number; + job_timestamp: string; + job_no: number; + sampling_step: number; + sampling_steps: number; +}