big rewrite

This commit is contained in:
pinks 2023-09-10 20:56:17 +02:00
parent ba2afe40ce
commit 6c66a00910
25 changed files with 1013 additions and 934 deletions

View File

@ -11,12 +11,44 @@ Telegram bot for generating images from text.
You can put these in `.env` file or pass them as environment variables.
- `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather))
- `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`)
- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands
(optional)
- `TG_BOT_TOKEN` - Telegram bot token. Get yours from [@BotFather](https://t.me/BotFather).
Required.
- `SD_API_URL` - URL to Stable Diffusion API. Only used on first run. Default:
`http://127.0.0.1:7860/`
- `TG_ADMIN_USERS` - Comma separated list of usernames of users that can use admin commands. Only
used on first run. Optional.
## Running
- Start stable diffusion webui: `cd sd-webui`, `./webui.sh --api`
- Start bot: `deno task start`
## TODO
- [x] Keep generation history
- [x] Changing params, parsing png info in request
- [x] Cancelling jobs by deleting message
- [x] Multiple parallel workers
- [ ] Replying to another text message to copy prompt and generate
- [ ] Replying to bot message, conversation in DMs
- [ ] Replying to png message to extract png info nad generate
- [ ] Banning tags
- [ ] Img2Img + Upscale
- [ ] Admin WebUI
- [ ] User daily generation limits
- [ ] Querying all generation history, displaying stats
- [ ] Analyzing prompt quality based on tag csv
- [ ] Report aliased/unknown tags based on csv
- [ ] Report unknown loras
- [ ] Investigate "sendMediaGroup failed"
- [ ] Changing sampler without error on unknown sampler
- [ ] Changing model
- [ ] Inpaint using telegram photo edit
- [ ] Outpaint
- [ ] Non-SD (extras) upscale
- [ ] Tiled generation to allow very big images
- [ ] Downloading raw images
- [ ] Extra prompt syntax, fixing `()+++` syntax
- [ ] Translations
- replace fmtDuration usage
- replace formatOrdinal usage

273
bot.ts
View File

