feat: allow setting sd parameters via commands
This commit is contained in:
parent
2f5058ac45
commit
5722238c06
205
bot.ts
205
bot.ts
|
@ -4,19 +4,81 @@ import {
|
||||||
bold,
|
bold,
|
||||||
Bot,
|
Bot,
|
||||||
Context,
|
Context,
|
||||||
|
DenoKVAdapter,
|
||||||
fmt,
|
fmt,
|
||||||
hydrateReply,
|
hydrateReply,
|
||||||
ParseModeFlavor,
|
ParseModeFlavor,
|
||||||
|
session,
|
||||||
|
SessionFlavor,
|
||||||
} from "./deps.ts";
|
} from "./deps.ts";
|
||||||
import { fmtArray, formatOrdinal } from "./intl.ts";
|
import { fmtArray, formatOrdinal } from "./intl.ts";
|
||||||
import { config } from "./config.ts";
|
|
||||||
import { queue } from "./queue.ts";
|
import { queue } from "./queue.ts";
|
||||||
|
import { SdRequest } from "./sd.ts";
|
||||||
|
|
||||||
export const bot = new Bot<ParseModeFlavor<Context>>(Deno.env.get("TG_BOT_TOKEN") ?? "");
|
type AppContext = ParseModeFlavor<Context> & SessionFlavor<SessionData>;
|
||||||
|
|
||||||
|
interface SessionData {
|
||||||
|
global: {
|
||||||
|
adminUsernames: string[];
|
||||||
|
pausedReason: string | null;
|
||||||
|
sdApiUrl: string;
|
||||||
|
maxUserJobs: number;
|
||||||
|
maxJobs: number;
|
||||||
|
defaultParams?: Partial<SdRequest>;
|
||||||
|
};
|
||||||
|
user: {
|
||||||
|
steps: number;
|
||||||
|
detail: number;
|
||||||
|
batchSize: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export const bot = new Bot<AppContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
|
||||||
bot.use(autoQuote);
|
bot.use(autoQuote);
|
||||||
bot.use(hydrateReply);
|
bot.use(hydrateReply);
|
||||||
bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 }));
|
bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 }));
|
||||||
|
|
||||||
|
const db = await Deno.openKv("./app.db");
|
||||||
|
|
||||||
|
const getDefaultGlobalSession = (): SessionData["global"] => ({
|
||||||
|
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
|
||||||
|
pausedReason: null,
|
||||||
|
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
|
||||||
|
maxUserJobs: 3,
|
||||||
|
maxJobs: 20,
|
||||||
|
defaultParams: {
|
||||||
|
batch_size: 1,
|
||||||
|
n_iter: 1,
|
||||||
|
width: 128 * 2,
|
||||||
|
height: 128 * 3,
|
||||||
|
steps: 20,
|
||||||
|
cfg_scale: 9,
|
||||||
|
send_images: true,
|
||||||
|
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
bot.use(session<SessionData, AppContext>({
|
||||||
|
type: "multi",
|
||||||
|
global: {
|
||||||
|
getSessionKey: () => "global",
|
||||||
|
initial: getDefaultGlobalSession,
|
||||||
|
storage: new DenoKVAdapter(db),
|
||||||
|
},
|
||||||
|
user: {
|
||||||
|
initial: () => ({
|
||||||
|
steps: 20,
|
||||||
|
detail: 8,
|
||||||
|
batchSize: 2,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
export async function getGlobalSession(): Promise<SessionData["global"]> {
|
||||||
|
const entry = await db.get<SessionData["global"]>(["sessions", "global"]);
|
||||||
|
return entry.value ?? getDefaultGlobalSession();
|
||||||
|
}
|
||||||
|
|
||||||
bot.api.setMyShortDescription("I can generate furry images from text");
|
bot.api.setMyShortDescription("I can generate furry images from text");
|
||||||
bot.api.setMyDescription(
|
bot.api.setMyDescription(
|
||||||
"I can generate furry images from text. Send /txt2img to generate an image.",
|
"I can generate furry images from text. Send /txt2img to generate an image.",
|
||||||
|
@ -24,6 +86,7 @@ bot.api.setMyDescription(
|
||||||
bot.api.setMyCommands([
|
bot.api.setMyCommands([
|
||||||
{ command: "txt2img", description: "Generate an image" },
|
{ command: "txt2img", description: "Generate an image" },
|
||||||
{ command: "queue", description: "Show the current queue" },
|
{ command: "queue", description: "Show the current queue" },
|
||||||
|
{ command: "sdparams", description: "Show the current SD parameters" },
|
||||||
]);
|
]);
|
||||||
|
|
||||||
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
|
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
|
||||||
|
@ -32,8 +95,9 @@ bot.command("txt2img", async (ctx) => {
|
||||||
if (!ctx.from?.id) {
|
if (!ctx.from?.id) {
|
||||||
return ctx.reply("I don't know who you are");
|
return ctx.reply("I don't know who you are");
|
||||||
}
|
}
|
||||||
|
const config = ctx.session.global;
|
||||||
if (config.pausedReason != null) {
|
if (config.pausedReason != null) {
|
||||||
return ctx.reply(`I'm paused: ${config.pausedReason}`);
|
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
|
||||||
}
|
}
|
||||||
if (queue.length >= config.maxJobs) {
|
if (queue.length >= config.maxJobs) {
|
||||||
return ctx.reply(
|
return ctx.reply(
|
||||||
|
@ -67,9 +131,9 @@ bot.command("txt2img", async (ctx) => {
|
||||||
console.log(`Enqueued job for ${userName} in chat ${chatName}`);
|
console.log(`Enqueued job for ${userName} in chat ${chatName}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
bot.command("queue", async (ctx) => {
|
bot.command("queue", (ctx) => {
|
||||||
if (queue.length === 0) return ctx.reply("Queue is empty");
|
if (queue.length === 0) return ctx.reply("Queue is empty");
|
||||||
return await ctx.replyFmt(
|
return ctx.replyFmt(
|
||||||
fmt`Current queue:\n\n${
|
fmt`Current queue:\n\n${
|
||||||
fmtArray(
|
fmtArray(
|
||||||
queue.map((job, index) =>
|
queue.map((job, index) =>
|
||||||
|
@ -81,22 +145,141 @@ bot.command("queue", async (ctx) => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
bot.command("pause", async (ctx) => {
|
bot.command("pause", (ctx) => {
|
||||||
if (!ctx.from?.username) return;
|
if (!ctx.from?.username) return;
|
||||||
|
const config = ctx.session.global;
|
||||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||||
if (config.pausedReason != null) {
|
if (config.pausedReason != null) {
|
||||||
return await ctx.reply(`Already paused: ${config.pausedReason}`);
|
return ctx.reply(`Already paused: ${config.pausedReason}`);
|
||||||
}
|
}
|
||||||
config.pausedReason = ctx.match ?? "No reason given";
|
config.pausedReason = ctx.match ?? "No reason given";
|
||||||
return await ctx.reply("Paused");
|
return ctx.reply("Paused");
|
||||||
});
|
});
|
||||||
|
|
||||||
bot.command("resume", async (ctx) => {
|
bot.command("resume", (ctx) => {
|
||||||
if (!ctx.from?.username) return;
|
if (!ctx.from?.username) return;
|
||||||
|
const config = ctx.session.global;
|
||||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||||
if (config.pausedReason == null) return await ctx.reply("Already running");
|
if (config.pausedReason == null) return ctx.reply("Already running");
|
||||||
config.pausedReason = null;
|
config.pausedReason = null;
|
||||||
return await ctx.reply("Resumed");
|
return ctx.reply("Resumed");
|
||||||
|
});
|
||||||
|
|
||||||
|
bot.command("setsdapiurl", async (ctx) => {
|
||||||
|
if (!ctx.from?.username) return;
|
||||||
|
const config = ctx.session.global;
|
||||||
|
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||||
|
if (!ctx.match) return ctx.reply("Please specify an URL");
|
||||||
|
let url: URL;
|
||||||
|
try {
|
||||||
|
url = new URL(ctx.match);
|
||||||
|
} catch {
|
||||||
|
return ctx.reply("Invalid URL");
|
||||||
|
}
|
||||||
|
let resp: Response;
|
||||||
|
try {
|
||||||
|
resp = await fetch(new URL("config", url));
|
||||||
|
} catch (err) {
|
||||||
|
return ctx.reply(`Could not connect: ${err}`);
|
||||||
|
}
|
||||||
|
if (!resp.ok) {
|
||||||
|
return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`);
|
||||||
|
}
|
||||||
|
let data: unknown;
|
||||||
|
try {
|
||||||
|
data = await resp.json();
|
||||||
|
} catch {
|
||||||
|
return ctx.reply("Invalid response from API");
|
||||||
|
}
|
||||||
|
if (data != null && typeof data === "object" && "version" in data) {
|
||||||
|
config.sdApiUrl = url.toString();
|
||||||
|
return ctx.reply(`Now using SD at ${url} running version ${data.version}`);
|
||||||
|
} else {
|
||||||
|
return ctx.reply("Invalid response from API");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
bot.command("setsdparam", (ctx) => {
|
||||||
|
if (!ctx.from?.username) return;
|
||||||
|
const config = ctx.session.global;
|
||||||
|
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||||
|
let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim());
|
||||||
|
if (!param) return ctx.reply("Please specify a parameter");
|
||||||
|
if (value == null) return ctx.reply("Please specify a value after the =");
|
||||||
|
param = param.toLowerCase().replace(/[\s_]+/g, "");
|
||||||
|
if (config.defaultParams == null) config.defaultParams = {};
|
||||||
|
switch (param) {
|
||||||
|
case "steps": {
|
||||||
|
const steps = parseInt(value);
|
||||||
|
if (isNaN(steps)) return ctx.reply("Invalid number value");
|
||||||
|
if (steps > 100) return ctx.reply("Steps must be less than 100");
|
||||||
|
if (steps < 10) return ctx.reply("Steps must be greater than 10");
|
||||||
|
config.defaultParams.steps = steps;
|
||||||
|
return ctx.reply("Steps set to " + steps);
|
||||||
|
}
|
||||||
|
case "detail":
|
||||||
|
case "cfgscale": {
|
||||||
|
const detail = parseInt(value);
|
||||||
|
if (isNaN(detail)) return ctx.reply("Invalid number value");
|
||||||
|
if (detail > 20) return ctx.reply("Detail must be less than 20");
|
||||||
|
if (detail < 1) return ctx.reply("Detail must be greater than 1");
|
||||||
|
config.defaultParams.cfg_scale = detail;
|
||||||
|
return ctx.reply("Detail set to " + detail);
|
||||||
|
}
|
||||||
|
case "niter":
|
||||||
|
case "niters": {
|
||||||
|
const nIter = parseInt(value);
|
||||||
|
if (isNaN(nIter)) return ctx.reply("Invalid number value");
|
||||||
|
if (nIter > 10) return ctx.reply("Iterations must be less than 10");
|
||||||
|
if (nIter < 1) return ctx.reply("Iterations must be greater than 1");
|
||||||
|
config.defaultParams.n_iter = nIter;
|
||||||
|
return ctx.reply("Iterations set to " + nIter);
|
||||||
|
}
|
||||||
|
case "batchsize": {
|
||||||
|
const batchSize = parseInt(value);
|
||||||
|
if (isNaN(batchSize)) return ctx.reply("Invalid number value");
|
||||||
|
if (batchSize > 8) return ctx.reply("Batch size must be less than 8");
|
||||||
|
if (batchSize < 1) return ctx.reply("Batch size must be greater than 1");
|
||||||
|
config.defaultParams.batch_size = batchSize;
|
||||||
|
return ctx.reply("Batch size set to " + batchSize);
|
||||||
|
}
|
||||||
|
case "size": {
|
||||||
|
let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim()));
|
||||||
|
if (!width || !height || isNaN(width) || isNaN(height)) {
|
||||||
|
return ctx.reply("Invalid size value");
|
||||||
|
}
|
||||||
|
if (width > 2048) return ctx.reply("Width must be less than 2048");
|
||||||
|
if (height > 2048) return ctx.reply("Height must be less than 2048");
|
||||||
|
// find closest multiple of 64
|
||||||
|
width = Math.round(width / 64) * 64;
|
||||||
|
height = Math.round(height / 64) * 64;
|
||||||
|
if (width <= 0) return ctx.reply("Width too small");
|
||||||
|
if (height <= 0) return ctx.reply("Height too small");
|
||||||
|
config.defaultParams.width = width;
|
||||||
|
config.defaultParams.height = height;
|
||||||
|
return ctx.reply(`Size set to ${width}x${height}`);
|
||||||
|
}
|
||||||
|
case "negativeprompt": {
|
||||||
|
config.defaultParams.negative_prompt = value;
|
||||||
|
return ctx.reply(`Negative prompt set to: ${value}`);
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return ctx.reply("Invalid parameter");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
bot.command("sdparams", (ctx) => {
|
||||||
|
if (!ctx.from?.username) return;
|
||||||
|
const config = ctx.session.global;
|
||||||
|
return ctx.replyFmt(fmt`Current config:\n\n${
|
||||||
|
fmtArray(
|
||||||
|
Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
|
||||||
|
fmt`${bold(key)} = ${String(value)}`
|
||||||
|
),
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
bot.catch((err) => {
|
bot.catch((err) => {
|
||||||
|
|
15
config.ts
15
config.ts
|
@ -1,15 +0,0 @@
|
||||||
export const config: Config = {
|
|
||||||
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
|
|
||||||
pausedReason: null,
|
|
||||||
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
|
|
||||||
maxUserJobs: 3,
|
|
||||||
maxJobs: 20,
|
|
||||||
};
|
|
||||||
|
|
||||||
interface Config {
|
|
||||||
adminUsernames: string[];
|
|
||||||
pausedReason: string | null;
|
|
||||||
sdApiUrl: string;
|
|
||||||
maxUserJobs: number;
|
|
||||||
maxJobs: number;
|
|
||||||
}
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"tasks": {
|
"tasks": {
|
||||||
"dev": "deno run --watch --allow-env --allow-read --allow-net main.ts",
|
"dev": "deno run --watch --unstable --allow-env --allow-read --allow-write --allow-net main.ts",
|
||||||
"start": "deno run --allow-env --allow-read --allow-net main.ts"
|
"start": "deno run --unstable --allow-env --allow-read --allow-write --allow-net main.ts"
|
||||||
},
|
},
|
||||||
"fmt": {
|
"fmt": {
|
||||||
"lineWidth": 100
|
"lineWidth": 100
|
||||||
|
|
2
deps.ts
2
deps.ts
|
@ -1,4 +1,6 @@
|
||||||
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
|
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
|
||||||
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
|
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
|
||||||
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
|
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
|
||||||
|
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
|
||||||
export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1";
|
export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1";
|
||||||
|
export * from "https://deno.land/x/zod/mod.ts";
|
||||||
|
|
22
queue.ts
22
queue.ts
|
@ -1,6 +1,5 @@
|
||||||
import { InputFile, InputMediaBuilder } from "./deps.ts";
|
import { InputFile, InputMediaBuilder } from "./deps.ts";
|
||||||
import { config } from "./config.ts";
|
import { bot, getGlobalSession } from "./bot.ts";
|
||||||
import { bot } from "./bot.ts";
|
|
||||||
import { formatOrdinal } from "./intl.ts";
|
import { formatOrdinal } from "./intl.ts";
|
||||||
import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts";
|
import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts";
|
||||||
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
|
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
|
||||||
|
@ -56,12 +55,12 @@ export async function processQueue() {
|
||||||
)
|
)
|
||||||
.catch(() => {});
|
.catch(() => {});
|
||||||
};
|
};
|
||||||
|
const config = await getGlobalSession();
|
||||||
const response = await txt2img(
|
const response = await txt2img(
|
||||||
config.sdApiUrl,
|
config.sdApiUrl,
|
||||||
{ ...defaultParams, ...job.params },
|
{ ...config.defaultParams, ...job.params },
|
||||||
onProgress,
|
onProgress,
|
||||||
);
|
);
|
||||||
|
|
||||||
console.log(
|
console.log(
|
||||||
`Generated ${response.images.length} images (${
|
`Generated ${response.images.length} images (${
|
||||||
response.images
|
response.images
|
||||||
|
@ -69,11 +68,11 @@ export async function processQueue() {
|
||||||
.join(", ")
|
.join(", ")
|
||||||
}) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`,
|
}) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`,
|
||||||
);
|
);
|
||||||
bot.api.editMessageText(
|
await bot.api.editMessageText(
|
||||||
job.chatId,
|
job.chatId,
|
||||||
progressMessage.message_id,
|
progressMessage.message_id,
|
||||||
`Uploading your images...`,
|
`Uploading your images...`,
|
||||||
);
|
).catch(() => {});
|
||||||
const inputFiles = await Promise.all(
|
const inputFiles = await Promise.all(
|
||||||
response.images.map(async (imageBase64, idx) => {
|
response.images.map(async (imageBase64, idx) => {
|
||||||
const mimeType = mimeTypeFromBase64(imageBase64);
|
const mimeType = mimeTypeFromBase64(imageBase64);
|
||||||
|
@ -110,14 +109,3 @@ export async function processQueue() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultParams: Partial<SdRequest> = {
|
|
||||||
batch_size: 1,
|
|
||||||
n_iter: 1,
|
|
||||||
width: 128 * 2,
|
|
||||||
height: 128 * 3,
|
|
||||||
steps: 20,
|
|
||||||
cfg_scale: 9,
|
|
||||||
send_images: true,
|
|
||||||
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
|
|
||||||
};
|
|
||||||
|
|
Loading…
Reference in New Issue