Compare commits

..

3 Commits

10 changed files with 563 additions and 389 deletions

View File

@ -5,7 +5,7 @@ Telegram bot for generating images from text.
## Requirements
- [Deno](https://deno.land/)
- [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/)
- [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/)
## Options
@ -13,7 +13,8 @@ You can put these in `.env` file or pass them as environment variables.
- `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather))
- `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`)
- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands (optional)
- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands
(optional)
## Running

296
bot.ts Normal file
View File

@ -0,0 +1,296 @@
import {
autoQuote,
autoRetry,
bold,
Bot,
Context,
DenoKVAdapter,
fmt,
hydrateReply,
ParseModeFlavor,
session,
SessionFlavor,
} from "./deps.ts";
import { fmtArray, formatOrdinal } from "./intl.ts";
import { queue } from "./queue.ts";
import { SdRequest } from "./sd.ts";
type AppContext = ParseModeFlavor<Context> & SessionFlavor<SessionData>;
interface SessionData {
global: {
adminUsernames: string[];
pausedReason: string | null;
sdApiUrl: string;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdRequest>;
};
user: {
steps: number;
detail: number;
batchSize: number;
};
}
export const bot = new Bot<AppContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(autoQuote);
bot.use(hydrateReply);
bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 }));
const db = await Deno.openKv("./app.db");
const getDefaultGlobalSession = (): SessionData["global"] => ({
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
pausedReason: null,
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
maxUserJobs: 3,
maxJobs: 20,
defaultParams: {
batch_size: 1,
n_iter: 1,
width: 128 * 2,
height: 128 * 3,
steps: 20,
cfg_scale: 9,
send_images: true,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
});
bot.use(session<SessionData, AppContext>({
type: "multi",
global: {
getSessionKey: () => "global",
initial: getDefaultGlobalSession,
storage: new DenoKVAdapter(db),
},
user: {
initial: () => ({
steps: 20,
detail: 8,
batchSize: 2,
}),
},
}));
export async function getGlobalSession(): Promise<SessionData["global"]> {
const entry = await db.get<SessionData["global"]>(["sessions", "global"]);
return entry.value ?? getDefaultGlobalSession();
}
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
{ command: "sdparams", description: "Show the current SD parameters" },
]);
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
bot.command("txt2img", async (ctx) => {
if (!ctx.from?.id) {
return ctx.reply("I don't know who you are");
}
const config = ctx.session.global;
if (config.pausedReason != null) {
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
}
if (queue.length >= config.maxJobs) {
return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
);
}
const jobCount = queue.filter((job) => job.userId === ctx.from.id).length;
if (jobCount >= config.maxUserJobs) {
return ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
);
}
if (!ctx.match) {
return ctx.reply("Please describe what you want to see after the command");
}
const place = queue.length + 1;
const queueMessage = await ctx.reply(`You are ${formatOrdinal(place)} in queue.`);
const userName = [ctx.from.first_name, ctx.from.last_name].filter(Boolean).join(" ");
const chatName = ctx.chat.type === "supergroup" || ctx.chat.type === "group"
? ctx.chat.title
: "private chat";
queue.push({
params: { prompt: ctx.match },
userId: ctx.from.id,
userName,
chatId: ctx.chat.id,
chatName,
requestMessageId: ctx.message.message_id,
statusMessageId: queueMessage.message_id,
});
console.log(`Enqueued job for ${userName} in chat ${chatName}`);
});
bot.command("queue", (ctx) => {
if (queue.length === 0) return ctx.reply("Queue is empty");
return ctx.replyFmt(
fmt`Current queue:\n\n${
fmtArray(
queue.map((job, index) =>
fmt`${bold(index + 1)}. ${bold(job.userName)} in ${bold(job.chatName)}`
),
"\n",
)
}`,
);
});
bot.command("pause", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason != null) {
return ctx.reply(`Already paused: ${config.pausedReason}`);
}
config.pausedReason = ctx.match ?? "No reason given";
return ctx.reply("Paused");
});
bot.command("resume", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason == null) return ctx.reply("Already running");
config.pausedReason = null;
return ctx.reply("Resumed");
});
bot.command("setsdapiurl", async (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (!ctx.match) return ctx.reply("Please specify an URL");
let url: URL;
try {
url = new URL(ctx.match);
} catch {
return ctx.reply("Invalid URL");
}
let resp: Response;
try {
resp = await fetch(new URL("config", url));
} catch (err) {
return ctx.reply(`Could not connect: ${err}`);
}
if (!resp.ok) {
return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`);
}
let data: unknown;
try {
data = await resp.json();
} catch {
return ctx.reply("Invalid response from API");
}
if (data != null && typeof data === "object" && "version" in data) {
config.sdApiUrl = url.toString();
return ctx.reply(`Now using SD at ${url} running version ${data.version}`);
} else {
return ctx.reply("Invalid response from API");
}
});
bot.command("setsdparam", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim());
if (!param) return ctx.reply("Please specify a parameter");
if (value == null) return ctx.reply("Please specify a value after the =");
param = param.toLowerCase().replace(/[\s_]+/g, "");
if (config.defaultParams == null) config.defaultParams = {};
switch (param) {
case "steps": {
const steps = parseInt(value);
if (isNaN(steps)) return ctx.reply("Invalid number value");
if (steps > 100) return ctx.reply("Steps must be less than 100");
if (steps < 10) return ctx.reply("Steps must be greater than 10");
config.defaultParams.steps = steps;
return ctx.reply("Steps set to " + steps);
}
case "detail":
case "cfgscale": {
const detail = parseInt(value);
if (isNaN(detail)) return ctx.reply("Invalid number value");
if (detail > 20) return ctx.reply("Detail must be less than 20");
if (detail < 1) return ctx.reply("Detail must be greater than 1");
config.defaultParams.cfg_scale = detail;
return ctx.reply("Detail set to " + detail);
}
case "niter":
case "niters": {
const nIter = parseInt(value);
if (isNaN(nIter)) return ctx.reply("Invalid number value");
if (nIter > 10) return ctx.reply("Iterations must be less than 10");
if (nIter < 1) return ctx.reply("Iterations must be greater than 1");
config.defaultParams.n_iter = nIter;
return ctx.reply("Iterations set to " + nIter);
}
case "batchsize": {
const batchSize = parseInt(value);
if (isNaN(batchSize)) return ctx.reply("Invalid number value");
if (batchSize > 8) return ctx.reply("Batch size must be less than 8");
if (batchSize < 1) return ctx.reply("Batch size must be greater than 1");
config.defaultParams.batch_size = batchSize;
return ctx.reply("Batch size set to " + batchSize);
}
case "size": {
let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim()));
if (!width || !height || isNaN(width) || isNaN(height)) {
return ctx.reply("Invalid size value");
}
if (width > 2048) return ctx.reply("Width must be less than 2048");
if (height > 2048) return ctx.reply("Height must be less than 2048");
// find closest multiple of 64
width = Math.round(width / 64) * 64;
height = Math.round(height / 64) * 64;
if (width <= 0) return ctx.reply("Width too small");
if (height <= 0) return ctx.reply("Height too small");
config.defaultParams.width = width;
config.defaultParams.height = height;
return ctx.reply(`Size set to ${width}x${height}`);
}
case "negativeprompt": {
config.defaultParams.negative_prompt = value;
return ctx.reply(`Negative prompt set to: ${value}`);
}
default: {
return ctx.reply("Invalid parameter");
}
}
});
bot.command("sdparams", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
return ctx.replyFmt(fmt`Current config:\n\n${
fmtArray(
Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
fmt`${bold(key)} = ${String(value)}`
),
"\n",
)
}`);
});
bot.catch((err) => {
let msg = "Error processing update";
const { from, chat } = err.ctx;
if (from?.first_name) msg += ` from ${from.first_name}`;
if (from?.last_name) msg += ` ${from.last_name}`;
if (from?.username) msg += ` (@${from.username})`;
if (chat?.type === "supergroup" || chat?.type === "group") {
msg += ` in ${chat.title}`;
if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`;
}
console.error(msg, err.error);
});

View File

@ -1,6 +1,9 @@
{
"tasks": {
"dev": "deno run --watch --allow-env --allow-read --allow-net main.ts",
"start": "deno run --allow-env --allow-read --allow-net main.ts"
"dev": "deno run --watch --unstable --allow-env --allow-read --allow-write --allow-net main.ts",
"start": "deno run --unstable --allow-env --allow-read --allow-write --allow-net main.ts"
},
"fmt": {
"lineWidth": 100
}
}

6
deps.ts Normal file
View File

@ -0,0 +1,6 @@
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1";
export * from "https://deno.land/x/zod/mod.ts";

0
fmtArray.ts Normal file
View File

33
intl.ts Normal file
View File

@ -0,0 +1,33 @@
import { FormattedString } from "./deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
/**
* Like `fmt` from `grammy_parse_mode` but accepts an array instead of template string.
* @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182
*/
export function fmtArray(
stringLikes: FormattedString[],
separator = "",
): FormattedString {
let text = "";
const entities: ConstructorParameters<typeof FormattedString>[1] = [];
for (let i = 0; i < stringLikes.length; i++) {
const stringLike = stringLikes[i];
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
text += stringLike.toString();
if (i < stringLikes.length - 1) text += separator;
}
return new FormattedString(text, entities);
}

391
main.ts
View File

@ -1,387 +1,8 @@
import {
Bot,
Context,
InputFile,
InputMediaBuilder,
} from "https://deno.land/x/grammy@v1.18.1/mod.ts";
import { autoQuote } from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
import {
fmt,
hydrateReply,
ParseModeFlavor,
} from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
import "https://deno.land/x/dotenv@v3.2.2/load.ts";
import {
FormattedString,
bold,
} from "https://deno.land/x/grammy_parse_mode@1.7.1/format.ts";
import { autoRetry } from "https://esm.sh/@grammyjs/auto-retry";
import { MessageEntity } from "https://deno.land/x/grammy@v1.18.1/types.ts";
import "https://deno.land/std@0.201.0/dotenv/load.ts";
import { bot } from "./bot.ts";
import { processQueue } from "./queue.ts";
const maxUserJobs = 3;
const maxJobs = 10;
let isRunning = true;
const sdApiUrl = Deno.env.get("SD_API_URL");
if (!sdApiUrl) throw new Error("SD_API_URL not set");
console.log("Using SD API URL:", sdApiUrl);
const sdConfigUrl = new URL("/config", sdApiUrl);
const sdConfigRequest = await fetch(sdConfigUrl);
if (!sdConfigRequest.ok)
throw new Error(
`Failed to fetch SD config from ${sdConfigUrl}: ${sdConfigRequest.statusText}`
);
const sdConfig = await sdConfigRequest.json();
console.log("Using SD WebUI version:", String(sdConfig.version).trim());
const adminUsernames = (Deno.env.get("ADMIN_USERNAMES") ?? "")
.split(",")
.filter(Boolean);
const tgBotToken = Deno.env.get("TG_BOT_TOKEN");
if (!tgBotToken) throw new Error("TG_BOT_TOKEN not set");
const bot = new Bot<ParseModeFlavor<Context>>(tgBotToken);
bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 30 }));
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. Send /txt2img to generate an image."
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
await Promise.all([
bot.start(),
processQueue(),
]);
bot.use(autoQuote);
bot.use(hydrateReply);
bot.command("start", (ctx) =>
ctx.reply("Hello! Use the /txt2img command to generate an image")
);
bot.command("txt2img", async (ctx) => {
if (!ctx.from?.id) {
return ctx.reply("I don't know who you are");
}
if (!isRunning) {
return ctx.reply("I'm currently paused. Try again later.");
}
if (queue.length >= maxJobs) {
return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${maxJobs})`
);
}
const jobCount = queue.filter((job) => job.userId === ctx.from.id).length;
if (jobCount >= maxUserJobs) {
return ctx.reply(
`You already have ${maxUserJobs} jobs in queue. Try again later.`
);
}
if (!ctx.match) {
return ctx.reply("Please describe what you want to see");
}
const place = queue.length + 1;
const queueMessage = await ctx.reply(
`You are ${formatOrdinal(place)} in queue.`
);
const userName = [ctx.from.first_name, ctx.from.last_name]
.filter(Boolean)
.join(" ");
const chatName =
ctx.chat.type === "supergroup" || ctx.chat.type === "group"
? ctx.chat.title
: "private chat";
queue.push({
params: { prompt: ctx.match },
userId: ctx.from.id,
userName,
chatId: ctx.chat.id,
chatName,
requestMessageId: ctx.message.message_id,
statusMessageId: queueMessage.message_id,
});
console.log(`Enqueued job for ${userName} in chat ${chatName}`);
});
bot.command("queue", async (ctx) => {
if (queue.length === 0) return ctx.reply("Queue is empty");
return await ctx.replyFmt(
fmt`Current queue:\n\n${fmtArray(
queue.map(
(job, index) =>
fmt`${bold(index + 1)}. ${bold(job.userName)} in ${bold(
job.chatName
)}`
),
"\n"
)}`
);
});
bot.command("pause", async (ctx) => {
if (!ctx.from?.username) return;
if (!adminUsernames.includes(ctx.from.username)) return;
if (!isRunning) return await ctx.reply("Already paused");
isRunning = false;
return await ctx.reply("Paused");
});
bot.command("resume", async (ctx) => {
if (!ctx.from?.username) return;
if (!adminUsernames.includes(ctx.from.username)) return;
if (isRunning) return await ctx.reply("Already running");
isRunning = true;
return await ctx.reply("Resumed");
});
bot.catch((err) => {
let msg = "Error processing update";
const { from, chat } = err.ctx;
if (from?.first_name) msg += ` from ${from.first_name}`;
if (from?.last_name) msg += ` ${from.last_name}`;
if (from?.username) msg += ` (@${from.username})`;
if (chat?.type === "supergroup" || chat?.type === "group") {
msg += ` in ${chat.title}`;
if (chat.type === "supergroup" && chat.username)
msg += ` (@${chat.username})`;
}
console.error(msg, err.error);
});
const queue: Job[] = [];
interface Job {
params: Partial<SdRequest>;
userId: number;
userName: string;
chatId: number;
chatName: string;
requestMessageId: number;
statusMessageId: number;
}
async function processQueue() {
while (true) {
const job = queue.shift();
if (!job) {
await new Promise((resolve) => setTimeout(resolve, 1000));
continue;
}
for (const [index, job] of queue.entries()) {
const place = index + 1;
await bot.api
.editMessageText(
job.chatId,
job.statusMessageId,
`You are ${formatOrdinal(place)} in queue.`
)
.catch(() => {});
}
try {
await bot.api.deleteMessage(job.chatId, job.statusMessageId);
const progressMessage = await bot.api.sendMessage(
job.chatId,
"Generating your prompt now...",
{ reply_to_message_id: job.requestMessageId }
);
const onProgress = (progress: SdProgressResponse) => {
bot.api
.editMessageText(
job.chatId,
progressMessage.message_id,
`Generating your prompt now... ${Math.round(
progress.progress * 100
)}%`
)
.catch(() => {});
};
const response = await txt2img(
{ ...defaultParams, ...job.params },
onProgress
);
console.log(
`Generated image for ${job.userName} in ${job.chatName}: ${job.params.prompt}`
);
bot.api.editMessageText(
job.chatId,
progressMessage.message_id,
`Uploading your images...`
);
const inputFiles = await Promise.all(
response.images.slice(1).map(async (imageBase64) => {
const imageBlob = await fetch(
`data:${mimeTypeFromBase64(imageBase64)};base64,${imageBase64}`
).then((resp) => resp.blob());
return InputMediaBuilder.photo(new InputFile(imageBlob));
})
);
await bot.api.sendMediaGroup(job.chatId, inputFiles, {
reply_to_message_id: job.requestMessageId,
});
await bot.api.deleteMessage(job.chatId, progressMessage.message_id);
console.log(`${queue.length} jobs remaining`);
} catch (err) {
console.error(
`Failed to generate image for ${job.userName} in ${job.chatName}: ${job.params.prompt} - ${err}`
);
await bot.api
.sendMessage(job.chatId, err.toString(), {
reply_to_message_id: job.requestMessageId,
})
.catch(() => {});
}
}
}
function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
const defaultParams: Partial<SdRequest> = {
batch_size: 3,
n_iter: 1,
width: 128 * 5,
height: 128 * 7,
steps: 40,
cfg_scale: 9,
send_images: true,
save_images: true,
negative_prompt:
"id210 boring_e621_fluffyrock_v4 boring_e621_v4 easynegative ng_deepnegative_v1_75t",
};
function mimeTypeFromBase64(base64: string) {
if (base64.startsWith("/9j/")) {
return "image/jpeg";
}
if (base64.startsWith("iVBORw0KGgo")) {
return "image/png";
}
if (base64.startsWith("R0lGODlh")) {
return "image/gif";
}
if (base64.startsWith("UklGRg")) {
return "image/webp";
}
throw new Error("Unknown image type");
}
async function txt2img(
params: Partial<SdRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal?: AbortSignal
): Promise<SdResponse> {
let response: Response | undefined;
let error: unknown;
fetch(new URL("sdapi/v1/txt2img", sdApiUrl), {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(params),
}).then(
(resp) => (response = resp),
(err) => (error = err)
);
try {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const progressRequest = await fetch(
new URL("sdapi/v1/progress", sdApiUrl)
);
if (progressRequest.ok) {
const progress = (await progressRequest.json()) as SdProgressResponse;
onProgress?.(progress);
}
if (response != null) {
if (response.ok) {
const result = (await response.json()) as SdResponse;
return result;
} else {
throw new Error(
`Request failed: ${response.status} ${response.statusText}`
);
}
}
if (error != null) {
throw error;
}
signal?.throwIfAborted();
}
} finally {
if (!response && !error)
await fetch(new URL("sdapi/v1/interrupt", sdApiUrl), { method: "POST" });
}
}
interface SdRequest {
denoising_strength: number;
prompt: string;
seed: number;
sampler_name: unknown;
batch_size: number;
n_iter: number;
steps: number;
cfg_scale: number;
width: number;
height: number;
negative_prompt: string;
send_images: boolean;
save_images: boolean;
}
interface SdResponse {
images: string[];
parameters: SdRequest;
/** Contains serialized JSON */
info: string;
}
interface SdProgressResponse {
progress: number;
eta_relative: number;
state: SdProgressState;
/** base64 encoded preview */
current_image: string | null;
textinfo: string | null;
}
interface SdProgressState {
skipped: boolean;
interrupted: boolean;
job: string;
job_count: number;
job_timestamp: string;
job_no: number;
sampling_step: number;
sampling_steps: number;
}
/** Like {@link fmt} but accepts an array instead of template string. */
function fmtArray(
stringLikes: FormattedString[],
separator = ""
): FormattedString {
let text = "";
const entities: MessageEntity[] = [];
for (let i = 0; i < stringLikes.length; i++) {
const stringLike = stringLikes[i];
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
}))
);
text += stringLike.toString();
if (i < stringLikes.length - 1) text += separator;
}
return new FormattedString(text, entities);
}
await Promise.all([bot.start(), processQueue()]);