@ -1,273 +0,0 @@
import { autoQuote, bold, Bot, Context, hydrateReply, log, ParseModeFlavor } from "./deps.ts";
import { fmt } from "./intl.ts";
import { getAllJobs, pushJob } from "./queue.ts";
import { mySession, MySessionFlavor } from "./session.ts";
const logger = () => log.getLogger();
export type MyContext = ParseModeFlavor<Context> & MySessionFlavor;
export const bot = new Bot<MyContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(autoQuote);
bot.use(hydrateReply);
bot.use(mySession);
// Automatically retry bot requests if we get a 429 error
bot.api.config.use(async (prev, method, payload, signal) => {
let remainingAttempts = 5;
while (true) {
const result = await prev(method, payload, signal);
if (result.ok) return result;
if (result.error_code !== 429 || remainingAttempts <= 0) return result;
remainingAttempts -= 1;
const retryAfterMs = (result.parameters?.retry_after ?? 30) * 1000;
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
}
});
// if error happened, try to reply to the user with the error
bot.use(async (ctx, next) => {
try {
await next();
} catch (err) {
try {
await ctx.reply(`Handling update failed: ${err}`, {
reply_to_message_id: ctx.message?.message_id,
});
} catch {
throw err;
}
}
});
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. " +
"Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
{ command: "sdparams", description: "Show the current SD parameters" },
]);
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
bot.command("txt2img", async (ctx) => {
if (!ctx.from?.id) {
return ctx.reply("I don't know who you are");
}
const config = ctx.session.global;
if (config.pausedReason != null) {
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
}
const jobs = await getAllJobs();
if (jobs.length >= config.maxJobs) {
return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
);
}
const jobCount = jobs.filter((job) => job.user.id === ctx.from.id).length;
if (jobCount >= config.maxUserJobs) {
return ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
);
}
if (!ctx.match) {
return ctx.reply("Please describe what you want to see after the command");
}
const statusMessage = await ctx.reply("Accepted. You are now in queue.");
await pushJob({
params: { prompt: ctx.match },
user: ctx.from,
chat: ctx.chat,
requestMessage: ctx.message,
statusMessage,
status: { type: "idle" },
});
logger().info("Job enqueued", ctx.from.first_name, ctx.chat.type, ctx.match.replace(/\s+/g, " "));
});
bot.command("queue", async (ctx) => {
let jobs = await getAllJobs();
const getMessageText = () => {
if (jobs.length === 0) return fmt`Queue is empty.`;
const sortedJobs = [];
let place = 0;
for (const job of jobs) {
if (job.status.type === "idle") place += 1;
sortedJobs.push({ ...job, place });
}
return fmt`Current queue:\n\n${
sortedJobs.map((job) =>
fmt`${job.place}. ${bold(job.user.first_name)} in ${job.chat.type} chat ${
job.status.type === "processing" ? `(${(job.status.progress * 100).toFixed(0)}%)` : ""
}\n`
)
}`;
};
const message = await ctx.replyFmt(getMessageText());
handleFutureUpdates();
async function handleFutureUpdates() {
for (let idx = 0; idx < 12; idx++) {
await new Promise((resolve) => setTimeout(resolve, 5000));
jobs = await getAllJobs();
const formattedMessage = getMessageText();
await ctx.api.editMessageText(ctx.chat.id, message.message_id, formattedMessage.text, {
entities: formattedMessage.entities,
}).catch(() => undefined);
}
}
});
bot.command("pause", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason != null) {
return ctx.reply(`Already paused: ${config.pausedReason}`);
}
config.pausedReason = ctx.match ?? "No reason given";
return ctx.reply("Paused");
});
bot.command("resume", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason == null) return ctx.reply("Already running");
config.pausedReason = null;
return ctx.reply("Resumed");
});
bot.command("setsdapiurl", async (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (!ctx.match) return ctx.reply("Please specify an URL");
let url: URL;
try {
url = new URL(ctx.match);
} catch {
return ctx.reply("Invalid URL");
}
let resp: Response;
try {
resp = await fetch(new URL("config", url));
} catch (err) {
return ctx.reply(`Could not connect: ${err}`);
}
if (!resp.ok) {
return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`);
}
let data: unknown;
try {
data = await resp.json();
} catch {
return ctx.reply("Invalid response from API");
}
if (data != null && typeof data === "object" && "version" in data) {
config.sdApiUrl = url.toString();
return ctx.reply(`Now using SD at ${url} running version ${data.version}`);
} else {
return ctx.reply("Invalid response from API");
}
});
bot.command("setsdparam", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim());
if (!param) return ctx.reply("Please specify a parameter");
if (value == null) return ctx.reply("Please specify a value after the =");
param = param.toLowerCase().replace(/[\s_]+/g, "");
if (config.defaultParams == null) config.defaultParams = {};
switch (param) {
case "steps": {
const steps = parseInt(value);
if (isNaN(steps)) return ctx.reply("Invalid number value");
if (steps > 100) return ctx.reply("Steps must be less than 100");
if (steps < 10) return ctx.reply("Steps must be greater than 10");
config.defaultParams.steps = steps;
return ctx.reply("Steps set to " + steps);
}
case "detail":
case "cfgscale": {
const detail = parseInt(value);
if (isNaN(detail)) return ctx.reply("Invalid number value");
if (detail > 20) return ctx.reply("Detail must be less than 20");
if (detail < 1) return ctx.reply("Detail must be greater than 1");
config.defaultParams.cfg_scale = detail;
return ctx.reply("Detail set to " + detail);
}
case "niter":
case "niters": {
const nIter = parseInt(value);
if (isNaN(nIter)) return ctx.reply("Invalid number value");
if (nIter > 10) return ctx.reply("Iterations must be less than 10");
if (nIter < 1) return ctx.reply("Iterations must be greater than 1");
config.defaultParams.n_iter = nIter;
return ctx.reply("Iterations set to " + nIter);
}
case "batchsize": {
const batchSize = parseInt(value);
if (isNaN(batchSize)) return ctx.reply("Invalid number value");
if (batchSize > 8) return ctx.reply("Batch size must be less than 8");
if (batchSize < 1) return ctx.reply("Batch size must be greater than 1");
config.defaultParams.batch_size = batchSize;
return ctx.reply("Batch size set to " + batchSize);
}
case "size": {
let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim()));
if (!width || !height || isNaN(width) || isNaN(height)) {
return ctx.reply("Invalid size value");
}
if (width > 2048) return ctx.reply("Width must be less than 2048");
if (height > 2048) return ctx.reply("Height must be less than 2048");
// find closest multiple of 64
width = Math.round(width / 64) * 64;
height = Math.round(height / 64) * 64;
if (width <= 0) return ctx.reply("Width too small");
if (height <= 0) return ctx.reply("Height too small");
config.defaultParams.width = width;
config.defaultParams.height = height;
return ctx.reply(`Size set to ${width}x${height}`);
}
case "negativeprompt": {
config.defaultParams.negative_prompt = value;
return ctx.reply(`Negative prompt set to: ${value}`);
}
default: {
return ctx.reply("Invalid parameter");
}
}
});
bot.command("sdparams", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
return ctx.replyFmt(
fmt`Current config:\n\n${
Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
fmt`${bold(key)} = ${String(value)}\n`
)
}`,
);
});
bot.command("crash", () => {
throw new Error("Crash command used");
});
bot.catch((err) => {
let msg = "Error processing update";
const { from, chat } = err.ctx;
if (from?.first_name) msg += ` from ${from.first_name}`;
if (from?.last_name) msg += ` ${from.last_name}`;
if (from?.username) msg += ` (@${from.username})`;
if (chat?.type === "supergroup" || chat?.type === "group") {
msg += ` in ${chat.title}`;
if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`;
}
logger().error("handling update failed", from?.first_name, chat?.type, err);
});

