Compare commits

..

7 Commits

Author SHA1 Message Date
pinks 3984ca337c feat: conversation with bot 2023-09-11 01:59:33 +02:00
pinks af9d31f75c change hanged jobs to 2 minutes 2023-09-11 01:58:51 +02:00
pinks 721f3d0b0f mark offline when 401 unauthorized 2023-09-11 01:58:20 +02:00
pinks 13bd410bdc remove from queue when user blocked bot 2023-09-11 01:57:42 +02:00
pinks 97f261d01c configs 2023-09-10 23:16:27 +02:00
pinks 6c66a00910 big rewrite 2023-09-10 20:56:17 +02:00
pinks ba2afe40ce implement indexes in store to keep whole history 2023-09-07 22:43:40 +02:00
27 changed files with 1122 additions and 707 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
.vscode
.env
app.db*
deno.lock
updateConfig.ts

View File

@ -11,12 +11,44 @@ Telegram bot for generating images from text.
You can put these in `.env` file or pass them as environment variables.
- `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather))
- `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`)
- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands
(optional)
- `TG_BOT_TOKEN` - Telegram bot token. Get yours from [@BotFather](https://t.me/BotFather).
Required.
- `SD_API_URL` - URL to Stable Diffusion API. Only used on first run. Default:
`http://127.0.0.1:7860/`
- `TG_ADMIN_USERS` - Comma separated list of usernames of users that can use admin commands. Only
used on first run. Optional.
## Running
- Start stable diffusion webui: `cd sd-webui`, `./webui.sh --api`
- Start bot: `deno task start`
## TODO
- [x] Keep generation history
- [x] Changing params, parsing png info in request
- [x] Cancelling jobs by deleting message
- [x] Multiple parallel workers
- [x] Replying to another text message to copy prompt and generate
- [x] Replying to bot message, conversation in DMs
- [ ] Replying to png message to extract png info nad generate
- [ ] Banning tags
- [ ] Img2Img + Upscale
- [ ] Admin WebUI
- [ ] User daily generation limits
- [ ] Querying all generation history, displaying stats
- [ ] Analyzing prompt quality based on tag csv
- [ ] Report aliased/unknown tags based on csv
- [ ] Report unknown loras
- [ ] Investigate "sendMediaGroup failed"
- [ ] Changing sampler without error on unknown sampler
- [ ] Changing model
- [ ] Inpaint using telegram photo edit
- [ ] Outpaint
- [ ] Non-SD (extras) upscale
- [ ] Tiled generation to allow very big images
- [ ] Downloading raw images
- [ ] Extra prompt syntax, fixing `()+++` syntax
- [ ] Translations
- replace fmtDuration usage
- replace formatOrdinal usage

272
bot.ts
View File

@ -1,272 +0,0 @@
import { autoQuote, bold, Bot, Context, hydrateReply, ParseModeFlavor } from "./deps.ts";
import { fmt } from "./intl.ts";
import { getAllJobs, pushJob } from "./queue.ts";
import { mySession, MySessionFlavor } from "./session.ts";
export type MyContext = ParseModeFlavor<Context> & MySessionFlavor;
export const bot = new Bot<MyContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(autoQuote);
bot.use(hydrateReply);
bot.use(mySession);
// Automatically retry bot requests if we get a 429 error
bot.api.config.use(async (prev, method, payload, signal) => {
let remainingAttempts = 5;
while (true) {
const result = await prev(method, payload, signal);
if (result.ok) return result;
if (result.error_code !== 429 || remainingAttempts <= 0) return result;
remainingAttempts -= 1;
const retryAfterMs = (result.parameters?.retry_after ?? 30) * 1000;
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
}
});
// if error happened, try to reply to the user with the error
bot.use(async (ctx, next) => {
try {
await next();
} catch (err) {
try {
await ctx.reply(`Handling update failed: ${err}`, {
reply_to_message_id: ctx.message?.message_id,
});
} catch {
throw err;
}
}
});
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
{ command: "sdparams", description: "Show the current SD parameters" },
]);
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
bot.command("txt2img", async (ctx) => {
if (!ctx.from?.id) {
return ctx.reply("I don't know who you are");
}
const config = ctx.session.global;
if (config.pausedReason != null) {
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
}
const jobs = await getAllJobs();
if (jobs.length >= config.maxJobs) {
return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
);
}
const jobCount = jobs.filter((job) => job.user.id === ctx.from.id).length;
if (jobCount >= config.maxUserJobs) {
return ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
);
}
if (!ctx.match) {
return ctx.reply("Please describe what you want to see after the command");
}
pushJob({
params: { prompt: ctx.match },
user: ctx.from,
chat: ctx.chat,
requestMessage: ctx.message,
status: { type: "idle" },
});
console.log(
`Enqueued job ${jobs.length + 1} for ${ctx.from.first_name} in ${ctx.chat.type} chat:`,
ctx.match.replace(/\s+/g, " "),
"\n",
);
});
bot.command("queue", async (ctx) => {
let jobs = await getAllJobs();
const getMessageText = () => {
if (jobs.length === 0) return fmt`Queue is empty.`;
const sortedJobs = [];
let place = 0;
for (const job of jobs) {
if (job.status.type === "idle") place += 1;
sortedJobs.push({ ...job, place });
}
return fmt`Current queue:\n\n${
sortedJobs.map((job) =>
fmt`${job.place}. ${bold(job.user.first_name)} in ${job.chat.type} chat ${
job.status.type === "processing" ? `(${(job.status.progress * 100).toFixed(0)}%)` : ""
}\n`
)
}`;
};
const message = await ctx.replyFmt(getMessageText());
handleFutureUpdates();
async function handleFutureUpdates() {
for (let idx = 0; idx < 12; idx++) {
await new Promise((resolve) => setTimeout(resolve, 5000));
jobs = await getAllJobs();
const formattedMessage = getMessageText();
await ctx.api.editMessageText(ctx.chat.id, message.message_id, formattedMessage.text, {
entities: formattedMessage.entities,
}).catch(() => undefined);
}
}
});
bot.command("pause", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason != null) {
return ctx.reply(`Already paused: ${config.pausedReason}`);
}
config.pausedReason = ctx.match ?? "No reason given";
return ctx.reply("Paused");
});
bot.command("resume", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason == null) return ctx.reply("Already running");
config.pausedReason = null;
return ctx.reply("Resumed");
});
bot.command("setsdapiurl", async (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (!ctx.match) return ctx.reply("Please specify an URL");
let url: URL;
try {
url = new URL(ctx.match);
} catch {
return ctx.reply("Invalid URL");
}
let resp: Response;
try {
resp = await fetch(new URL("config", url));
} catch (err) {
return ctx.reply(`Could not connect: ${err}`);
}
if (!resp.ok) {
return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`);
}
let data: unknown;
try {
data = await resp.json();
} catch {
return ctx.reply("Invalid response from API");
}
if (data != null && typeof data === "object" && "version" in data) {
config.sdApiUrl = url.toString();
return ctx.reply(`Now using SD at ${url} running version ${data.version}`);
} else {
return ctx.reply("Invalid response from API");
}
});
bot.command("setsdparam", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim());
if (!param) return ctx.reply("Please specify a parameter");
if (value == null) return ctx.reply("Please specify a value after the =");
param = param.toLowerCase().replace(/[\s_]+/g, "");
if (config.defaultParams == null) config.defaultParams = {};
switch (param) {
case "steps": {
const steps = parseInt(value);
if (isNaN(steps)) return ctx.reply("Invalid number value");
if (steps > 100) return ctx.reply("Steps must be less than 100");
if (steps < 10) return ctx.reply("Steps must be greater than 10");
config.defaultParams.steps = steps;
return ctx.reply("Steps set to " + steps);
}
case "detail":
case "cfgscale": {
const detail = parseInt(value);
if (isNaN(detail)) return ctx.reply("Invalid number value");
if (detail > 20) return ctx.reply("Detail must be less than 20");
if (detail < 1) return ctx.reply("Detail must be greater than 1");
config.defaultParams.cfg_scale = detail;
return ctx.reply("Detail set to " + detail);
}
case "niter":
case "niters": {
const nIter = parseInt(value);
if (isNaN(nIter)) return ctx.reply("Invalid number value");
if (nIter > 10) return ctx.reply("Iterations must be less than 10");
if (nIter < 1) return ctx.reply("Iterations must be greater than 1");
config.defaultParams.n_iter = nIter;
return ctx.reply("Iterations set to " + nIter);
}
case "batchsize": {
const batchSize = parseInt(value);
if (isNaN(batchSize)) return ctx.reply("Invalid number value");
if (batchSize > 8) return ctx.reply("Batch size must be less than 8");
if (batchSize < 1) return ctx.reply("Batch size must be greater than 1");
config.defaultParams.batch_size = batchSize;
return ctx.reply("Batch size set to " + batchSize);
}
case "size": {
let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim()));
if (!width || !height || isNaN(width) || isNaN(height)) {
return ctx.reply("Invalid size value");
}
if (width > 2048) return ctx.reply("Width must be less than 2048");
if (height > 2048) return ctx.reply("Height must be less than 2048");
// find closest multiple of 64
width = Math.round(width / 64) * 64;
height = Math.round(height / 64) * 64;
if (width <= 0) return ctx.reply("Width too small");
if (height <= 0) return ctx.reply("Height too small");
config.defaultParams.width = width;
config.defaultParams.height = height;
return ctx.reply(`Size set to ${width}x${height}`);
}
case "negativeprompt": {
config.defaultParams.negative_prompt = value;
return ctx.reply(`Negative prompt set to: ${value}`);
}
default: {
return ctx.reply("Invalid parameter");
}
}
});
bot.command("sdparams", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
return ctx.replyFmt(
fmt`Current config:\n\n${
Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
fmt`${bold(key)} = ${String(value)}\n`
)
}`,
);
});
bot.command("crash", () => {
throw new Error("Crash command used");
});
bot.catch((err) => {
let msg = "Error processing update";
const { from, chat } = err.ctx;
if (from?.first_name) msg += ` from ${from.first_name}`;
if (from?.last_name) msg += ` ${from.last_name}`;
if (from?.username) msg += ` (@${from.username})`;
if (chat?.type === "supergroup" || chat?.type === "group") {
msg += ` in ${chat.title}`;
if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`;
}
console.error(msg, err.error);
});