15
mimeType.ts Normal file
View File

@ -0,0 +1,15 @@
export function mimeTypeFromBase64(base64: string) {
if (base64.startsWith("/9j/")) return "image/jpeg";
if (base64.startsWith("iVBORw0KGgo")) return "image/png";
if (base64.startsWith("R0lGODlh")) return "image/gif";
if (base64.startsWith("UklGRg")) return "image/webp";
throw new Error("Unknown image type");
}
export function extFromMimeType(mimeType: string) {
if (mimeType === "image/jpeg") return "jpg";
if (mimeType === "image/png") return "png";
if (mimeType === "image/gif") return "gif";
if (mimeType === "image/webp") return "webp";
throw new Error("Unknown image type");
}

111
queue.ts Normal file
View File

@ -0,0 +1,111 @@
import { InputFile, InputMediaBuilder } from "./deps.ts";
import { bot, getGlobalSession } from "./bot.ts";
import { formatOrdinal } from "./intl.ts";
import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts";
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
export const queue: Job[] = [];
interface Job {
params: Partial<SdRequest>;
userId: number;
userName: string;
chatId: number;
chatName: string;
requestMessageId: number;
statusMessageId: number;
}
export async function processQueue() {
while (true) {
const job = queue.shift();
if (!job) {
await new Promise((resolve) => setTimeout(resolve, 1000));
continue;
}
for (const [index, job] of queue.entries()) {
const place = index + 1;
await bot.api
.editMessageText(
job.chatId,
job.statusMessageId,
`You are ${formatOrdinal(place)} in queue.`,
)
.catch(() => {});
}
try {
await bot.api
.deleteMessage(job.chatId, job.statusMessageId)
.catch(() => {});
const progressMessage = await bot.api.sendMessage(
job.chatId,
"Generating your prompt now...",
{ reply_to_message_id: job.requestMessageId },
);
const onProgress = (progress: SdProgressResponse) => {
bot.api
.editMessageText(
job.chatId,
progressMessage.message_id,
`Generating your prompt now... ${
Math.round(
progress.progress * 100,
)
}%`,
)
.catch(() => {});
};
const config = await getGlobalSession();
const response = await txt2img(
config.sdApiUrl,
{ ...config.defaultParams, ...job.params },
onProgress,
);
console.log(
`Generated ${response.images.length} images (${
response.images
.map((image) => (image.length / 1024).toFixed(0) + "kB")
.join(", ")
}) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`,
);
await bot.api.editMessageText(
job.chatId,
progressMessage.message_id,
`Uploading your images...`,
).catch(() => {});
const inputFiles = await Promise.all(
response.images.map(async (imageBase64, idx) => {
const mimeType = mimeTypeFromBase64(imageBase64);
const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`).then((resp) =>
resp.blob()
);
console.log(
`Uploading image ${idx + 1} of ${response.images.length} (${
(imageBlob.size / 1024).toFixed(0)
}kB)`,
);
return InputMediaBuilder.photo(
new InputFile(imageBlob, `${idx}.${extFromMimeType(mimeType)}`),
);
}),
);
await bot.api.sendMediaGroup(job.chatId, inputFiles, {
reply_to_message_id: job.requestMessageId,
});
await bot.api
.deleteMessage(job.chatId, progressMessage.message_id)
.catch(() => {});
console.log(`${queue.length} jobs remaining`);
} catch (err) {
console.error(
`Failed to generate image for ${job.userName} in ${job.chatName}: ${job.params.prompt} - ${err}`,
);
await bot.api
.sendMessage(job.chatId, err.toString(), {
reply_to_message_id: job.requestMessageId,
})
.catch(() => bot.api.sendMessage(job.chatId, err.toString()))
.catch(() => {});
}
}
}

88
sd.ts Normal file
View File

@ -0,0 +1,88 @@
export async function txt2img(
apiUrl: string,
params: Partial<SdRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal?: AbortSignal,
): Promise<SdResponse> {
let response: Response | undefined;
let error: unknown;
fetch(new URL("sdapi/v1/txt2img", apiUrl), {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(params),
}).then(
(resp) => (response = resp),
(err) => (error = err),
);
try {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl));
if (progressRequest.ok) {
const progress = (await progressRequest.json()) as SdProgressResponse;
onProgress?.(progress);
}
if (response != null) {
if (response.ok) {
const result = (await response.json()) as SdResponse;
return result;
} else {
throw new Error(`Request failed: ${response.status} ${response.statusText}`);
}
}
if (error != null) {
throw error;
}
signal?.throwIfAborted();
}
} finally {
if (!response && !error) {
await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" });
}
}
}
export interface SdRequest {
denoising_strength: number;
prompt: string;
seed: number;
sampler_name: unknown;
batch_size: number;
n_iter: number;
steps: number;
cfg_scale: number;
width: number;
height: number;
negative_prompt: string;
send_images: boolean;
save_images: boolean;
}
export interface SdResponse {
images: string[];
parameters: SdRequest;
/** Contains serialized JSON */
info: string;
}
export interface SdProgressResponse {
progress: number;
eta_relative: number;
state: SdProgressState;
/** base64 encoded preview */
current_image: string | null;
textinfo: string | null;
}
export interface SdProgressState {
skipped: boolean;
interrupted: boolean;
job: string;
job_count: number;
job_timestamp: string;
job_no: number;
sampling_step: number;
sampling_steps: number;
}