92
bot/mod.ts Normal file
View File

@ -0,0 +1,92 @@
import { Grammy, GrammyAutoQuote, GrammyParseMode, Log } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { session, SessionFlavor } from "./session.ts";
import { queueCommand } from "./queueCommand.ts";
import { txt2imgCommand } from "./txt2imgCommand.ts";
export const logger = () => Log.getLogger();
export type Context = GrammyParseMode.ParseModeFlavor<Grammy.Context> & SessionFlavor;
export const bot = new Grammy.Bot<Context>(Deno.env.get("TG_BOT_TOKEN") ?? "");
bot.use(GrammyAutoQuote.autoQuote);
bot.use(GrammyParseMode.hydrateReply);
bot.use(session);
bot.catch((err) => {
logger().error(`Handling update from ${formatUserChat(err.ctx)} failed: ${err}`);
});
// Automatically retry bot requests if we get a "too many requests" or telegram internal error
bot.api.config.use(async (prev, method, payload, signal) => {
let attempt = 0;
while (true) {
attempt++;
const result = await prev(method, payload, signal);
if (
result.ok ||
![429, 500].includes(result.error_code) ||
attempt >= 5
) {
return result;
}
const retryAfterMs = (result.parameters?.retry_after ?? (attempt * 5)) * 1000;
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
}
});
// if error happened, try to reply to the user with the error
bot.use(async (ctx, next) => {
try {
await next();
} catch (err) {
try {
await ctx.reply(`Handling update failed: ${err}`, {
reply_to_message_id: ctx.message?.message_id,
});
} catch {
throw err;
}
}
});
bot.api.setMyShortDescription("I can generate furry images from text");
bot.api.setMyDescription(
"I can generate furry images from text. " +
"Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "queue", description: "Show the current queue" },
]);
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
bot.command("txt2img", txt2imgCommand);
bot.command("queue", queueCommand);
bot.command("pause", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason != null) {
return ctx.reply(`Already paused: ${config.pausedReason}`);
}
config.pausedReason = ctx.match ?? "No reason given";
logger().warning(`Bot paused by ${ctx.from.first_name} because ${config.pausedReason}`);
return ctx.reply("Paused");
});
bot.command("resume", (ctx) => {
if (!ctx.from?.username) return;
const config = ctx.session.global;
if (!config.adminUsernames.includes(ctx.from.username)) return;
if (config.pausedReason == null) return ctx.reply("Already running");
config.pausedReason = null;
logger().info(`Bot resumed by ${ctx.from.first_name}`);
return ctx.reply("Resumed");
});
bot.command("crash", () => {
throw new Error("Crash command used");
});

67
bot/queueCommand.ts Normal file
View File

@ -0,0 +1,67 @@
import { Grammy, GrammyParseMode } from "../deps.ts";
import { fmt, getFlagEmoji } from "../utils.ts";
import { runningWorkers } from "../tasks/pingWorkers.ts";
import { jobStore } from "../db/jobStore.ts";
import { Context, logger } from "./mod.ts";
export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
let formattedMessage = await getMessageText();
const queueMessage = await ctx.replyFmt(formattedMessage);
handleFutureUpdates().catch((err) => logger().warning(`Updating queue message failed: ${err}`));
async function getMessageText() {
const processingJobs = await jobStore.getBy("status.type", "processing")
.then((jobs) => jobs.map((job) => ({ ...job.value, place: 0 })));
const waitingJobs = await jobStore.getBy("status.type", "waiting")
.then((jobs) => jobs.map((job, index) => ({ ...job.value, place: index + 1 })));
const jobs = [...processingJobs, ...waitingJobs];
const config = ctx.session.global;
const { bold } = GrammyParseMode;
return fmt([
"Current queue:\n",
...jobs.length > 0
? jobs.flatMap((job) => [
`${job.place}. `,
fmt`${bold(job.request.from.first_name)} `,
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
job.request.from.username ? `(@${job.request.from.username}) ` : "",
getFlagEmoji(job.request.from.language_code) ?? "",
job.request.chat.type === "private"
? " in private chat "
: ` in ${job.request.chat.title} `,
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
job.request.chat.username
? `(@${job.request.chat.username}) `
: "",
job.status.type === "processing"
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
: "",
"\n",
])
: ["Queue is empty.\n"],
"\nActive workers:\n",
...config.workers.flatMap((worker) => [
runningWorkers.has(worker.name) ? "✅ " : "☠️ ",
fmt`${bold(worker.name)} `,
`(max ${(worker.maxResolution / 1000000).toFixed(1)} Mpx) `,
"\n",
]),
]);
}
async function handleFutureUpdates() {
for (let idx = 0; idx < 20; idx++) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const nextFormattedMessage = await getMessageText();
if (nextFormattedMessage.text !== formattedMessage.text) {
await ctx.api.editMessageText(
ctx.chat.id,
queueMessage.message_id,
nextFormattedMessage.text,
{ entities: nextFormattedMessage.entities },
);
formattedMessage = nextFormattedMessage;
}
}
}
}