93
bot/mod.ts Normal file
View File

@ -0,0 +1,93 @@
import { Grammy, GrammyAutoQuote, GrammyParseMode, Log } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { session, SessionFlavor } from "./session.ts";
import { queueCommand } from "./queueCommand.ts";
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
export const logger = () => Log.getLogger();
export type Context = GrammyParseMode.ParseModeFlavor<Grammy.Context> & SessionFlavor;
export const bot = new Grammy.Bot<Context>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(GrammyAutoQuote.autoQuote);
bot.use(GrammyParseMode.hydrateReply);
bot.use(session);
bot.catch((err) => {
logger().error(`Handling update from ${formatUserChat(err.ctx)} failed: ${err}`);
});
// Automatically retry bot requests if we get a "too many requests" or telegram internal error
bot.api.config.use(async (prev, method, payload, signal) => {
let attempt = 0;
while (true) {
attempt++;
const result = await prev(method, payload, signal);
if (
result.ok ||
![429, 500].includes(result.error_code) ||
attempt >= 5
) {
return result;
}
const retryAfterMs = (result.parameters?.retry_after ?? (attempt * 5)) * 1000;
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
}
});
// if error happened, try to reply to the user with the error
bot.use(async (ctx, next) => {
try {
await next();
} catch (err) {
try {
await ctx.reply(`Handling update failed: ${err}`, {
reply_to_message_id: ctx.message?.message_id,
});
} catch {
throw err;
}
}
});
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. " +
"Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
]);
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
bot.command("txt2img", txt2imgCommand);
bot.use(txt2imgQuestion.middleware() as any);
bot.command("queue", queueCommand);
bot.command("pause", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason != null) {
return ctx.reply(`Already paused: ${config.pausedReason}`);
}
config.pausedReason = ctx.match ?? "No reason given";
logger().warning(`Bot paused by ${ctx.from.first_name} because ${config.pausedReason}`);
return ctx.reply("Paused");
});
bot.command("resume", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason == null) return ctx.reply("Already running");
config.pausedReason = null;
logger().info(`Bot resumed by ${ctx.from.first_name}`);
return ctx.reply("Resumed");
});
bot.command("crash", () => {
throw new Error("Crash command used");
});

67
bot/queueCommand.ts Normal file
View File

