diff --git a/bot.ts b/bot.ts index 544608d..cf7b33c 100644 --- a/bot.ts +++ b/bot.ts @@ -4,19 +4,81 @@ import { bold, Bot, Context, + DenoKVAdapter, fmt, hydrateReply, ParseModeFlavor, + session, + SessionFlavor, } from "./deps.ts"; import { fmtArray, formatOrdinal } from "./intl.ts"; -import { config } from "./config.ts"; import { queue } from "./queue.ts"; +import { SdRequest } from "./sd.ts"; -export const bot = new Bot>(Deno.env.get("TG_BOT_TOKEN") ?? ""); +type AppContext = ParseModeFlavor & SessionFlavor; + +interface SessionData { + global: { + adminUsernames: string[]; + pausedReason: string | null; + sdApiUrl: string; + maxUserJobs: number; + maxJobs: number; + defaultParams?: Partial; + }; + user: { + steps: number; + detail: number; + batchSize: number; + }; +} + +export const bot = new Bot(Deno.env.get("TG_BOT_TOKEN") ?? ""); bot.use(autoQuote); bot.use(hydrateReply); bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 })); +const db = await Deno.openKv("./app.db"); + +const getDefaultGlobalSession = (): SessionData["global"] => ({ + adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean), + pausedReason: null, + sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/", + maxUserJobs: 3, + maxJobs: 20, + defaultParams: { + batch_size: 1, + n_iter: 1, + width: 128 * 2, + height: 128 * 3, + steps: 20, + cfg_scale: 9, + send_images: true, + negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4", + }, +}); + +bot.use(session({ + type: "multi", + global: { + getSessionKey: () => "global", + initial: getDefaultGlobalSession, + storage: new DenoKVAdapter(db), + }, + user: { + initial: () => ({ + steps: 20, + detail: 8, + batchSize: 2, + }), + }, +})); + +export async function getGlobalSession(): Promise { + const entry = await db.get(["sessions", "global"]); + return entry.value ?? getDefaultGlobalSession(); +} + bot.api.setMyShortDescription("I can generate furry images from text"); bot.api.setMyDescription( "I can generate furry images from text. Send /txt2img to generate an image.", @@ -24,6 +86,7 @@ bot.api.setMyDescription( bot.api.setMyCommands([ { command: "txt2img", description: "Generate an image" }, { command: "queue", description: "Show the current queue" }, + { command: "sdparams", description: "Show the current SD parameters" }, ]); bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image")); @@ -32,8 +95,9 @@ bot.command("txt2img", async (ctx) => { if (!ctx.from?.id) { return ctx.reply("I don't know who you are"); } + const config = ctx.session.global; if (config.pausedReason != null) { - return ctx.reply(`I'm paused: ${config.pausedReason}`); + return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); } if (queue.length >= config.maxJobs) { return ctx.reply( @@ -67,9 +131,9 @@ bot.command("txt2img", async (ctx) => { console.log(`Enqueued job for ${userName} in chat ${chatName}`); }); -bot.command("queue", async (ctx) => { +bot.command("queue", (ctx) => { if (queue.length === 0) return ctx.reply("Queue is empty"); - return await ctx.replyFmt( + return ctx.replyFmt( fmt`Current queue:\n\n${ fmtArray( queue.map((job, index) => @@ -81,22 +145,141 @@ bot.command("queue", async (ctx) => { ); }); -bot.command("pause", async (ctx) => { +bot.command("pause", (ctx) => { if (!ctx.from?.username) return; + const config = ctx.session.global; if (!config.adminUsernames.includes(ctx.from.username)) return; if (config.pausedReason != null) { - return await ctx.reply(`Already paused: ${config.pausedReason}`); + return ctx.reply(`Already paused: ${config.pausedReason}`); } config.pausedReason = ctx.match ?? "No reason given"; - return await ctx.reply("Paused"); + return ctx.reply("Paused"); }); -bot.command("resume", async (ctx) => { +bot.command("resume", (ctx) => { if (!ctx.from?.username) return; + const config = ctx.session.global; if (!config.adminUsernames.includes(ctx.from.username)) return; - if (config.pausedReason == null) return await ctx.reply("Already running"); + if (config.pausedReason == null) return ctx.reply("Already running"); config.pausedReason = null; - return await ctx.reply("Resumed"); + return ctx.reply("Resumed"); +}); + +bot.command("setsdapiurl", async (ctx) => { + if (!ctx.from?.username) return; + const config = ctx.session.global; + if (!config.adminUsernames.includes(ctx.from.username)) return; + if (!ctx.match) return ctx.reply("Please specify an URL"); + let url: URL; + try { + url = new URL(ctx.match); + } catch { + return ctx.reply("Invalid URL"); + } + let resp: Response; + try { + resp = await fetch(new URL("config", url)); + } catch (err) { + return ctx.reply(`Could not connect: ${err}`); + } + if (!resp.ok) { + return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`); + } + let data: unknown; + try { + data = await resp.json(); + } catch { + return ctx.reply("Invalid response from API"); + } + if (data != null && typeof data === "object" && "version" in data) { + config.sdApiUrl = url.toString(); + return ctx.reply(`Now using SD at ${url} running version ${data.version}`); + } else { + return ctx.reply("Invalid response from API"); + } +}); + +bot.command("setsdparam", (ctx) => { + if (!ctx.from?.username) return; + const config = ctx.session.global; + if (!config.adminUsernames.includes(ctx.from.username)) return; + let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim()); + if (!param) return ctx.reply("Please specify a parameter"); + if (value == null) return ctx.reply("Please specify a value after the ="); + param = param.toLowerCase().replace(/[\s_]+/g, ""); + if (config.defaultParams == null) config.defaultParams = {}; + switch (param) { + case "steps": { + const steps = parseInt(value); + if (isNaN(steps)) return ctx.reply("Invalid number value"); + if (steps > 100) return ctx.reply("Steps must be less than 100"); + if (steps < 10) return ctx.reply("Steps must be greater than 10"); + config.defaultParams.steps = steps; + return ctx.reply("Steps set to " + steps); + } + case "detail": + case "cfgscale": { + const detail = parseInt(value); + if (isNaN(detail)) return ctx.reply("Invalid number value"); + if (detail > 20) return ctx.reply("Detail must be less than 20"); + if (detail < 1) return ctx.reply("Detail must be greater than 1"); + config.defaultParams.cfg_scale = detail; + return ctx.reply("Detail set to " + detail); + } + case "niter": + case "niters": { + const nIter = parseInt(value); + if (isNaN(nIter)) return ctx.reply("Invalid number value"); + if (nIter > 10) return ctx.reply("Iterations must be less than 10"); + if (nIter < 1) return ctx.reply("Iterations must be greater than 1"); + config.defaultParams.n_iter = nIter; + return ctx.reply("Iterations set to " + nIter); + } + case "batchsize": { + const batchSize = parseInt(value); + if (isNaN(batchSize)) return ctx.reply("Invalid number value"); + if (batchSize > 8) return ctx.reply("Batch size must be less than 8"); + if (batchSize < 1) return ctx.reply("Batch size must be greater than 1"); + config.defaultParams.batch_size = batchSize; + return ctx.reply("Batch size set to " + batchSize); + } + case "size": { + let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim())); + if (!width || !height || isNaN(width) || isNaN(height)) { + return ctx.reply("Invalid size value"); + } + if (width > 2048) return ctx.reply("Width must be less than 2048"); + if (height > 2048) return ctx.reply("Height must be less than 2048"); + // find closest multiple of 64 + width = Math.round(width / 64) * 64; + height = Math.round(height / 64) * 64; + if (width <= 0) return ctx.reply("Width too small"); + if (height <= 0) return ctx.reply("Height too small"); + config.defaultParams.width = width; + config.defaultParams.height = height; + return ctx.reply(`Size set to ${width}x${height}`); + } + case "negativeprompt": { + config.defaultParams.negative_prompt = value; + return ctx.reply(`Negative prompt set to: ${value}`); + } + default: { + return ctx.reply("Invalid parameter"); + } + } +}); + +bot.command("sdparams", (ctx) => { + if (!ctx.from?.username) return; + const config = ctx.session.global; + return ctx.replyFmt(fmt`Current config:\n\n${ + fmtArray( + Object.entries(config.defaultParams ?? {}).map(([key, value]) => + fmt`${bold(key)} = ${String(value)}` + ), + "\n", + ) + }`); }); bot.catch((err) => { diff --git a/config.ts b/config.ts deleted file mode 100644 index 4d41272..0000000 --- a/config.ts +++ /dev/null @@ -1,15 +0,0 @@ -export const config: Config = { - adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean), - pausedReason: null, - sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/", - maxUserJobs: 3, - maxJobs: 20, -}; - -interface Config { - adminUsernames: string[]; - pausedReason: string | null; - sdApiUrl: string; - maxUserJobs: number; - maxJobs: number; -} diff --git a/deno.jsonc b/deno.jsonc index 9be0600..497b91f 100644 --- a/deno.jsonc +++ b/deno.jsonc @@ -1,7 +1,7 @@ { "tasks": { - "dev": "deno run --watch --allow-env --allow-read --allow-net main.ts", - "start": "deno run --allow-env --allow-read --allow-net main.ts" + "dev": "deno run --watch --unstable --allow-env --allow-read --allow-write --allow-net main.ts", + "start": "deno run --unstable --allow-env --allow-read --allow-write --allow-net main.ts" }, "fmt": { "lineWidth": 100 diff --git a/deps.ts b/deps.ts index faa9429..b84b8b8 100644 --- a/deps.ts +++ b/deps.ts @@ -1,4 +1,6 @@ export * from "https://deno.land/x/grammy@v1.18.1/mod.ts"; export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; +export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts"; export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1"; +export * from "https://deno.land/x/zod/mod.ts"; diff --git a/queue.ts b/queue.ts index a1e0db5..57dcbab 100644 --- a/queue.ts +++ b/queue.ts @@ -1,6 +1,5 @@ import { InputFile, InputMediaBuilder } from "./deps.ts"; -import { config } from "./config.ts"; -import { bot } from "./bot.ts"; +import { bot, getGlobalSession } from "./bot.ts"; import { formatOrdinal } from "./intl.ts"; import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts"; import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts"; @@ -56,12 +55,12 @@ export async function processQueue() { ) .catch(() => {}); }; + const config = await getGlobalSession(); const response = await txt2img( config.sdApiUrl, - { ...defaultParams, ...job.params }, + { ...config.defaultParams, ...job.params }, onProgress, ); - console.log( `Generated ${response.images.length} images (${ response.images @@ -69,11 +68,11 @@ export async function processQueue() { .join(", ") }) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`, ); - bot.api.editMessageText( + await bot.api.editMessageText( job.chatId, progressMessage.message_id, `Uploading your images...`, - ); + ).catch(() => {}); const inputFiles = await Promise.all( response.images.map(async (imageBase64, idx) => { const mimeType = mimeTypeFromBase64(imageBase64); @@ -110,14 +109,3 @@ export async function processQueue() { } } } - -const defaultParams: Partial = { - batch_size: 1, - n_iter: 1, - width: 128 * 2, - height: 128 * 3, - steps: 20, - cfg_scale: 9, - send_images: true, - negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4", -};