View File

@ -1,7 +1,7 @@
import { Context, DenoKVAdapter, session, SessionFlavor } from "./deps.ts";
import { SdTxt2ImgRequest } from "./sd.ts";
import { Grammy, GrammyKvStorage } from "../deps.ts";
import { SdApi, SdTxt2ImgRequest } from "../sd.ts";
export type MySessionFlavor = SessionFlavor<SessionData>;
export type SessionFlavor = Grammy.SessionFlavor<SessionData>;
export interface SessionData {
global: GlobalData;
@ -12,45 +12,55 @@ export interface SessionData {
export interface GlobalData {
adminUsernames: string[];
pausedReason: string | null;
sdApiUrl: string;
maxUserJobs: number;
maxJobs: number;
defaultParams?: Partial<SdTxt2ImgRequest>;
workers: WorkerData[];
}
export interface WorkerData {
name: string;
api: SdApi;
auth?: string;
maxResolution: number;
}
export interface ChatData {
language: string;
language?: string;
}
export interface UserData {
steps: number;
detail: number;
batchSize: number;
params?: Partial<SdTxt2ImgRequest>;
}
const globalDb = await Deno.openKv("./app.db");
const globalDbAdapter = new DenoKVAdapter<GlobalData>(globalDb);
const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter<GlobalData>(globalDb);
const getDefaultGlobalData = (): GlobalData => ({
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
adminUsernames: Deno.env.get("TG_ADMIN_USERS")?.split(",") ?? [],
pausedReason: null,
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
maxUserJobs: 3,
maxJobs: 20,
defaultParams: {
batch_size: 1,
n_iter: 1,
width: 128 * 2,
height: 128 * 3,
steps: 20,
cfg_scale: 9,
send_images: true,
width: 512,
height: 768,
steps: 30,
cfg_scale: 10,
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
},
workers: [
{
name: "local",
api: { url: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/" },
maxResolution: 1024 * 1024,
},
],
});
export const mySession = session<SessionData, Context & MySessionFlavor>({
export const session = Grammy.session<SessionData, Grammy.Context & SessionFlavor>({
type: "multi",
global: {
getSessionKey: () => "global",
@ -58,17 +68,11 @@ export const mySession = session<SessionData, Context & MySessionFlavor>({
storage: globalDbAdapter,
},
chat: {
initial: () => ({
language: "en",
}),
initial: () => ({}),
},
user: {
getSessionKey: (ctx) => ctx.from?.id.toFixed(),
initial: () => ({
steps: 20,
detail: 8,
batchSize: 2,
}),
initial: () => ({}),
},
});

39
bot/txt2imgCommand.ts Normal file
View File

@ -0,0 +1,39 @@
import { Grammy } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { parsePngInfo } from "../sd.ts";
import { Context, logger } from "./mod.ts";
export async function txt2imgCommand(ctx: Grammy.CommandContext<Context>) {
if (!ctx.from?.id) {
return ctx.reply("I don't know who you are");
}
const config = ctx.session.global;
if (config.pausedReason != null) {
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
}
const jobs = await jobStore.getBy("status.type", "waiting");
if (jobs.length >= config.maxJobs) {
return ctx.reply(
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
);
}
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.from?.id);
if (userJobs.length >= config.maxUserJobs) {
return ctx.reply(
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
);
}
const params = parsePngInfo(ctx.match);
if (!params.prompt) {
return ctx.reply("Please describe what you want to see after the command");
}
const reply = await ctx.reply("Accepted. You are now in queue.");
await jobStore.create({
params,
request: ctx.message,
reply,
status: { type: "waiting" },
});
logger().debug(`Job enqueued for ${formatUserChat(ctx)}`);
}

1
db/db.ts Normal file
View File

@ -0,0 +1 @@
export const db = await Deno.openKv("./app.db");

18
db/jobStore.ts Normal file
View File

@ -0,0 +1,18 @@
import { GrammyTypes, IKV } from "../deps.ts";
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
import { db } from "./db.ts";
export interface JobSchema {
params: Partial<SdTxt2ImgRequest>;
request: GrammyTypes.Message.TextMessage & { from: GrammyTypes.User };
reply?: GrammyTypes.Message.TextMessage;
status:
| { type: "waiting" }
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
| { type: "done"; info?: SdTxt2ImgInfo; startDate?: Date; endDate?: Date };
}
export const jobStore = new IKV.Store(db, "job", {
schema: new IKV.Schema<JobSchema>(),
indices: ["status.type"],
});

25
deps.ts
View File

@ -1,7 +1,18 @@
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as types from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
export * from "https://deno.land/x/ulid@v0.3.0/mod.ts";
export * as log from "https://deno.land/std@0.201.0/log/mod.ts";
export * as Log from "https://deno.land/std@0.201.0/log/mod.ts";
export * as Async from "https://deno.land/std@0.201.0/async/mod.ts";
export * as FmtDuration from "https://deno.land/std@0.201.0/fmt/duration.ts";
export * as Collections from "https://deno.land/std@0.201.0/collections/mod.ts";
export * as Base64 from "https://deno.land/std@0.201.0/encoding/base64.ts";
export * as AsyncX from "https://deno.land/x/async@v2.0.2/mod.ts";
export * as ULID from "https://deno.land/x/ulid@v0.3.0/mod.ts";
export * as IKV from "https://deno.land/x/indexed_kv@v0.2.0/mod.ts";
export * as Grammy from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as FileType from "npm:file-type@18.5.0";
// @deno-types="./types/png-chunks-extract.d.ts"
export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0";
// @deno-types="./types/png-chunk-text.d.ts"
export * as PngChunkText from "npm:png-chunk-text@1.0.0";

43
intl.ts
View File

@ -1,43 +0,0 @@
import { FormattedString } from "./deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
type DeepArray<T> = Array<T | DeepArray<T>>;
type StringLikes = DeepArray<FormattedString | string | number | null | undefined>;
/**
* Like `fmt` from `grammy_parse_mode` but additionally accepts arrays.
* @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182
*/
export const fmt = (
rawStringParts: TemplateStringsArray | StringLikes,
...stringLikes: StringLikes
): FormattedString => {
let text = "";
const entities: ConstructorParameters<typeof FormattedString>[1][] = [];
const length = Math.max(rawStringParts.length, stringLikes.length);
for (let i = 0; i < length; i++) {
for (let stringLike of [rawStringParts[i], stringLikes[i]]) {
if (Array.isArray(stringLike)) {
stringLike = fmt(stringLike);
}
if (stringLike instanceof FormattedString) {
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
}
return new FormattedString(text, entities);
};

23
main.ts
View File

@ -1,24 +1,21 @@
// Load environment variables from .env file
import "https://deno.land/std@0.201.0/dotenv/load.ts";
import { bot } from "./bot.ts";
import { processQueue, returnHangedJobs } from "./queue.ts";
import { log } from "./deps.ts";
log.setup({
// Setup logging
import { Log } from "./deps.ts";
Log.setup({
handlers: {
console: new log.handlers.ConsoleHandler("INFO", {
formatter: (record) =>
`[${record.levelName}] ${record.msg} ${
record.args.map((arg) => JSON.stringify(arg)).join(" ")
} (${record.datetime.toISOString()})`,
}),
console: new Log.handlers.ConsoleHandler("DEBUG"),
},
loggers: {
default: { level: "INFO", handlers: ["console"] },
default: { level: "DEBUG", handlers: ["console"] },
},
});
// Main program logic
import { bot } from "./bot/mod.ts";
import { runAllTasks } from "./tasks/mod.ts";
await Promise.all([
bot.start(),
processQueue(),
returnHangedJobs(),
runAllTasks(),
]);

View File

@ -1,15 +0,0 @@
export function mimeTypeFromBase64(base64: string) {
if (base64.startsWith("/9j/")) return "image/jpeg";
if (base64.startsWith("iVBORw0KGgo")) return "image/png";
if (base64.startsWith("R0lGODlh")) return "image/gif";
if (base64.startsWith("UklGRg")) return "image/webp";
throw new Error("Unknown image type");
}
export function extFromMimeType(mimeType: string) {
if (mimeType === "image/jpeg") return "jpg";
if (mimeType === "image/png") return "png";
if (mimeType === "image/gif") return "gif";
if (mimeType === "image/webp") return "webp";
throw new Error("Unknown image type");
}

181
queue.ts
View File

@ -1,181 +0,0 @@
import { InputFile, InputMediaBuilder, log, types } from "./deps.ts";
import { bot } from "./bot.ts";
import { getGlobalSession } from "./session.ts";
import { formatOrdinal } from "./intl.ts";
import { SdTxt2ImgRequest, SdTxt2ImgResponse, txt2img } from "./sd.ts";
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
import { Model, Schema, Store } from "./store.ts";
const logger = () => log.getLogger();
interface Job {
params: Partial<SdTxt2ImgRequest>;
user: types.User;
chat: types.Chat.PrivateChat | types.Chat.GroupChat | types.Chat.SupergroupChat;
requestMessage: types.Message & types.Message.TextMessage;
statusMessage?: types.Message & types.Message.TextMessage;
status:
| { type: "idle" }
| { type: "processing"; progress: number; updatedDate: Date };
}
const db = await Deno.openKv("./app.db");
const jobStore = new Store(db, "job", {
schema: new Schema<Job>(),
indices: ["status.type", "user.id", "chat.id"],
});
jobStore.getBy("user.id", 123).then(() => {});
export async function pushJob(job: Job) {
await jobStore.create(job);
}
async function takeJob(): Promise<Model<Job> | null> {
const jobs = await jobStore.getAll();
const job = jobs.find((job) => job.value.status.type === "idle");
if (!job) return null;
await job.update({ status: { type: "processing", progress: 0, updatedDate: new Date() } });
return job;
}
export async function getAllJobs(): Promise<Array<Job>> {
return await jobStore.getAll().then((jobs) => jobs.map((job) => job.value));
}
export async function processQueue() {
while (true) {
const job = await takeJob().catch((err) =>
void logger().warning("failed getting job", err.message)
);
if (!job) {
await new Promise((resolve) => setTimeout(resolve, 1000));
continue;
}
let place = 0;
for (const job of await jobStore.getAll().catch(() => [])) {
if (job.value.status.type === "idle") place += 1;
if (place === 0) continue;
const statusMessageText = `You are ${formatOrdinal(place)} in queue.`;
if (!job.value.statusMessage) {
await bot.api.sendMessage(job.value.chat.id, statusMessageText, {
reply_to_message_id: job.value.requestMessage.message_id,
}).catch(() => undefined)
.then((message) => job.update({ statusMessage: message })).catch(() => undefined);
} else {
await bot.api.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
statusMessageText,
)
.catch(() => undefined);
}
}
try {
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage?.message_id)
.catch(() => undefined)
.then(() => job.update({ statusMessage: undefined }));
}
await bot.api.sendMessage(
job.value.chat.id,
"Generating your prompt now...",
{ reply_to_message_id: job.value.requestMessage.message_id },
).then((message) => job.update({ statusMessage: message }));
const config = await getGlobalSession();
const response = await txt2img(
config.sdApiUrl,
{ ...config.defaultParams, ...job.value.params },
(progress) => {
job.update({
status: { type: "processing", progress: progress.progress, updatedDate: new Date() },
});
if (job.value.statusMessage) {
bot.api
.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
`Generating your prompt now... ${
Math.round(
progress.progress * 100,
)
}%`,
)
.catch(() => undefined);
}
},
);
const jobCount = (await jobStore.getAll()).filter((job) =>
job.value.status.type !== "processing"
).length;
logger().info("Job finished", job.value.user.first_name, job.value.chat.type, { jobCount });
if (job.value.statusMessage) {
await bot.api.editMessageText(
job.value.chat.id,
job.value.statusMessage.message_id,
`Uploading your images...`,
).catch(() => undefined);
}
const inputFiles = await Promise.all(
response.images.map(async (imageBase64, idx) => {
const mimeType = mimeTypeFromBase64(imageBase64);
const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`)
.then((resp) => resp.blob());
return InputMediaBuilder.photo(
new InputFile(imageBlob, `image_${idx}.${extFromMimeType(mimeType)}`),
);
}),
);
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.catch(() => undefined).then(() => job.update({ statusMessage: undefined }));
}
await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
reply_to_message_id: job.value.requestMessage.message_id,
});
await job.delete();
} catch (err) {
logger().error("Job failed", job.value.user.first_name, job.value.chat.type, err);
const errorMessage = await bot.api
.sendMessage(job.value.chat.id, err.toString(), {
reply_to_message_id: job.value.requestMessage.message_id,
})
.catch(() => undefined);
if (errorMessage) {
if (job.value.statusMessage) {
await bot.api
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
.then(() => job.update({ statusMessage: undefined }))
.catch(() => void logger().warning("failed deleting status message", err.message));
}
await job.update({ status: { type: "idle" } }).catch((err) =>
void logger().warning("failed returning job", err.message)
);
} else {
await job.delete().catch((err) =>
void logger().warning("failed deleting job", err.message)
);
}
}
}
}
export async function returnHangedJobs() {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getAll().catch(() => []);
for (const job of jobs) {
if (job.value.status.type !== "processing") continue;
// if job wasn't updated for 1 minute, return it to the queue
if (job.value.status.updatedDate.getTime() < Date.now() - 60 * 1000) {
logger().warning("Hanged job returned", job.value.user.first_name, job.value.chat.type);
await job.update({ status: { type: "idle" } }).catch((err) =>
void logger().warning("failed returning job", err.message)
);
}
}
}
}

294
sd.ts
View File

@ -1,53 +1,88 @@
export async function txt2img(
apiUrl: string,
import { Async, AsyncX, PngChunksExtract, PngChunkText } from "./deps.ts";
const neverSignal = new AbortController().signal;
export interface SdApi {
url: string;
auth?: string;
}
async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Promise<T> {
let options: RequestInit | undefined;
if (body != null) {
options = {
method: "POST",
headers: {
"Content-Type": "application/json",
...api.auth ? { Authorization: api.auth } : {},
},
body: JSON.stringify(body),
};
} else if (api.auth) {
options = {
headers: { Authorization: api.auth },
};
}
const response = await fetch(new URL(endpoint, api.url), options).catch(() => {
throw new SdApiError(endpoint, options, 0, "Network error");
});
const result = await response.json().catch(() => {
throw new SdApiError(endpoint, options, response.status, response.statusText, {
detail: "Invalid JSON",
});
});
if (!response.ok) {
throw new SdApiError(endpoint, options, response.status, response.statusText, result);
}
return result;
}
export async function sdTxt2Img(
api: SdApi,
params: Partial<SdTxt2ImgRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal?: AbortSignal,
signal: AbortSignal = neverSignal,
): Promise<SdTxt2ImgResponse> {
let response: Response | undefined;
let error: unknown;
fetch(new URL("sdapi/v1/txt2img", apiUrl), {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(params),
}).then(
(resp) => (response = resp),
(err) => (error = err),
);
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
}));
try {
while (true) {
await new Promise((resolve) => setTimeout(resolve, 3000));
const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl));
if (progressRequest.ok) {
const progress = (await progressRequest.json()) as SdProgressResponse;
onProgress?.(progress);
}
if (response != null) {
if (response.ok) {
const result = (await response.json()) as SdTxt2ImgResponse;
return result;
} else {
throw new Error(`Request failed: ${response.status} ${response.statusText}`);
}
}
if (error != null) {
throw error;
}
signal?.throwIfAborted();
await Async.abortable(Promise.race([request, Async.delay(3000)]), signal);
if (await AsyncX.promiseState(request) !== "pending") return await request;
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
}
} finally {
if (!response && !error) {
await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" });
if (await AsyncX.promiseState(request) === "pending") {
await fetchSdApi(api, "sdapi/v1/interrupt", {});
}
}
}
export interface SdTxt2ImgRequest {
enable_hr: boolean;
denoising_strength: number;
firstphase_width: number;
firstphase_height: number;
hr_scale: number;
hr_upscaler: unknown;
hr_second_pass_steps: number;
hr_resize_x: number;
hr_resize_y: number;
hr_sampler_name: unknown;
hr_prompt: string;
hr_negative_prompt: string;
prompt: string;
styles: unknown;
seed: number;
subseed: number;
subseed_strength: number;
seed_resize_from_h: number;
seed_resize_from_w: number;
sampler_name: unknown;
batch_size: number;
n_iter: number;
@ -55,16 +90,68 @@ export interface SdTxt2ImgRequest {
cfg_scale: number;
width: number;
height: number;
restore_faces: boolean;
tiling: boolean;
do_not_save_samples: boolean;
do_not_save_grid: boolean;
negative_prompt: string;
eta: unknown;
s_min_uncond: number;
s_churn: number;
s_tmax: unknown;
s_tmin: number;
s_noise: number;
override_settings: object;
override_settings_restore_afterwards: boolean;
script_args: unknown[];
sampler_index: string;
script_name: unknown;
send_images: boolean;
save_images: boolean;
alwayson_scripts: object;
}
export interface SdTxt2ImgResponse {
images: string[];
parameters: SdTxt2ImgRequest;
/** Contains serialized JSON */
info: string;
// Warning: raw response from API is a JSON-serialized string
info: SdTxt2ImgInfo;
}
export interface SdTxt2ImgInfo {
prompt: string;
all_prompts: string[];
negative_prompt: string;
all_negative_prompts: string[];
seed: number;
all_seeds: number[];
subseed: number;
all_subseeds: number[];
subseed_strength: number;
width: number;
height: number;
sampler_name: string;
cfg_scale: number;
steps: number;
batch_size: number;
restore_faces: boolean;
face_restoration_model: unknown;
sd_model_hash: string;
seed_resize_from_w: number;
seed_resize_from_h: number;
denoising_strength: number;
extra_generation_params: SdTxt2ImgInfoExtraParams;
index_of_first_image: number;
infotexts: string[];
styles: unknown[];
job_timestamp: string;
clip_skip: number;
is_using_inpainting_conditioning: boolean;
}
export interface SdTxt2ImgInfoExtraParams {
"Lora hashes": string;
"TI hashes": string;
}
export interface SdProgressResponse {
@ -86,3 +173,138 @@ export interface SdProgressState {
sampling_step: number;
sampling_steps: number;
}
export function sdGetConfig(api: SdApi): Promise<SdConfigResponse> {
return fetchSdApi(api, "config");
}
export interface SdConfigResponse {
/** version with new line at the end for some reason */
version: string;
mode: string;
dev_mode: boolean;
analytics_enabled: boolean;
components: object[];
css: unknown;
title: string;
is_space: boolean;
enable_queue: boolean;
show_error: boolean;
show_api: boolean;
is_colab: boolean;
stylesheets: unknown[];
theme: string;
layout: object;
dependencies: object[];
root: string;
}
export interface SdErrorResponse {
/**
* The HTTP status message or array of invalid fields.
* Can also be empty string.
*/
detail: string | Array<{ loc: string[]; msg: string; type: string }>;
/** Can be e.g. "OutOfMemoryError" or undefined. */
error?: string;
/** Empty string. */
body?: string;
/** Long description of error. */
errors?: string;
}
export class SdApiError extends Error {
constructor(
public readonly endpoint: string,
public readonly options: RequestInit | undefined,
public readonly statusCode: number,
public readonly statusText: string,
public readonly response?: SdErrorResponse,
) {
let message = `${options?.method ?? "GET"} ${endpoint} : ${statusCode} ${statusText}`;
if (response?.error) {
message += `: ${response.error}`;
if (response.errors) message += ` - ${response.errors}`;
} else if (typeof response?.detail === "string" && response.detail.length > 0) {
message += `: ${response.detail}`;
} else if (response?.detail) {
message += `: ${JSON.stringify(response.detail)}`;
}
super(message);
}
}
export function getPngInfo(pngData: Uint8Array): string | undefined {
return PngChunksExtract.default(pngData)
.filter((chunk) => chunk.name === "tEXt")
.map((chunk) => PngChunkText.decode(chunk.data))
.find((textChunk) => textChunk.keyword === "parameters")
?.text;
}
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
let part: "prompt" | "negative_prompt" | "params" = "prompt";
const params: Partial<SdTxt2ImgRequest> = {};
const prompt: string[] = [];
const negativePrompt: string[] = [];
for (const tag of tags) {
const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+([\d\w. ]+)\s*$/u);
if (paramValuePair) {
const [, param, value] = paramValuePair;
switch (param.replace(/\s+/u, "").toLowerCase()) {
case "prompt":
part = "prompt";
prompt.push(value.trim());
break;
case "negativeprompt":
part = "negative_prompt";
negativePrompt.push(value.trim());
break;
case "steps":
case "cycles": {
part = "params";
const steps = Number(value.trim());
if (steps > 0) params.steps = Math.min(steps, 50);
break;
}
case "cfgscale":
case "detail": {
part = "params";
const cfgScale = Number(value.trim());
if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20);
break;
}
case "size":
case "resolution": {
part = "params";
const [width, height] = value.trim()
.split(/\s*[x,]\s*/u, 2)
.map((v) => v.trim())
.map(Number);
if (width > 0 && height > 0) {
params.width = Math.min(width, 2048);
params.height = Math.min(height, 2048);
}
break;
}
default:
break;
}
} else if (tag.trim().length > 0) {
switch (part) {
case "prompt":
prompt.push(tag.trim());
break;
case "negative_prompt":
negativePrompt.push(tag.trim());
break;
default:
break;
}
}
}
if (prompt.length > 0) params.prompt = prompt.join(", ");
if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", ");
return params;
}

View File

@ -1,96 +0,0 @@
import { assert } from "https://deno.land/std@0.198.0/assert/assert.ts";
import { Schema, Store } from "./store.ts";
import { log } from "./deps.ts";
const db = await Deno.openKv();
log.setup({
handlers: {
console: new log.handlers.ConsoleHandler("DEBUG", {}),
},
loggers: {
kvStore: { level: "DEBUG", handlers: ["console"] },
},
});
interface PointSchema {
x: number;
y: number;
}
interface JobSchema {
name: string;
params: {
a: number;
b: number | null;
};
status: { type: "idle" } | { type: "processing"; progress: number } | { type: "done" };
lastUpdateDate: Date;
}
const pointStore = new Store(db, "points", {
schema: new Schema<PointSchema>(),
indices: ["x", "y"],
});
const jobStore = new Store(db, "jobs", {
schema: new Schema<JobSchema>(),
indices: ["name", "status.type"],
});
Deno.test("create and delete", async () => {
await pointStore.deleteAll();
const point1 = await pointStore.create({ x: 1, y: 2 });
const point2 = await pointStore.create({ x: 3, y: 4 });
assert((await pointStore.getAll()).length === 2);
const point3 = await pointStore.create({ x: 5, y: 6 });
assert((await pointStore.getAll()).length === 3);
assert((await pointStore.get(point2.id))?.value.y === 4);
await point1.delete();
assert((await pointStore.getAll()).length === 2);
await point2.delete();
await point3.delete();
assert((await pointStore.getAll()).length === 0);
});
Deno.test("list by index", async () => {
await jobStore.deleteAll();
const test = await jobStore.create({
name: "test",
params: { a: 1, b: null },
status: { type: "idle" },
lastUpdateDate: new Date(),
});
assert((await jobStore.getBy("name", "test"))[0].value.params.a === 1);
assert((await jobStore.getBy("status.type", "idle"))[0].value.params.a === 1);
await test.update({ status: { type: "processing", progress: 33 } });
assert((await jobStore.getBy("status.type", "processing"))[0].value.params.a === 1);
await test.update({ status: { type: "done" } });
assert((await jobStore.getBy("status.type", "done"))[0].value.params.a === 1);
assert((await jobStore.getBy("status.type", "processing")).length === 0);
await test.delete();
assert((await jobStore.getBy("status.type", "done")).length === 0);
assert((await jobStore.getBy("name", "test")).length === 0);
});
Deno.test("fail on concurrent update", async () => {
await jobStore.deleteAll();
const test = await jobStore.create({
name: "test",
params: { a: 1, b: null },
status: { type: "idle" },
lastUpdateDate: new Date(),
});
const result = await Promise.all([