@ -0,0 +1,67 @@
import { Grammy, GrammyParseMode } from "../deps.ts";
import { fmt, getFlagEmoji } from "../utils.ts";
import { runningWorkers } from "../tasks/pingWorkers.ts";
import { jobStore } from "../db/jobStore.ts";
import { Context, logger } from "./mod.ts";
export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
let formattedMessage = await getMessageText();
const queueMessage = await ctx.replyFmt(formattedMessage);
handleFutureUpdates().catch((err) => logger().warning(`Updating queue message failed: ${err}`));
async function getMessageText() {
const processingJobs = await jobStore.getBy("status.type", "processing")
.then((jobs) => jobs.map((job) => ({ ...job.value, place: 0 })));
const waitingJobs = await jobStore.getBy("status.type", "waiting")
.then((jobs) => jobs.map((job, index) => ({ ...job.value, place: index + 1 })));
const jobs = [...processingJobs, ...waitingJobs];
const config = ctx.session.global;
const { bold } = GrammyParseMode;
return fmt([
"Current queue:\n",
...jobs.length > 0
? jobs.flatMap((job) => [
`${job.place}. `,
fmt`${bold(job.request.from.first_name)} `,
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
job.request.from.username ? `(@${job.request.from.username}) ` : "",
getFlagEmoji(job.request.from.language_code) ?? "",
job.request.chat.type === "private"
? " in private chat "
: ` in ${job.request.chat.title} `,
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
job.request.chat.username
? `(@${job.request.chat.username}) `
: "",
job.status.type === "processing"
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
: "",
"\n",
])
: ["Queue is empty.\n"],
"\nActive workers:\n",
...config.workers.flatMap((worker) => [
runningWorkers.has(worker.name) ? "✅ " : "☠️ ",
fmt`${bold(worker.name)} `,
`(max ${(worker.maxResolution / 1000000).toFixed(1)} Mpx) `,
"\n",
]),
]);
}
async function handleFutureUpdates() {
for (let idx = 0; idx < 20; idx++) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const nextFormattedMessage = await getMessageText();
if (nextFormattedMessage.text !== formattedMessage.text) {
await ctx.api.editMessageText(
ctx.chat.id,
queueMessage.message_id,
nextFormattedMessage.text,
{ entities: nextFormattedMessage.entities },
);
formattedMessage = nextFormattedMessage;
}
}
}
}

81
bot/session.ts Normal file
View File

@ -0,0 +1,81 @@
import { db } from "../db/db.ts";
import { Grammy, GrammyKvStorage } from "../deps.ts";
import { SdApi, SdTxt2ImgRequest } from "../sd.ts";
export type SessionFlavor = Grammy.SessionFlavor<SessionData>;
export interface SessionData {
global: GlobalData;
chat: ChatData;
user: UserData;
}
export interface GlobalData {
adminUsernames: string[];
pausedReason: string | null;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdTxt2ImgRequest>;
workers: WorkerData[];
}
export interface WorkerData {
name: string;
api: SdApi;
auth?: string;
maxResolution: number;
}
export interface ChatData {
language?: string;
}
export interface UserData {
params?: Partial<SdTxt2ImgRequest>;
}
const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter<GlobalData>(db);
const getDefaultGlobalData = (): GlobalData => ({
adminUsernames: Deno.env.get("TG_ADMIN_USERS")?.split(",") ?? [],
pausedReason: null,
maxUserJobs: 3,
maxJobs: 20,
defaultParams: {
batch_size: 1,
n_iter: 1,
width: 512,
height: 768,
steps: 30,
cfg_scale: 10,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
workers: [
{
name: "local",
api: { url: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/" },
maxResolution: 1024 * 1024,
},
],
});
export const session = Grammy.session<SessionData, Grammy.Context & SessionFlavor>({
type: "multi",
global: {
getSessionKey: () => "global",
initial: getDefaultGlobalData,
storage: globalDbAdapter,
},
chat: {
initial: () => ({}),
},
user: {
getSessionKey: (ctx) => ctx.from?.id.toFixed(),
initial: () => ({}),
},
});
export async function getGlobalSession(): Promise<GlobalData> {
const data = await globalDbAdapter.read("global");
return data ?? getDefaultGlobalData();
}

72
bot/txt2imgCommand.ts Normal file
View File

@ -0,0 +1,72 @@
import { Grammy, GrammyStatelessQ } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { parsePngInfo } from "../sd.ts";
import { Context, logger } from "./mod.ts";
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion(
"txt2img",
async (ctx) => {
if (!ctx.message.text) return;
await txt2img(ctx as any, ctx.message.text, false);
},
);
export async function txt2imgCommand(ctx: Grammy.CommandContext<Context>) {
await txt2img(ctx, ctx.match, true);
}
async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean): Promise<void> {
if (!ctx.message?.from?.id) {
return void ctx.reply("I don't know who you are");
}
const config = ctx.session.global;
if (config.pausedReason != null) {
return void ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
}
const jobs = await jobStore.getBy("status.type", "waiting");
if (jobs.length >= config.maxJobs) {
return void ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
);
}
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.message?.from?.id);
if (userJobs.length >= config.maxUserJobs) {
return void ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
);
}
let params = parsePngInfo(match);
const repliedToMsg = ctx.message.reply_to_message;
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
if (includeRepliedTo && repliedToText) {
const originalParams = parsePngInfo(repliedToText);
params = {
...originalParams,
...params,
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
};
}
if (!params.prompt) {
return void ctx.reply(
"Please tell me what you want to see." +
txt2imgQuestion.messageSuffixMarkdown(),
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
);
}
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
await jobStore.create({
params,
request: ctx.message,
reply: replyMessage,
status: { type: "waiting" },
});
logger().debug(`Job enqueued for ${formatUserChat(ctx.message)}`);
}

1
db/db.ts Normal file
View File

@ -0,0 +1 @@
export const db = await Deno.openKv("./app.db");

18
db/jobStore.ts Normal file
View File

@ -0,0 +1,18 @@
import { GrammyTypes, IKV } from "../deps.ts";
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
import { db } from "./db.ts";
export interface JobSchema {
params: Partial<SdTxt2ImgRequest>;
request: GrammyTypes.Message & { from: GrammyTypes.User };
reply?: GrammyTypes.Message.TextMessage;
status:
| { type: "waiting" }
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
| { type: "done"; info?: SdTxt2ImgInfo; startDate?: Date; endDate?: Date };
}
export const jobStore = new IKV.Store(db, "job", {
schema: new IKV.Schema<JobSchema>(),
indices: ["status.type"],
});

View File

