Compare commits

..

No commits in common. "1c55ae70af672a1dbd2a54056dccb3288cfd5f42" and "5722238c06db1338c1ec5e8f462e7ea243eb6565" have entirely different histories.

8 changed files with 201 additions and 398 deletions

192
bot.ts
View File

@ -1,41 +1,83 @@
import { autoQuote, bold, Bot, Context, hydrateReply, ParseModeFlavor } from "./deps.ts"; import {
import { fmt } from "./intl.ts"; autoQuote,
import { getAllJobs, pushJob } from "./queue.ts"; autoRetry,
import { mySession, MySessionFlavor } from "./session.ts"; bold,
Bot,
Context,
DenoKVAdapter,
fmt,
hydrateReply,
ParseModeFlavor,
session,
SessionFlavor,
} from "./deps.ts";
import { fmtArray, formatOrdinal } from "./intl.ts";
import { queue } from "./queue.ts";
import { SdRequest } from "./sd.ts";
export type MyContext = ParseModeFlavor<Context> & MySessionFlavor; type AppContext = ParseModeFlavor<Context> & SessionFlavor<SessionData>;
export const bot = new Bot<MyContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
interface SessionData {
global: {
adminUsernames: string[];
pausedReason: string | null;
sdApiUrl: string;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdRequest>;
};
user: {
steps: number;
detail: number;
batchSize: number;
};
}
export const bot = new Bot<AppContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(autoQuote); bot.use(autoQuote);
bot.use(hydrateReply); bot.use(hydrateReply);
bot.use(mySession); bot.api.config.use(autoRetry({ maxRetryAttempts: 5, maxDelaySeconds: 60 }));
// Automatically retry bot requests if we get a 429 error const db = await Deno.openKv("./app.db");
bot.api.config.use(async (prev, method, payload, signal) => {
let remainingAttempts = 5; const getDefaultGlobalSession = (): SessionData["global"] => ({
while (true) { adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
const result = await prev(method, payload, signal); pausedReason: null,
if (result.ok) return result; sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
if (result.error_code !== 429 || remainingAttempts <= 0) return result; maxUserJobs: 3,
remainingAttempts -= 1; maxJobs: 20,
const retryAfterMs = (result.parameters?.retry_after ?? 30) * 1000; defaultParams: {
await new Promise((resolve) => setTimeout(resolve, retryAfterMs)); batch_size: 1,
} n_iter: 1,
width: 128 * 2,
height: 128 * 3,
steps: 20,
cfg_scale: 9,
send_images: true,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
}); });
// if error happened, try to reply to the user with the error bot.use(session<SessionData, AppContext>({
bot.use(async (ctx, next) => { type: "multi",
try { global: {
await next(); getSessionKey: () => "global",
} catch (err) { initial: getDefaultGlobalSession,
try { storage: new DenoKVAdapter(db),
await ctx.reply(`Handling update failed: ${err}`, { },
reply_to_message_id: ctx.message?.message_id, user: {
}); initial: () => ({
} catch { steps: 20,
throw err; detail: 8,
} batchSize: 2,
} }),
}); },
}));
export async function getGlobalSession(): Promise<SessionData["global"]> {
const entry = await db.get<SessionData["global"]>(["sessions", "global"]);
return entry.value ?? getDefaultGlobalSession();
}
bot.api.setMyShortDescription("I can generate furry images from text"); bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription( bot.api.setMyDescription(
@ -57,13 +99,12 @@ bot.command("txt2img", async (ctx) => {
if (config.pausedReason != null) { if (config.pausedReason != null) {
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
} }
const jobs = await getAllJobs(); if (queue.length >= config.maxJobs) {
if (jobs.length >= config.maxJobs) {
return ctx.reply( return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`, `The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
); );
} }
const jobCount = jobs.filter((job) => job.user.id === ctx.from.id).length; const jobCount = queue.filter((job) => job.userId === ctx.from.id).length;
if (jobCount >= config.maxUserJobs) { if (jobCount >= config.maxUserJobs) {
return ctx.reply( return ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`, `You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
@ -72,50 +113,36 @@ bot.command("txt2img", async (ctx) => {
if (!ctx.match) { if (!ctx.match) {
return ctx.reply("Please describe what you want to see after the command"); return ctx.reply("Please describe what you want to see after the command");
} }
pushJob({ const place = queue.length + 1;
const queueMessage = await ctx.reply(`You are ${formatOrdinal(place)} in queue.`);
const userName = [ctx.from.first_name, ctx.from.last_name].filter(Boolean).join(" ");
const chatName = ctx.chat.type === "supergroup" || ctx.chat.type === "group"
? ctx.chat.title
: "private chat";
queue.push({
params: { prompt: ctx.match }, params: { prompt: ctx.match },
user: ctx.from, userId: ctx.from.id,
chat: ctx.chat, userName,
requestMessage: ctx.message, chatId: ctx.chat.id,
status: { type: "idle" }, chatName,
requestMessageId: ctx.message.message_id,
statusMessageId: queueMessage.message_id,
}); });
console.log( console.log(`Enqueued job for ${userName} in chat ${chatName}`);
`Enqueued job ${jobs.length + 1} for ${ctx.from.first_name} in ${ctx.chat.type} chat:`,
ctx.match.replace(/\s+/g, " "),
"\n",
);
}); });
bot.command("queue", async (ctx) => { bot.command("queue", (ctx) => {
let jobs = await getAllJobs(); if (queue.length === 0) return ctx.reply("Queue is empty");
const getMessageText = () => { return ctx.replyFmt(
if (jobs.length === 0) return fmt`Queue is empty.`; fmt`Current queue:\n\n${
const sortedJobs = []; fmtArray(
let place = 0; queue.map((job, index) =>
for (const job of jobs) { fmt`${bold(index + 1)}. ${bold(job.userName)} in ${bold(job.chatName)}`
if (job.status.type === "idle") place += 1; ),
sortedJobs.push({ ...job, place }); "\n",
}
return fmt`Current queue:\n\n${
sortedJobs.map((job) =>
fmt`${job.place}. ${bold(job.user.first_name)} in ${job.chat.type} chat ${
job.status.type === "processing" ? `(${(job.status.progress * 100).toFixed(0)}%)` : ""
}\n`
) )
}`; }`,
}; );
const message = await ctx.replyFmt(getMessageText());
handleFutureUpdates();
async function handleFutureUpdates() {
for (let idx = 0; idx < 12; idx++) {
await new Promise((resolve) => setTimeout(resolve, 5000));
jobs = await getAllJobs();
const formattedMessage = getMessageText();
await ctx.api.editMessageText(ctx.chat.id, message.message_id, formattedMessage.text, {
entities: formattedMessage.entities,
}).catch(() => undefined);
}
}
}); });
bot.command("pause", (ctx) => { bot.command("pause", (ctx) => {
@ -245,17 +272,14 @@ bot.command("setsdparam", (ctx) => {
bot.command("sdparams", (ctx) => { bot.command("sdparams", (ctx) => {
if (!ctx.from?.username) return; if (!ctx.from?.username) return;
const config = ctx.session.global; const config = ctx.session.global;
return ctx.replyFmt( return ctx.replyFmt(fmt`Current config:\n\n${
fmt`Current config:\n\n${ fmtArray(
Object.entries(config.defaultParams ?? {}).map(([key, value]) => Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
fmt`${bold(key)} = ${String(value)}\n` fmt`${bold(key)} = ${String(value)}`
) ),
}`, "\n",
); )
}); }`);
bot.command("crash", () => {
throw new Error("Crash command used");
}); });
bot.catch((err) => { bot.catch((err) => {

View File

@ -2,5 +2,5 @@ export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts"; export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as types from "https://deno.land/x/grammy_types@v3.2.0/mod.ts"; export { autoRetry } from "https://esm.sh/@grammyjs/auto-retry@1.1.1";
export * from "https://deno.land/x/ulid@v0.3.0/mod.ts"; export * from "https://deno.land/x/zod/mod.ts";

0
fmtArray.ts Normal file
View File

44
intl.ts
View File

@ -8,36 +8,26 @@ export function formatOrdinal(n: number) {
return `${n}th`; return `${n}th`;
} }
type DeepArray<T> = Array<T | DeepArray<T>>;
type StringLikes = DeepArray<FormattedString | string | number | null | undefined>;
/** /**
* Like `fmt` from `grammy_parse_mode` but additionally accepts arrays. * Like `fmt` from `grammy_parse_mode` but accepts an array instead of template string.
* @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182 * @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182
*/ */
export const fmt = ( export function fmtArray(
rawStringParts: TemplateStringsArray | StringLikes, stringLikes: FormattedString[],
...stringLikes: StringLikes separator = "",
): FormattedString => { ): FormattedString {
let text = ""; let text = "";
const entities: ConstructorParameters<typeof FormattedString>[1][] = []; const entities: ConstructorParameters<typeof FormattedString>[1] = [];
for (let i = 0; i < stringLikes.length; i++) {
const length = Math.max(rawStringParts.length, stringLikes.length); const stringLike = stringLikes[i];
for (let i = 0; i < length; i++) { entities.push(
for (let stringLike of [rawStringParts[i], stringLikes[i]]) { ...stringLike.entities.map((e) => ({
if (Array.isArray(stringLike)) { ...e,
stringLike = fmt(stringLike); offset: e.offset + text.length,
} })),
if (stringLike instanceof FormattedString) { );
entities.push( text += stringLike.toString();
...stringLike.entities.map((e) => ({ if (i < stringLikes.length - 1) text += separator;
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
} }
return new FormattedString(text, entities); return new FormattedString(text, entities);
}; }

View File

@ -1,9 +1,8 @@
import "https://deno.land/std@0.201.0/dotenv/load.ts"; import "https://deno.land/std@0.201.0/dotenv/load.ts";
import { bot } from "./bot.ts"; import { bot } from "./bot.ts";
import { processQueue, returnHangedJobs } from "./queue.ts"; import { processQueue } from "./queue.ts";
await Promise.all([ await Promise.all([
bot.start(), bot.start(),
processQueue(), processQueue(),
returnHangedJobs(),
]); ]);

202
queue.ts
View File

@ -1,167 +1,111 @@
import { InputFile, InputMediaBuilder, types } from "./deps.ts"; import { InputFile, InputMediaBuilder } from "./deps.ts";
import { bot } from "./bot.ts"; import { bot, getGlobalSession } from "./bot.ts";
import { getGlobalSession } from "./session.ts";
import { formatOrdinal } from "./intl.ts"; import { formatOrdinal } from "./intl.ts";
import { SdRequest, txt2img } from "./sd.ts"; import { SdProgressResponse, SdRequest, txt2img } from "./sd.ts";
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts"; import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
import { Model, Store } from "./store.ts";
export const queue: Job[] = [];
interface Job { interface Job {
params: Partial<SdRequest>; params: Partial<SdRequest>;
user: types.User; userId: number;
chat: types.Chat.PrivateChat | types.Chat.GroupChat | types.Chat.SupergroupChat; userName: string;
requestMessage: types.Message & types.Message.TextMessage; chatId: number;
statusMessage?: types.Message & types.Message.TextMessage; chatName: string;
status: { type: "idle" } | { type: "processing"; progress: number; updatedDate: Date }; requestMessageId: number;
} statusMessageId: number;
const db = await Deno.openKv("./app.db");
const jobStore = new Store<Job>(db, "job");
export async function pushJob(job: Job) {
await jobStore.create(job);
}
async function takeJob(): Promise<Model<Job> | null> {
const jobs = await jobStore.list();
const job = jobs.find((job) => job.value.status.type === "idle");
if (!job) return null;
await job.update({ status: { type: "processing", progress: 0, updatedDate: new Date() } });
return job;
}
export async function getAllJobs(): Promise<Array<Job>> {
return await jobStore.list().then((jobs) => jobs.map((job) => job.value));
} }
export async function processQueue() { export async function processQueue() {
while (true) { while (true) {
const job = await takeJob(); const job = queue.shift();
if (!job) { if (!job) {
await new Promise((resolve) => setTimeout(resolve, 1000)); await new Promise((resolve) => setTimeout(resolve, 1000));
continue; continue;
} }
let place = 0; for (const [index, job] of queue.entries()) {
for (const job of await jobStore.list()) { const place = index + 1;
if (job.value.status.type === "idle") place += 1; await bot.api
if (place === 0) continue; .editMessageText(
const statusMessageText = `You are ${formatOrdinal(place)} in queue.`; job.chatId,
if (!job.value.statusMessage) { job.statusMessageId,
await bot.api.sendMessage(job.value.chat.id, statusMessageText, { `You are ${formatOrdinal(place)} in queue.`,
reply_to_message_id: job.value.requestMessage.message_id,
}).catch(() => undefined)
.then((message) => job.update({ statusMessage: message }));
} else {
await bot.api.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
statusMessageText,
) )
.catch(() => undefined); .catch(() => {});
}
} }
try { try {
if (job.value.statusMessage) { await bot.api
await bot.api .deleteMessage(job.chatId, job.statusMessageId)
.deleteMessage(job.value.chat.id, job.value.statusMessage?.message_id) .catch(() => {});
.catch(() => undefined) const progressMessage = await bot.api.sendMessage(
.then(() => job.update({ statusMessage: undefined })); job.chatId,
}
await bot.api.sendMessage(
job.value.chat.id,
"Generating your prompt now...", "Generating your prompt now...",
{ reply_to_message_id: job.value.requestMessage.message_id }, { reply_to_message_id: job.requestMessageId },
).then((message) => job.update({ statusMessage: message })); );
const onProgress = (progress: SdProgressResponse) => {
bot.api
.editMessageText(
job.chatId,
progressMessage.message_id,
`Generating your prompt now... ${
Math.round(
progress.progress * 100,
)
}%`,
)
.catch(() => {});
};
const config = await getGlobalSession(); const config = await getGlobalSession();
const response = await txt2img( const response = await txt2img(
config.sdApiUrl, config.sdApiUrl,
{ ...config.defaultParams, ...job.value.params }, { ...config.defaultParams, ...job.params },
(progress) => { onProgress,
job.update({
status: { type: "processing", progress: progress.progress, updatedDate: new Date() },
});
if (job.value.statusMessage) {
bot.api
.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
`Generating your prompt now... ${
Math.round(
progress.progress * 100,
)
}%`,
)
.catch(() => undefined);
}
},
); );
console.log( console.log(
`Finished job for ${job.value.user.first_name} in ${job.value.chat.type} chat`, `Generated ${response.images.length} images (${
response.images
.map((image) => (image.length / 1024).toFixed(0) + "kB")
.join(", ")
}) for ${job.userName} in ${job.chatName}: ${job.params.prompt?.replace(/\s+/g, " ")}`,
); );
if (job.value.statusMessage) { await bot.api.editMessageText(
await bot.api.editMessageText( job.chatId,
job.value.chat.id, progressMessage.message_id,
job.value.statusMessage.message_id, `Uploading your images...`,
`Uploading your images...`, ).catch(() => {});
).catch(() => undefined);
}
const inputFiles = await Promise.all( const inputFiles = await Promise.all(
response.images.map(async (imageBase64, idx) => { response.images.map(async (imageBase64, idx) => {
const mimeType = mimeTypeFromBase64(imageBase64); const mimeType = mimeTypeFromBase64(imageBase64);
const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`) const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`).then((resp) =>
.then((resp) => resp.blob()); resp.blob()
);
console.log(
`Uploading image ${idx + 1} of ${response.images.length} (${
(imageBlob.size / 1024).toFixed(0)
}kB)`,
);
return InputMediaBuilder.photo( return InputMediaBuilder.photo(
new InputFile(imageBlob, `image_${idx}.${extFromMimeType(mimeType)}`), new InputFile(imageBlob, `${idx}.${extFromMimeType(mimeType)}`),
); );
}), }),
); );
if (job.value.statusMessage) { await bot.api.sendMediaGroup(job.chatId, inputFiles, {
await bot.api reply_to_message_id: job.requestMessageId,
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.catch(() => undefined).then(() => job.update({ statusMessage: undefined }));
}
await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
reply_to_message_id: job.value.requestMessage.message_id,
}); });
await job.delete(); await bot.api
.deleteMessage(job.chatId, progressMessage.message_id)
.catch(() => {});
console.log(`${queue.length} jobs remaining`);
} catch (err) { } catch (err) {
console.error( console.error(
`Failed to generate an image for ${job.value.user.first_name} in ${job.value.chat.type} chat: ${err}`, `Failed to generate image for ${job.userName} in ${job.chatName}: ${job.params.prompt} - ${err}`,
); );
const errorMessage = await bot.api await bot.api
.sendMessage(job.value.chat.id, err.toString(), { .sendMessage(job.chatId, err.toString(), {
reply_to_message_id: job.value.requestMessage.message_id, reply_to_message_id: job.requestMessageId,
}) })
.catch(() => undefined); .catch(() => bot.api.sendMessage(job.chatId, err.toString()))
if (errorMessage) { .catch(() => {});
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.catch(() => undefined)
.then(() => job.update({ statusMessage: undefined }));
}
job.update({ status: { type: "idle" } });
} else {
await job.delete();
}
}
}
}
export async function returnHangedJobs() {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.list();
for (const job of jobs) {
if (job.value.status.type === "idle") continue;
// if job wasn't updated for 1 minute, return it to the queue
if (job.value.status.updatedDate.getTime() < Date.now() - 60 * 1000) {
console.log(
`Returning hanged job for ${job.value.user.first_name} in ${job.value.chat.type} chat`,
);
await job.update({ status: { type: "idle" } });
}
} }
} }
} }

View File

@ -1,78 +0,0 @@
import { Context, DenoKVAdapter, session, SessionFlavor } from "./deps.ts";
import { SdRequest } from "./sd.ts";
export type MySessionFlavor = SessionFlavor<SessionData>;
export interface SessionData {
global: GlobalData;
chat: ChatData;
user: UserData;
}
export interface GlobalData {
adminUsernames: string[];
pausedReason: string | null;
sdApiUrl: string;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdRequest>;
}
export interface ChatData {
language: string;
}
export interface UserData {
steps: number;
detail: number;
batchSize: number;
}
const globalDb = await Deno.openKv("./app.db");
const globalDbAdapter = new DenoKVAdapter<GlobalData>(globalDb);
const getDefaultGlobalData = (): GlobalData => ({
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
pausedReason: null,
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
maxUserJobs: 3,
maxJobs: 20,
defaultParams: {
batch_size: 1,
n_iter: 1,
width: 128 * 2,
height: 128 * 3,
steps: 20,
cfg_scale: 9,
send_images: true,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
});
export const mySession = session<SessionData, Context & MySessionFlavor>({
type: "multi",
global: {
getSessionKey: () => "global",
initial: getDefaultGlobalData,
storage: globalDbAdapter,
},
chat: {
initial: () => ({
language: "en",
}),
},
user: {
getSessionKey: (ctx) => ctx.from?.id.toFixed(),
initial: () => ({
steps: 20,
detail: 8,
batchSize: 2,
}),
},
});
export async function getGlobalSession(): Promise<GlobalData> {
const data = await globalDbAdapter.read("global");
return data ?? getDefaultGlobalData();
}

View File

@ -1,76 +0,0 @@
import { ulid } from "./deps.ts";
export class Store<T extends object> {
constructor(
private readonly db: Deno.Kv,
private readonly storeKey: Deno.KvKeyPart,
) {
}
async create(value: T): Promise<Model<T>> {
const id = ulid();
await this.db.set([this.storeKey, id], value);
return new Model(this.db, this.storeKey, id, value);
}
async get(id: Deno.KvKeyPart): Promise<Model<T> | null> {
const entry = await this.db.get<T>([this.storeKey, id]);
if (entry.versionstamp == null) return null;
return new Model(this.db, this.storeKey, id, entry.value);
}
async list(): Promise<Array<Model<T>>> {
const models: Array<Model<T>> = [];
for await (const entry of this.db.list<T>({ prefix: [this.storeKey] })) {
models.push(new Model(this.db, this.storeKey, entry.key[1], entry.value));
}
return models;
}
}
export class Model<T extends object> {
#value: T;
constructor(
private readonly db: Deno.Kv,
private readonly storeKey: Deno.KvKeyPart,
private readonly entryKey: Deno.KvKeyPart,
value: T,
) {
this.#value = value;
}
get value(): T {
return this.#value;
}
async get(): Promise<T | null> {
const entry = await this.db.get<T>([this.storeKey, this.entryKey]);
if (entry.versionstamp == null) return null;
this.#value = entry.value;
return entry.value;
}
async set(value: T): Promise<T> {
await this.db.set([this.storeKey, this.entryKey], value);
this.#value = value;
return value;
}
async update(value: Partial<T> | ((value: T) => T)): Promise<T | null> {
const entry = await this.db.get<T>([this.storeKey, this.entryKey]);
if (entry.versionstamp == null) return null;
if (typeof value === "function") {
entry.value = value(entry.value);
} else {
entry.value = { ...entry.value, ...value };
}
await this.db.set([this.storeKey, this.entryKey], entry.value);
this.#value = entry.value;
return entry.value;
}
async delete(): Promise<void> {
await this.db.delete([this.storeKey, this.entryKey]);
}
}