@ -1,6 +1,5 @@
{
"tasks": {
"dev": "deno run --watch --unstable --allow-env --allow-read --allow-write --allow-net main.ts",
"start": "deno run --unstable --allow-env --allow-read --allow-write --allow-net main.ts"
},
"fmt": {

25
deps.ts
View File

@ -1,6 +1,19 @@
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as types from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
export * from "https://deno.land/x/ulid@v0.3.0/mod.ts";
export * as Log from "https://deno.land/std@0.201.0/log/mod.ts";
export * as Async from "https://deno.land/std@0.201.0/async/mod.ts";
export * as FmtDuration from "https://deno.land/std@0.201.0/fmt/duration.ts";
export * as Collections from "https://deno.land/std@0.201.0/collections/mod.ts";
export * as Base64 from "https://deno.land/std@0.201.0/encoding/base64.ts";
export * as AsyncX from "https://deno.land/x/async@v2.0.2/mod.ts";
export * as ULID from "https://deno.land/x/ulid@v0.3.0/mod.ts";
export * as IKV from "https://deno.land/x/indexed_kv@v0.2.0/mod.ts";
export * as Grammy from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as GrammyStatelessQ from "npm:@grammyjs/stateless-question";
export * as FileType from "npm:file-type@18.5.0";
// @deno-types="./types/png-chunks-extract.d.ts"
export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0";
// @deno-types="./types/png-chunk-text.d.ts"
export * as PngChunkText from "npm:png-chunk-text@1.0.0";

43
intl.ts
View File

@ -1,43 +0,0 @@
import { FormattedString } from "./deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
type DeepArray<T> = Array<T | DeepArray<T>>;
type StringLikes = DeepArray<FormattedString | string | number | null | undefined>;
/**
* Like `fmt` from `grammy_parse_mode` but additionally accepts arrays.
* @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182
*/
export const fmt = (
rawStringParts: TemplateStringsArray | StringLikes,
...stringLikes: StringLikes
): FormattedString => {
let text = "";
const entities: ConstructorParameters<typeof FormattedString>[1][] = [];
const length = Math.max(rawStringParts.length, stringLikes.length);
for (let i = 0; i < length; i++) {
for (let stringLike of [rawStringParts[i], stringLikes[i]]) {
if (Array.isArray(stringLike)) {
stringLike = fmt(stringLike);
}
if (stringLike instanceof FormattedString) {
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
}
return new FormattedString(text, entities);
};

20
main.ts
View File

@ -1,9 +1,21 @@
// Load environment variables from .env file
import "https://deno.land/std@0.201.0/dotenv/load.ts";
import { bot } from "./bot.ts";
import { processQueue, returnHangedJobs } from "./queue.ts";
// Setup logging
import { Log } from "./deps.ts";
Log.setup({
handlers: {
console: new Log.handlers.ConsoleHandler("DEBUG"),
},
loggers: {
default: { level: "DEBUG", handlers: ["console"] },
},
});
// Main program logic
import { bot } from "./bot/mod.ts";
import { runAllTasks } from "./tasks/mod.ts";
await Promise.all([
bot.start(),
processQueue(),
returnHangedJobs(),
runAllTasks(),
]);

View File

@ -1,15 +0,0 @@
export function mimeTypeFromBase64(base64: string) {
if (base64.startsWith("/9j/")) return "image/jpeg";
if (base64.startsWith("iVBORw0KGgo")) return "image/png";
if (base64.startsWith("R0lGODlh")) return "image/gif";
if (base64.startsWith("UklGRg")) return "image/webp";
throw new Error("Unknown image type");
}
export function extFromMimeType(mimeType: string) {
if (mimeType === "image/jpeg") return "jpg";
if (mimeType === "image/png") return "png";
if (mimeType === "image/gif") return "gif";
if (mimeType === "image/webp") return "webp";
throw new Error("Unknown image type");
}

167
queue.ts
View File

@ -1,167 +0,0 @@
import { InputFile, InputMediaBuilder, types } from "./deps.ts";
import { bot } from "./bot.ts";
import { getGlobalSession } from "./session.ts";
import { formatOrdinal } from "./intl.ts";
import { SdRequest, txt2img } from "./sd.ts";
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
import { Model, Store } from "./store.ts";
interface Job {
params: Partial<SdRequest>;
user: types.User;
chat: types.Chat.PrivateChat | types.Chat.GroupChat | types.Chat.SupergroupChat;
requestMessage: types.Message & types.Message.TextMessage;
statusMessage?: types.Message & types.Message.TextMessage;
status: { type: "idle" } | { type: "processing"; progress: number; updatedDate: Date };
}
const db = await Deno.openKv("./app.db");
const jobStore = new Store<Job>(db, "job");
export async function pushJob(job: Job) {
await jobStore.create(job);
}
async function takeJob(): Promise<Model<Job> | null> {
const jobs = await jobStore.list();
const job = jobs.find((job) => job.value.status.type === "idle");
if (!job) return null;
await job.update({ status: { type: "processing", progress: 0, updatedDate: new Date() } });
return job;
}
export async function getAllJobs(): Promise<Array<Job>> {
return await jobStore.list().then((jobs) => jobs.map((job) => job.value));
}
export async function processQueue() {
while (true) {
const job = await takeJob();
if (!job) {
await new Promise((resolve) => setTimeout(resolve, 1000));
continue;
}
let place = 0;
for (const job of await jobStore.list()) {
if (job.value.status.type === "idle") place += 1;
if (place === 0) continue;
const statusMessageText = `You are ${formatOrdinal(place)} in queue.`;
if (!job.value.statusMessage) {
await bot.api.sendMessage(job.value.chat.id, statusMessageText, {
reply_to_message_id: job.value.requestMessage.message_id,
}).catch(() => undefined)
.then((message) => job.update({ statusMessage: message }));
} else {
await bot.api.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
statusMessageText,
)
.catch(() => undefined);
}
}
try {
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage?.message_id)
.catch(() => undefined)
.then(() => job.update({ statusMessage: undefined }));
}
await bot.api.sendMessage(
job.value.chat.id,
"Generating your prompt now...",
{ reply_to_message_id: job.value.requestMessage.message_id },
).then((message) => job.update({ statusMessage: message }));
const config = await getGlobalSession();
const response = await txt2img(
config.sdApiUrl,
{ ...config.defaultParams, ...job.value.params },
(progress) => {
job.update({
status: { type: "processing", progress: progress.progress, updatedDate: new Date() },
});
if (job.value.statusMessage) {
bot.api
.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
`Generating your prompt now... ${
Math.round(
progress.progress * 100,
)
}%`,
)
.catch(() => undefined);
}
},
);
console.log(
`Finished job for ${job.value.user.first_name} in ${job.value.chat.type} chat`,
);
if (job.value.statusMessage) {
await bot.api.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
`Uploading your images...`,
).catch(() => undefined);
}
const inputFiles = await Promise.all(
response.images.map(async (imageBase64, idx) => {
const mimeType = mimeTypeFromBase64(imageBase64);
const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`)
.then((resp) => resp.blob());
return InputMediaBuilder.photo(
new InputFile(imageBlob, `image_${idx}.${extFromMimeType(mimeType)}`),
);
}),
);
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.catch(() => undefined).then(() => job.update({ statusMessage: undefined }));
}
await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
reply_to_message_id: job.value.requestMessage.message_id,
});
await job.delete();
} catch (err) {
console.error(
`Failed to generate an image for ${job.value.user.first_name} in ${job.value.chat.type} chat: ${err}`,
);
const errorMessage = await bot.api
.sendMessage(job.value.chat.id, err.toString(), {
reply_to_message_id: job.value.requestMessage.message_id,
})
.catch(() => undefined);
if (errorMessage) {
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.catch(() => undefined)
.then(() => job.update({ statusMessage: undefined }));
}
job.update({ status: { type: "idle" } });
} else {
await job.delete();
}
}
}
}
export async function returnHangedJobs() {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.list();
for (const job of jobs) {
if (job.value.status.type === "idle") continue;
// if job wasn't updated for 1 minute, return it to the queue
if (job.value.status.updatedDate.getTime() < Date.now() - 60 * 1000) {
console.log(
`Returning hanged job for ${job.value.user.first_name} in ${job.value.chat.type} chat`,
);
await job.update({ status: { type: "idle" } });
}
}
}
}

304
sd.ts
View File

@ -1,53 +1,88 @@
export async function txt2img(
apiUrl: string,
params: Partial<SdRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal?: AbortSignal,
): Promise<SdResponse> {
let response: Response | undefined;
let error: unknown;
import { Async, AsyncX, PngChunksExtract, PngChunkText } from "./deps.ts";
fetch(new URL("sdapi/v1/txt2img", apiUrl), {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(params),
}).then(
(resp) => (response = resp),
(err) => (error = err),
);
const neverSignal = new AbortController().signal;
export interface SdApi {
url: string;
auth?: string;
}
async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Promise<T> {
let options: RequestInit | undefined;
if (body != null) {
options = {
method: "POST",
headers: {
"Content-Type": "application/json",
...api.auth ? { Authorization: api.auth } : {},
},
body: JSON.stringify(body),
};
} else if (api.auth) {
options = {
headers: { Authorization: api.auth },
};
}
const response = await fetch(new URL(endpoint, api.url), options).catch(() => {
throw new SdApiError(endpoint, options, 0, "Network error");
});
const result = await response.json().catch(() => {
throw new SdApiError(endpoint, options, response.status, response.statusText, {
detail: "Invalid JSON",
});
});
if (!response.ok) {
throw new SdApiError(endpoint, options, response.status, response.statusText, result);
}
return result;
}
export async function sdTxt2Img(
api: SdApi,
params: Partial<SdTxt2ImgRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal: AbortSignal = neverSignal,
): Promise<SdTxt2ImgResponse> {
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
}));
try {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl));
if (progressRequest.ok) {
const progress = (await progressRequest.json()) as SdProgressResponse;
onProgress?.(progress);
}
if (response != null) {
if (response.ok) {
const result = (await response.json()) as SdResponse;
return result;
} else {
throw new Error(`Request failed: ${response.status} ${response.statusText}`);
}
}
if (error != null) {
throw error;
}
signal?.throwIfAborted();
await Async.abortable(Promise.race([request, Async.delay(3000)]), signal);
if (await AsyncX.promiseState(request) !== "pending") return await request;
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
}
} finally {
if (!response && !error) {
await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" });
if (await AsyncX.promiseState(request) === "pending") {
await fetchSdApi(api, "sdapi/v1/interrupt", {});
}
}
}
export interface SdRequest {
export interface SdTxt2ImgRequest {
enable_hr: boolean;
denoising_strength: number;
firstphase_width: number;
firstphase_height: number;
hr_scale: number;
hr_upscaler: unknown;
hr_second_pass_steps: number;
hr_resize_x: number;
hr_resize_y: number;
hr_sampler_name: unknown;
hr_prompt: string;
hr_negative_prompt: string;
prompt: string;
styles: unknown;
seed: number;
subseed: number;
subseed_strength: number;
seed_resize_from_h: number;
seed_resize_from_w: number;
sampler_name: unknown;
batch_size: number;
n_iter: number;
@ -55,16 +90,68 @@ export interface SdRequest {
cfg_scale: number;
width: number;
height: number;
restore_faces: boolean;
tiling: boolean;
do_not_save_samples: boolean;
do_not_save_grid: boolean;
negative_prompt: string;
eta: unknown;
s_min_uncond: number;
s_churn: number;
s_tmax: unknown;
s_tmin: number;
s_noise: number;
override_settings: object;
override_settings_restore_afterwards: boolean;
script_args: unknown[];
sampler_index: string;
script_name: unknown;
send_images: boolean;
save_images: boolean;
alwayson_scripts: object;
}
export interface SdResponse {
export interface SdTxt2ImgResponse {
images: string[];
parameters: SdRequest;
/** Contains serialized JSON */
info: string;
parameters: SdTxt2ImgRequest;
// Warning: raw response from API is a JSON-serialized string
info: SdTxt2ImgInfo;
}
export interface SdTxt2ImgInfo {
prompt: string;
all_prompts: string[];
negative_prompt: string;
all_negative_prompts: string[];
seed: number;
all_seeds: number[];
subseed: number;
all_subseeds: number[];
subseed_strength: number;
width: number;
height: number;
sampler_name: string;
cfg_scale: number;
steps: number;
batch_size: number;
restore_faces: boolean;
face_restoration_model: unknown;
sd_model_hash: string;
seed_resize_from_w: number;
seed_resize_from_h: number;
denoising_strength: number;
extra_generation_params: SdTxt2ImgInfoExtraParams;
index_of_first_image: number;
infotexts: string[];
styles: unknown[];
job_timestamp: string;
clip_skip: number;
is_using_inpainting_conditioning: boolean;
}
export interface SdTxt2ImgInfoExtraParams {
"Lora hashes": string;
"TI hashes": string;
}
export interface SdProgressResponse {
@ -86,3 +173,138 @@ export interface SdProgressState {
sampling_step: number;
sampling_steps: number;
}
export function sdGetConfig(api: SdApi): Promise<SdConfigResponse> {
return fetchSdApi(api, "config");
}
export interface SdConfigResponse {
/** version with new line at the end for some reason */
version: string;
mode: string;
dev_mode: boolean;
analytics_enabled: boolean;
components: object[];
css: unknown;
title: string;
is_space: boolean;
enable_queue: boolean;
show_error: boolean;
show_api: boolean;
is_colab: boolean;
stylesheets: unknown[];
theme: string;
layout: object;
dependencies: object[];
root: string;
}
export interface SdErrorResponse {
/**
* The HTTP status message or array of invalid fields.
* Can also be empty string.
*/
detail: string | Array<{ loc: string[]; msg: string; type: string }>;
/** Can be e.g. "OutOfMemoryError" or undefined. */
error?: string;
/** Empty string. */
body?: string;
/** Long description of error. */
errors?: string;
}
export class SdApiError extends Error {
constructor(
public readonly endpoint: string,
public readonly options: RequestInit | undefined,
public readonly statusCode: number,
public readonly statusText: string,
public readonly response?: SdErrorResponse,
) {
let message = `${options?.method ?? "GET"} ${endpoint} : ${statusCode} ${statusText}`;
if (response?.error) {
message += `: ${response.error}`;
if (response.errors) message += ` - ${response.errors}`;
} else if (typeof response?.detail === "string" && response.detail.length > 0) {
message += `: ${response.detail}`;
} else if (response?.detail) {
message += `: ${JSON.stringify(response.detail)}`;
}
super(message);
}
}
export function getPngInfo(pngData: Uint8Array): string | undefined {
return PngChunksExtract.default(pngData)
.filter((chunk) => chunk.name === "tEXt")
.map((chunk) => PngChunkText.decode(chunk.data))
.find((textChunk) => textChunk.keyword === "parameters")
?.text;
}
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
let part: "prompt" | "negative_prompt" | "params" = "prompt";
const params: Partial<SdTxt2ImgRequest> = {};
const prompt: string[] = [];
const negativePrompt: string[] = [];
for (const tag of tags) {
const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+([\d\w. ]+)\s*$/u);
if (paramValuePair) {
const [, param, value] = paramValuePair;
switch (param.replace(/\s+/u, "").toLowerCase()) {
case "prompt":
part = "prompt";
prompt.push(value.trim());
break;
case "negativeprompt":
part = "negative_prompt";
negativePrompt.push(value.trim());
break;
case "steps":
case "cycles": {
part = "params";
const steps = Number(value.trim());
if (steps > 0) params.steps = Math.min(steps, 50);
break;
}
case "cfgscale":
case "detail": {
part = "params";
const cfgScale = Number(value.trim());
if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20);
break;
}
case "size":
case "resolution": {
part = "params";
const [width, height] = value.trim()
.split(/\s*[x,]\s*/u, 2)
.map((v) => v.trim())
.map(Number);
if (width > 0 && height > 0) {
params.width = Math.min(width, 2048);
params.height = Math.min(height, 2048);
}
break;
}
default:
break;
}
} else if (tag.trim().length > 0) {
switch (part) {
case "prompt":
prompt.push(tag.trim());
break;
case "negative_prompt":
negativePrompt.push(tag.trim());
break;
default:
break;
}
}
}
if (prompt.length > 0) params.prompt = prompt.join(", ");
if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", ");
return params;
}

View File

@ -1,78 +0,0 @@
import { Context, DenoKVAdapter, session, SessionFlavor } from "./deps.ts";
import { SdRequest } from "./sd.ts";
export type MySessionFlavor = SessionFlavor<SessionData>;
export interface SessionData {
global: GlobalData;
chat: ChatData;
user: UserData;
}
export interface GlobalData {
adminUsernames: string[];
pausedReason: string | null;
sdApiUrl: string;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdRequest>;
}
export interface ChatData {
language: string;
}
export interface UserData {
steps: number;
detail: number;
batchSize: number;
}
const globalDb = await Deno.openKv("./app.db");
const globalDbAdapter = new DenoKVAdapter<GlobalData>(globalDb);
const getDefaultGlobalData = (): GlobalData => ({
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
pausedReason: null,
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
maxUserJobs: 3,
maxJobs: 20,
defaultParams: {
batch_size: 1,
n_iter: 1,
width: 128 * 2,
height: 128 * 3,
steps: 20,
cfg_scale: 9,
send_images: true,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
});
export const mySession = session<SessionData, Context & MySessionFlavor>({
type: "multi",
global: {
getSessionKey: () => "global",
initial: getDefaultGlobalData,
storage: globalDbAdapter,
},
chat: {
initial: () => ({
language: "en",
}),
},
user: {
getSessionKey: (ctx) => ctx.from?.id.toFixed(),
initial: () => ({
steps: 20,
detail: 8,
batchSize: 2,
}),
},
});
export async function getGlobalSession(): Promise<GlobalData> {
const data = await globalDbAdapter.read("global");
return data ?? getDefaultGlobalData();
}

View File

@ -1,76 +0,0 @@
import { ulid } from "./deps.ts";
export class Store<T extends object> {
constructor(
private readonly db: Deno.Kv,
private readonly storeKey: Deno.KvKeyPart,
) {
}
async create(value: T): Promise<Model<T>> {
const id = ulid();
await this.db.set([this.storeKey, id], value);
return new Model(this.db, this.storeKey, id, value);
}
async get(id: Deno.KvKeyPart): Promise<Model<T> | null> {
const entry = await this.db.get<T>([this.storeKey, id]);
if (entry.versionstamp == null) return null;
return new Model(this.db, this.storeKey, id, entry.value);
}
async list(): Promise<Array<Model<T>>> {
const models: Array<Model<T>> = [];
for await (const entry of this.db.list<T>({ prefix: [this.storeKey] })) {
models.push(new Model(this.db, this.storeKey, entry.key[1], entry.value));
}
return models;
}
}
export class Model<T extends object> {
#value: T;
constructor(
private readonly db: Deno.Kv,
private readonly storeKey: Deno.KvKeyPart,
private readonly entryKey: Deno.KvKeyPart,
value: T,
) {
this.#value = value;
}
get value(): T {
return this.#value;
}
async get(): Promise<T | null> {
const entry = await this.db.get<T>([this.storeKey, this.entryKey]);
if (entry.versionstamp == null) return null;
this.#value = entry.value;
return entry.value;
}
async set(value: T): Promise<T> {
await this.db.set([this.storeKey, this.entryKey], value);
this.#value = value;
return value;
}
async update(value: Partial<T> | ((value: T) => T)): Promise<T | null> {
const entry = await this.db.get<T>([this.storeKey, this.entryKey]);
if (entry.versionstamp == null) return null;
if (typeof value === "function") {
entry.value = value(entry.value);
} else {
entry.value = { ...entry.value, ...value };
}
await this.db.set([this.storeKey, this.entryKey], entry.value);
this.#value = entry.value;
return entry.value;
}
async delete(): Promise<void> {
await this.db.delete([this.storeKey, this.entryKey]);
}
}

13
tasks/mod.ts Normal file
View File

@ -0,0 +1,13 @@
import { pingWorkers } from "./pingWorkers.ts";
import { processJobs } from "./processJobs.ts";
import { returnHangedJobs } from "./returnHangedJobs.ts";
import { updateJobStatusMsgs } from "./updateJobStatusMsgs.ts";
export async function runAllTasks() {
await Promise.all([
processJobs(),
updateJobStatusMsgs(),
returnHangedJobs(),
pingWorkers(),
]);
}

32
tasks/pingWorkers.ts Normal file
View File

@ -0,0 +1,32 @@
import { Async, Log } from "../deps.ts";
import { getGlobalSession } from "../bot/session.ts";
import { sdGetConfig } from "../sd.ts";
const logger = () => Log.getLogger();
export const runningWorkers = new Set<string>();
/**
* Periodically ping the workers to see if they are alive.
*/
export async function pingWorkers(): Promise<never> {
while (true) {
try {
const config = await getGlobalSession();
for (const worker of config.workers) {
const status = await sdGetConfig(worker.api).catch(() => null);
const wasRunning = runningWorkers.has(worker.name);
if (status) {
runningWorkers.add(worker.name);
if (!wasRunning) logger().info(`Worker ${worker.name} is online`);
} else {
runningWorkers.delete(worker.name);
if (wasRunning) logger().warning(`Worker ${worker.name} went offline`);
}
}
await Async.delay(60 * 1000);
} catch (err) {
logger().warning(`Pinging workers failed: ${err}`);
}
}
}

227
tasks/processJobs.ts Normal file
View File

@ -0,0 +1,227 @@
import { Base64, FileType, FmtDuration, Grammy, GrammyParseMode, IKV, Log } from "../deps.ts";
import { bot } from "../bot/mod.ts";
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
import { fmt, formatUserChat } from "../utils.ts";
import { SdApiError, sdTxt2Img } from "../sd.ts";
import { JobSchema, jobStore } from "../db/jobStore.ts";
import { runningWorkers } from "./pingWorkers.ts";
const logger = () => Log.getLogger();
/**
* Sends waiting jobs to workers.
*/
export async function processJobs(): Promise<never> {
const busyWorkers = new Set<string>();
while (true) {
await new Promise((resolve) => setTimeout(resolve, 1000));
try {
// get first waiting job
const job = await jobStore.getBy("status.type", "waiting").then((jobs) => jobs[0]);
if (!job) continue;
// find a worker to handle the job
const config = await getGlobalSession();
const worker = config.workers?.find((worker) =>
runningWorkers.has(worker.name) &&
!busyWorkers.has(worker.name)
);
if (!worker) continue;
// process the job
await job.update({
status: { type: "processing", progress: 0, worker: worker.name, updatedDate: new Date() },
});
busyWorkers.add(worker.name);
processJob(job, worker, config)
.catch(async (err) => {
logger().error(
`Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`,
);
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
await bot.api.sendMessage(
job.value.request.chat.id,
`Failed to generate your prompt: ${err.message}`,
{ reply_to_message_id: job.value.request.message_id },
).catch(() => undefined);
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
}
if (
err instanceof SdApiError &&
(
err.statusCode === 0 /* Network error */ ||
err.statusCode === 404 ||
err.statusCode === 401
)
) {
runningWorkers.delete(worker.name);
logger().warning(
`Worker ${worker.name} was marked as offline because of network error`,
);
}
await job.delete().catch(() => undefined);
if (!(err instanceof Grammy.GrammyError) || err.error_code !== 403 /* blocked bot */) {
await jobStore.create(job.value);
}
})
.finally(() => busyWorkers.delete(worker.name));
} catch (err) {
logger().warning(`Processing jobs failed: ${err}`);
}
}
}
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
logger().debug(
`Job started for ${formatUserChat(job.value.request)} using ${worker.name}`,
);
const startDate = new Date();
// if there is already a status message delete it
if (job.value.reply) {
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
.catch(() => undefined);
}
// send a new status message
const newStatusMessage = await bot.api.sendMessage(
job.value.request.chat.id,
`Generating your prompt now... 0% using ${worker.name}`,
{ reply_to_message_id: job.value.request.message_id },
).catch((err) => {
// don't error if the request message was deleted
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
else throw err;
});
// if the request message was deleted, cancel the job
if (!newStatusMessage) {
await job.delete();
logger().info(
`Job cancelled for ${formatUserChat(job.value.request)}`,
);
return;
}
await job.update({ reply: newStatusMessage });
// reduce size if worker can't handle the resolution
const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution);
// process the job
const response = await sdTxt2Img(
worker.api,
{ ...config.defaultParams, ...job.value.params, ...size },
async (progress) => {
// important: don't let any errors escape this callback
if (job.value.reply) {
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
`Generating your prompt now... ${
(progress.progress * 100).toFixed(0)
}% using ${worker.name}`,
).catch(() => undefined);
}
await job.update({
status: {
type: "processing",
progress: progress.progress,
worker: worker.name,
updatedDate: new Date(),
},
}).catch(() => undefined);
},
);
// upload the result
if (job.value.reply) {
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
`Uploading your images...`,
).catch(() => undefined);
}
// render the caption
// const detailedReply = Object.keys(job.value.params).filter((key) => key !== "prompt").length > 0;
const detailedReply = true;
const jobDurationMs = Math.trunc((Date.now() - startDate.getTime()) / 1000) * 1000;
const { bold } = GrammyParseMode;
const caption = fmt([
`${response.info.prompt}\n`,
...detailedReply
? [
response.info.negative_prompt
? fmt`${bold("Negative prompt:")} ${response.info.negative_prompt}\n`
: "",
fmt`${bold("Steps:")} ${response.info.steps}, `,
fmt`${bold("Sampler:")} ${response.info.sampler_name}, `,
fmt`${bold("CFG scale:")} ${response.info.cfg_scale}, `,
fmt`${bold("Seed:")} ${response.info.seed}, `,
fmt`${bold("Size")}: ${response.info.width}x${response.info.height}, `,
fmt`${bold("Worker")}: ${worker.name}, `,
fmt`${bold("Time taken")}: ${FmtDuration.format(jobDurationMs, { ignoreZero: true })}`,
]
: [],
]);
// parse files from reply JSON
const inputFiles = await Promise.all(
response.images.map(async (imageBase64, idx) => {
const imageBuffer = Base64.decode(imageBase64);
const imageType = await FileType.fileTypeFromBuffer(imageBuffer);
if (!imageType) throw new Error("Unknown file type returned from worker");
return Grammy.InputMediaBuilder.photo(
new Grammy.InputFile(imageBuffer, `image${idx}.${imageType.ext}`),
// if it can fit, add caption for first photo
idx === 0 && caption.text.length <= 1024
? { caption: caption.text, caption_entities: caption.entities }
: undefined,
);
}),
);
// send the result to telegram
const resultMessage = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, {
reply_to_message_id: job.value.request.message_id,
});
// send caption in separate message if it couldn't fit
if (caption.text.length > 1024 && caption.text.length <= 4096) {
await bot.api.sendMessage(job.value.request.chat.id, caption.text, {
reply_to_message_id: resultMessage[0].message_id,
entities: caption.entities,
});
}
// delete the status message
if (job.value.reply) {
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
.catch(() => undefined)
.then(() => job.update({ reply: undefined }))
.catch(() => undefined);
}
// update job to status done
await job.update({
status: { type: "done", info: response.info, startDate, endDate: new Date() },
});
logger().debug(
`Job finished for ${formatUserChat(job.value.request)} using ${worker.name}`,
);
}
function limitSize(
{ width, height }: { width?: number; height?: number },
maxResolution: number,
): { width?: number; height?: number } {
if (!width || !height) return {};
const ratio = width / height;
if (width * height > maxResolution) {
return {
width: Math.trunc(Math.sqrt(maxResolution * ratio)),
height: Math.trunc(Math.sqrt(maxResolution / ratio)),
};
}
return { width, height };
}

36
tasks/returnHangedJobs.ts Normal file
View File

@ -0,0 +1,36 @@
import { FmtDuration, Log } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
const logger = () => Log.getLogger();
/**
* Returns hanged jobs to the queue.
*/
export async function returnHangedJobs(): Promise<never> {
while (true) {
try {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getBy("status.type", "processing");
for (const job of jobs) {
if (job.value.status.type !== "processing") continue;
// if job wasn't updated for 2 minutes, return it to the queue
const timeSinceLastUpdateMs = Date.now() - job.value.status.updatedDate.getTime();
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
await job.update({ status: { type: "waiting" } });
logger().warning(
`Job for ${
formatUserChat(job.value.request)
} was returned to the queue because it hanged for ${
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
ignoreZero: true,
})
}`,
);
}
}
} catch (err) {
logger().warning(`Returning hanged jobs failed: ${err}`);
}
}
}

View File

@ -0,0 +1,28 @@
import { Log } from "../deps.ts";
import { bot } from "../bot/mod.ts";
import { formatOrdinal } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
const logger = () => Log.getLogger();
/**
* Updates status messages for jobs in the queue.
*/
export async function updateJobStatusMsgs(): Promise<never> {
while (true) {
try {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getBy("status.type", "waiting");
for (const [index, job] of jobs.entries()) {
if (!job.value.reply) continue;
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
`You are ${formatOrdinal(index + 1)} in queue.`,
).catch(() => undefined);
}
} catch (err) {
logger().warning(`Updating job status messages failed: ${err}`);
}
}
}

2
types/png-chunk-text.d.ts vendored Normal file
View File

@ -0,0 +1,2 @@
export function decode(chunk: Uint8Array): { keyword: string; text: string };
export function encode(keyword: string, text: string): Uint8Array;

1
types/png-chunks-encode.d.ts vendored Normal file
View File

@ -0,0 +1 @@
export default function encode(chunks: Array<{ name: string; data: Uint8Array }>): Uint8Array;

1
types/png-chunks-extract.d.ts vendored Normal file
View File

@ -0,0 +1 @@
export default function extract(data: Uint8Array): Array<{ name: string; data: Uint8Array }>;

111
utils.ts Normal file
View File

@ -0,0 +1,111 @@
import { GrammyParseMode, GrammyTypes } from "./deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
export const fmt = (
rawStringParts: TemplateStringsArray | GrammyParseMode.Stringable[],
...stringLikes: GrammyParseMode.Stringable[]
): GrammyParseMode.FormattedString => {
let text = "";
const entities: GrammyTypes.MessageEntity[] = [];
const length = Math.max(rawStringParts.length, stringLikes.length);
for (let i = 0; i < length; i++) {
for (const stringLike of [rawStringParts[i], stringLikes[i]]) {
if (stringLike instanceof GrammyParseMode.FormattedString) {
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
}
return new GrammyParseMode.FormattedString(text, entities);
};
export function formatUserChat(ctx: { from?: GrammyTypes.User; chat?: GrammyTypes.Chat }) {
const msg: string[] = [];
if (ctx.from) {
msg.push(ctx.from.first_name);
if (ctx.from.last_name) msg.push(ctx.from.last_name);
if (ctx.from.username) msg.push(`(@${ctx.from.username})`);
if (ctx.from.language_code) msg.push(`(${ctx.from.language_code.toUpperCase()})`);
}
if (ctx.chat) {
if (
ctx.chat.type === "group" ||
ctx.chat.type === "supergroup" ||
ctx.chat.type === "channel"
) {
msg.push("in");
msg.push(ctx.chat.title);
if (
(ctx.chat.type === "supergroup" || ctx.chat.type === "channel") &&
ctx.chat.username
) {
msg.push(`(@${ctx.chat.username})`);
}
}
}
return msg.join(" ");
}
/** Language to biggest country emoji map */
const languageToFlagMap: Record<string, string> = {
"en": "🇺🇸", // English - United States
"zh": "🇨🇳", // Chinese - China
"es": "🇪🇸", // Spanish - Spain
"hi": "🇮🇳", // Hindi - India
"ar": "🇪🇬", // Arabic - Egypt
"pt": "🇧🇷", // Portuguese - Brazil
"bn": "🇧🇩", // Bengali - Bangladesh
"ru": "🇷🇺", // Russian - Russia
"ja": "🇯🇵", // Japanese - Japan
"pa": "🇮🇳", // Punjabi - India
"de": "🇩🇪", // German - Germany
"ko": "🇰🇷", // Korean - South Korea
"fr": "🇫🇷", // French - France
"tr": "🇹🇷", // Turkish - Turkey
"ur": "🇵🇰", // Urdu - Pakistan
"it": "🇮🇹", // Italian - Italy
"th": "🇹🇭", // Thai - Thailand
"vi": "🇻🇳", // Vietnamese - Vietnam
"pl": "🇵🇱", // Polish - Poland
"uk": "🇺🇦", // Ukrainian - Ukraine
"uz": "🇺🇿", // Uzbek - Uzbekistan
"su": "🇮🇩", // Sundanese - Indonesia
"sw": "🇹🇿", // Swahili - Tanzania
"nl": "🇳🇱", // Dutch - Netherlands
"fi": "🇫🇮", // Finnish - Finland
"el": "🇬🇷", // Greek - Greece
"da": "🇩🇰", // Danish - Denmark
"cs": "🇨🇿", // Czech - Czech Republic
"sk": "🇸🇰", // Slovak - Slovakia
"bg": "🇧🇬", // Bulgarian - Bulgaria
"sv": "🇸🇪", // Swedish - Sweden
"be": "🇧🇾", // Belarusian - Belarus
"hu": "🇭🇺", // Hungarian - Hungary
"lt": "🇱🇹", // Lithuanian - Lithuania
"lv": "🇱🇻", // Latvian - Latvia
"et": "🇪🇪", // Estonian - Estonia
"sl": "🇸🇮", // Slovenian - Slovenia
"hr": "🇭🇷", // Croatian - Croatia
"zu": "🇿🇦", // Zulu - South Africa
"id": "🇮🇩", // Indonesian - Indonesia
"is": "🇮🇸", // Icelandic - Iceland
"lb": "🇱🇺", // Luxembourgish - Luxembourg
};
export function getFlagEmoji(countryCode?: string): string | undefined {
if (!countryCode) return;
return languageToFlagMap[countryCode];
}