forked from pinks/eris
big rewrite
This commit is contained in:
parent
ba2afe40ce
commit
6c66a00910
40
README.md
40
README.md
|
@ -11,12 +11,44 @@ Telegram bot for generating images from text.
|
|||
|
||||
You can put these in `.env` file or pass them as environment variables.
|
||||
|
||||
- `TG_BOT_TOKEN` - Telegram bot token (get yours from [@BotFather](https://t.me/BotFather))
|
||||
- `SD_API_URL` - URL to Stable Diffusion API (e.g. `http://127.0.0.1:7860/`)
|
||||
- `ADMIN_USERNAMES` - Comma separated list of usernames of users that can use admin commands
|
||||
(optional)
|
||||
- `TG_BOT_TOKEN` - Telegram bot token. Get yours from [@BotFather](https://t.me/BotFather).
|
||||
Required.
|
||||
- `SD_API_URL` - URL to Stable Diffusion API. Only used on first run. Default:
|
||||
`http://127.0.0.1:7860/`
|
||||
- `TG_ADMIN_USERS` - Comma separated list of usernames of users that can use admin commands. Only
|
||||
used on first run. Optional.
|
||||
|
||||
## Running
|
||||
|
||||
- Start stable diffusion webui: `cd sd-webui`, `./webui.sh --api`
|
||||
- Start bot: `deno task start`
|
||||
|
||||
## TODO
|
||||
|
||||
- [x] Keep generation history
|
||||
- [x] Changing params, parsing png info in request
|
||||
- [x] Cancelling jobs by deleting message
|
||||
- [x] Multiple parallel workers
|
||||
- [ ] Replying to another text message to copy prompt and generate
|
||||
- [ ] Replying to bot message, conversation in DMs
|
||||
- [ ] Replying to png message to extract png info nad generate
|
||||
- [ ] Banning tags
|
||||
- [ ] Img2Img + Upscale
|
||||
- [ ] Admin WebUI
|
||||
- [ ] User daily generation limits
|
||||
- [ ] Querying all generation history, displaying stats
|
||||
- [ ] Analyzing prompt quality based on tag csv
|
||||
- [ ] Report aliased/unknown tags based on csv
|
||||
- [ ] Report unknown loras
|
||||
- [ ] Investigate "sendMediaGroup failed"
|
||||
- [ ] Changing sampler without error on unknown sampler
|
||||
- [ ] Changing model
|
||||
- [ ] Inpaint using telegram photo edit
|
||||
- [ ] Outpaint
|
||||
- [ ] Non-SD (extras) upscale
|
||||
- [ ] Tiled generation to allow very big images
|
||||
- [ ] Downloading raw images
|
||||
- [ ] Extra prompt syntax, fixing `()+++` syntax
|
||||
- [ ] Translations
|
||||
- replace fmtDuration usage
|
||||
- replace formatOrdinal usage
|
||||
|
|
273
bot.ts
273
bot.ts
|
@ -1,273 +0,0 @@
|
|||
import { autoQuote, bold, Bot, Context, hydrateReply, log, ParseModeFlavor } from "./deps.ts";
|
||||
import { fmt } from "./intl.ts";
|
||||
import { getAllJobs, pushJob } from "./queue.ts";
|
||||
import { mySession, MySessionFlavor } from "./session.ts";
|
||||
|
||||
const logger = () => log.getLogger();
|
||||
|
||||
export type MyContext = ParseModeFlavor<Context> & MySessionFlavor;
|
||||
export const bot = new Bot<MyContext>(Deno.env.get("TG_BOT_TOKEN") ?? "");
|
||||
bot.use(autoQuote);
|
||||
bot.use(hydrateReply);
|
||||
bot.use(mySession);
|
||||
|
||||
// Automatically retry bot requests if we get a 429 error
|
||||
bot.api.config.use(async (prev, method, payload, signal) => {
|
||||
let remainingAttempts = 5;
|
||||
while (true) {
|
||||
const result = await prev(method, payload, signal);
|
||||
if (result.ok) return result;
|
||||
if (result.error_code !== 429 || remainingAttempts <= 0) return result;
|
||||
remainingAttempts -= 1;
|
||||
const retryAfterMs = (result.parameters?.retry_after ?? 30) * 1000;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
|
||||
}
|
||||
});
|
||||
|
||||
// if error happened, try to reply to the user with the error
|
||||
bot.use(async (ctx, next) => {
|
||||
try {
|
||||
await next();
|
||||
} catch (err) {
|
||||
try {
|
||||
await ctx.reply(`Handling update failed: ${err}`, {
|
||||
reply_to_message_id: ctx.message?.message_id,
|
||||
});
|
||||
} catch {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bot.api.setMyShortDescription("I can generate furry images from text");
|
||||
bot.api.setMyDescription(
|
||||
"I can generate furry images from text. " +
|
||||
"Send /txt2img to generate an image.",
|
||||
);
|
||||
bot.api.setMyCommands([
|
||||
{ command: "txt2img", description: "Generate an image" },
|
||||
{ command: "queue", description: "Show the current queue" },
|
||||
{ command: "sdparams", description: "Show the current SD parameters" },
|
||||
]);
|
||||
|
||||
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
|
||||
|
||||
bot.command("txt2img", async (ctx) => {
|
||||
if (!ctx.from?.id) {
|
||||
return ctx.reply("I don't know who you are");
|
||||
}
|
||||
const config = ctx.session.global;
|
||||
if (config.pausedReason != null) {
|
||||
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
|
||||
}
|
||||
const jobs = await getAllJobs();
|
||||
if (jobs.length >= config.maxJobs) {
|
||||
return ctx.reply(
|
||||
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
|
||||
);
|
||||
}
|
||||
const jobCount = jobs.filter((job) => job.user.id === ctx.from.id).length;
|
||||
if (jobCount >= config.maxUserJobs) {
|
||||
return ctx.reply(
|
||||
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
|
||||
);
|
||||
}
|
||||
if (!ctx.match) {
|
||||
return ctx.reply("Please describe what you want to see after the command");
|
||||
}
|
||||
const statusMessage = await ctx.reply("Accepted. You are now in queue.");
|
||||
await pushJob({
|
||||
params: { prompt: ctx.match },
|
||||
user: ctx.from,
|
||||
chat: ctx.chat,
|
||||
requestMessage: ctx.message,
|
||||
statusMessage,
|
||||
status: { type: "idle" },
|
||||
});
|
||||
logger().info("Job enqueued", ctx.from.first_name, ctx.chat.type, ctx.match.replace(/\s+/g, " "));
|
||||
});
|
||||
|
||||
bot.command("queue", async (ctx) => {
|
||||
let jobs = await getAllJobs();
|
||||
const getMessageText = () => {
|
||||
if (jobs.length === 0) return fmt`Queue is empty.`;
|
||||
const sortedJobs = [];
|
||||
let place = 0;
|
||||
for (const job of jobs) {
|
||||
if (job.status.type === "idle") place += 1;
|
||||
sortedJobs.push({ ...job, place });
|
||||
}
|
||||
return fmt`Current queue:\n\n${
|
||||
sortedJobs.map((job) =>
|
||||
fmt`${job.place}. ${bold(job.user.first_name)} in ${job.chat.type} chat ${
|
||||
job.status.type === "processing" ? `(${(job.status.progress * 100).toFixed(0)}%)` : ""
|
||||
}\n`
|
||||
)
|
||||
}`;
|
||||
};
|
||||
const message = await ctx.replyFmt(getMessageText());
|
||||
handleFutureUpdates();
|
||||
async function handleFutureUpdates() {
|
||||
for (let idx = 0; idx < 12; idx++) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
jobs = await getAllJobs();
|
||||
const formattedMessage = getMessageText();
|
||||
await ctx.api.editMessageText(ctx.chat.id, message.message_id, formattedMessage.text, {
|
||||
entities: formattedMessage.entities,
|
||||
}).catch(() => undefined);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bot.command("pause", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
if (config.pausedReason != null) {
|
||||
return ctx.reply(`Already paused: ${config.pausedReason}`);
|
||||
}
|
||||
config.pausedReason = ctx.match ?? "No reason given";
|
||||
return ctx.reply("Paused");
|
||||
});
|
||||
|
||||
bot.command("resume", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
if (config.pausedReason == null) return ctx.reply("Already running");
|
||||
config.pausedReason = null;
|
||||
return ctx.reply("Resumed");
|
||||
});
|
||||
|
||||
bot.command("setsdapiurl", async (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
if (!ctx.match) return ctx.reply("Please specify an URL");
|
||||
let url: URL;
|
||||
try {
|
||||
url = new URL(ctx.match);
|
||||
} catch {
|
||||
return ctx.reply("Invalid URL");
|
||||
}
|
||||
let resp: Response;
|
||||
try {
|
||||
resp = await fetch(new URL("config", url));
|
||||
} catch (err) {
|
||||
return ctx.reply(`Could not connect: ${err}`);
|
||||
}
|
||||
if (!resp.ok) {
|
||||
return ctx.reply(`Could not connect: ${resp.status} ${resp.statusText}`);
|
||||
}
|
||||
let data: unknown;
|
||||
try {
|
||||
data = await resp.json();
|
||||
} catch {
|
||||
return ctx.reply("Invalid response from API");
|
||||
}
|
||||
if (data != null && typeof data === "object" && "version" in data) {
|
||||
config.sdApiUrl = url.toString();
|
||||
return ctx.reply(`Now using SD at ${url} running version ${data.version}`);
|
||||
} else {
|
||||
return ctx.reply("Invalid response from API");
|
||||
}
|
||||
});
|
||||
|
||||
bot.command("setsdparam", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
let [param = "", value] = ctx.match.split("=", 2).map((s) => s.trim());
|
||||
if (!param) return ctx.reply("Please specify a parameter");
|
||||
if (value == null) return ctx.reply("Please specify a value after the =");
|
||||
param = param.toLowerCase().replace(/[\s_]+/g, "");
|
||||
if (config.defaultParams == null) config.defaultParams = {};
|
||||
switch (param) {
|
||||
case "steps": {
|
||||
const steps = parseInt(value);
|
||||
if (isNaN(steps)) return ctx.reply("Invalid number value");
|
||||
if (steps > 100) return ctx.reply("Steps must be less than 100");
|
||||
if (steps < 10) return ctx.reply("Steps must be greater than 10");
|
||||
config.defaultParams.steps = steps;
|
||||
return ctx.reply("Steps set to " + steps);
|
||||
}
|
||||
case "detail":
|
||||
case "cfgscale": {
|
||||
const detail = parseInt(value);
|
||||
if (isNaN(detail)) return ctx.reply("Invalid number value");
|
||||
if (detail > 20) return ctx.reply("Detail must be less than 20");
|
||||
if (detail < 1) return ctx.reply("Detail must be greater than 1");
|
||||
config.defaultParams.cfg_scale = detail;
|
||||
return ctx.reply("Detail set to " + detail);
|
||||
}
|
||||
case "niter":
|
||||
case "niters": {
|
||||
const nIter = parseInt(value);
|
||||
if (isNaN(nIter)) return ctx.reply("Invalid number value");
|
||||
if (nIter > 10) return ctx.reply("Iterations must be less than 10");
|
||||
if (nIter < 1) return ctx.reply("Iterations must be greater than 1");
|
||||
config.defaultParams.n_iter = nIter;
|
||||
return ctx.reply("Iterations set to " + nIter);
|
||||
}
|
||||
case "batchsize": {
|
||||
const batchSize = parseInt(value);
|
||||
if (isNaN(batchSize)) return ctx.reply("Invalid number value");
|
||||
if (batchSize > 8) return ctx.reply("Batch size must be less than 8");
|
||||
if (batchSize < 1) return ctx.reply("Batch size must be greater than 1");
|
||||
config.defaultParams.batch_size = batchSize;
|
||||
return ctx.reply("Batch size set to " + batchSize);
|
||||
}
|
||||
case "size": {
|
||||
let [width, height] = value.split("x", 2).map((s) => parseInt(s.trim()));
|
||||
if (!width || !height || isNaN(width) || isNaN(height)) {
|
||||
return ctx.reply("Invalid size value");
|
||||
}
|
||||
if (width > 2048) return ctx.reply("Width must be less than 2048");
|
||||
if (height > 2048) return ctx.reply("Height must be less than 2048");
|
||||
// find closest multiple of 64
|
||||
width = Math.round(width / 64) * 64;
|
||||
height = Math.round(height / 64) * 64;
|
||||
if (width <= 0) return ctx.reply("Width too small");
|
||||
if (height <= 0) return ctx.reply("Height too small");
|
||||
config.defaultParams.width = width;
|
||||
config.defaultParams.height = height;
|
||||
return ctx.reply(`Size set to ${width}x${height}`);
|
||||
}
|
||||
case "negativeprompt": {
|
||||
config.defaultParams.negative_prompt = value;
|
||||
return ctx.reply(`Negative prompt set to: ${value}`);
|
||||
}
|
||||
default: {
|
||||
return ctx.reply("Invalid parameter");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bot.command("sdparams", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
return ctx.replyFmt(
|
||||
fmt`Current config:\n\n${
|
||||
Object.entries(config.defaultParams ?? {}).map(([key, value]) =>
|
||||
fmt`${bold(key)} = ${String(value)}\n`
|
||||
)
|
||||
}`,
|
||||
);
|
||||
});
|
||||
|
||||
bot.command("crash", () => {
|
||||
throw new Error("Crash command used");
|
||||
});
|
||||
|
||||
bot.catch((err) => {
|
||||
let msg = "Error processing update";
|
||||
const { from, chat } = err.ctx;
|
||||
if (from?.first_name) msg += ` from ${from.first_name}`;
|
||||
if (from?.last_name) msg += ` ${from.last_name}`;
|
||||
if (from?.username) msg += ` (@${from.username})`;
|
||||
if (chat?.type === "supergroup" || chat?.type === "group") {
|
||||
msg += ` in ${chat.title}`;
|
||||
if (chat.type === "supergroup" && chat.username) msg += ` (@${chat.username})`;
|
||||
}
|
||||
logger().error("handling update failed", from?.first_name, chat?.type, err);
|
||||
});
|
|
@ -0,0 +1,92 @@
|
|||
import { Grammy, GrammyAutoQuote, GrammyParseMode, Log } from "../deps.ts";
|
||||
import { formatUserChat } from "../utils.ts";
|
||||
import { session, SessionFlavor } from "./session.ts";
|
||||
import { queueCommand } from "./queueCommand.ts";
|
||||
import { txt2imgCommand } from "./txt2imgCommand.ts";
|
||||
|
||||
export const logger = () => Log.getLogger();
|
||||
|
||||
export type Context = GrammyParseMode.ParseModeFlavor<Grammy.Context> & SessionFlavor;
|
||||
export const bot = new Grammy.Bot<Context>(Deno.env.get("TG_BOT_TOKEN") ?? "");
|
||||
bot.use(GrammyAutoQuote.autoQuote);
|
||||
bot.use(GrammyParseMode.hydrateReply);
|
||||
bot.use(session);
|
||||
|
||||
bot.catch((err) => {
|
||||
logger().error(`Handling update from ${formatUserChat(err.ctx)} failed: ${err}`);
|
||||
});
|
||||
|
||||
// Automatically retry bot requests if we get a "too many requests" or telegram internal error
|
||||
bot.api.config.use(async (prev, method, payload, signal) => {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
attempt++;
|
||||
const result = await prev(method, payload, signal);
|
||||
if (
|
||||
result.ok ||
|
||||
![429, 500].includes(result.error_code) ||
|
||||
attempt >= 5
|
||||
) {
|
||||
return result;
|
||||
}
|
||||
const retryAfterMs = (result.parameters?.retry_after ?? (attempt * 5)) * 1000;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryAfterMs));
|
||||
}
|
||||
});
|
||||
|
||||
// if error happened, try to reply to the user with the error
|
||||
bot.use(async (ctx, next) => {
|
||||
try {
|
||||
await next();
|
||||
} catch (err) {
|
||||
try {
|
||||
await ctx.reply(`Handling update failed: ${err}`, {
|
||||
reply_to_message_id: ctx.message?.message_id,
|
||||
});
|
||||
} catch {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bot.api.setMyShortDescription("I can generate furry images from text");
|
||||
bot.api.setMyDescription(
|
||||
"I can generate furry images from text. " +
|
||||
"Send /txt2img to generate an image.",
|
||||
);
|
||||
bot.api.setMyCommands([
|
||||
{ command: "txt2img", description: "Generate an image" },
|
||||
{ command: "queue", description: "Show the current queue" },
|
||||
]);
|
||||
|
||||
bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image"));
|
||||
|
||||
bot.command("txt2img", txt2imgCommand);
|
||||
|
||||
bot.command("queue", queueCommand);
|
||||
|
||||
bot.command("pause", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
if (config.pausedReason != null) {
|
||||
return ctx.reply(`Already paused: ${config.pausedReason}`);
|
||||
}
|
||||
config.pausedReason = ctx.match ?? "No reason given";
|
||||
logger().warning(`Bot paused by ${ctx.from.first_name} because ${config.pausedReason}`);
|
||||
return ctx.reply("Paused");
|
||||
});
|
||||
|
||||
bot.command("resume", (ctx) => {
|
||||
if (!ctx.from?.username) return;
|
||||
const config = ctx.session.global;
|
||||
if (!config.adminUsernames.includes(ctx.from.username)) return;
|
||||
if (config.pausedReason == null) return ctx.reply("Already running");
|
||||
config.pausedReason = null;
|
||||
logger().info(`Bot resumed by ${ctx.from.first_name}`);
|
||||
return ctx.reply("Resumed");
|
||||
});
|
||||
|
||||
bot.command("crash", () => {
|
||||
throw new Error("Crash command used");
|
||||
});
|
|
@ -0,0 +1,67 @@
|
|||
import { Grammy, GrammyParseMode } from "../deps.ts";
|
||||
import { fmt, getFlagEmoji } from "../utils.ts";
|
||||
import { runningWorkers } from "../tasks/pingWorkers.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
import { Context, logger } from "./mod.ts";
|
||||
|
||||
export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
|
||||
let formattedMessage = await getMessageText();
|
||||
const queueMessage = await ctx.replyFmt(formattedMessage);
|
||||
handleFutureUpdates().catch((err) => logger().warning(`Updating queue message failed: ${err}`));
|
||||
|
||||
async function getMessageText() {
|
||||
const processingJobs = await jobStore.getBy("status.type", "processing")
|
||||
.then((jobs) => jobs.map((job) => ({ ...job.value, place: 0 })));
|
||||
const waitingJobs = await jobStore.getBy("status.type", "waiting")
|
||||
.then((jobs) => jobs.map((job, index) => ({ ...job.value, place: index + 1 })));
|
||||
const jobs = [...processingJobs, ...waitingJobs];
|
||||
const config = ctx.session.global;
|
||||
const { bold } = GrammyParseMode;
|
||||
return fmt([
|
||||
"Current queue:\n",
|
||||
...jobs.length > 0
|
||||
? jobs.flatMap((job) => [
|
||||
`${job.place}. `,
|
||||
fmt`${bold(job.request.from.first_name)} `,
|
||||
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
|
||||
job.request.from.username ? `(@${job.request.from.username}) ` : "",
|
||||
getFlagEmoji(job.request.from.language_code) ?? "",
|
||||
job.request.chat.type === "private"
|
||||
? " in private chat "
|
||||
: ` in ${job.request.chat.title} `,
|
||||
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
|
||||
job.request.chat.username
|
||||
? `(@${job.request.chat.username}) `
|
||||
: "",
|
||||
job.status.type === "processing"
|
||||
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
|
||||
: "",
|
||||
"\n",
|
||||
])
|
||||
: ["Queue is empty.\n"],
|
||||
"\nActive workers:\n",
|
||||
...config.workers.flatMap((worker) => [
|
||||
runningWorkers.has(worker.name) ? "✅ " : "☠️ ",
|
||||
fmt`${bold(worker.name)} `,
|
||||
`(max ${(worker.maxResolution / 1000000).toFixed(1)} Mpx) `,
|
||||
"\n",
|
||||
]),
|
||||
]);
|
||||
}
|
||||
|
||||
async function handleFutureUpdates() {
|
||||
for (let idx = 0; idx < 20; idx++) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
const nextFormattedMessage = await getMessageText();
|
||||
if (nextFormattedMessage.text !== formattedMessage.text) {
|
||||
await ctx.api.editMessageText(
|
||||
ctx.chat.id,
|
||||
queueMessage.message_id,
|
||||
nextFormattedMessage.text,
|
||||
{ entities: nextFormattedMessage.entities },
|
||||
);
|
||||
formattedMessage = nextFormattedMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
import { Context, DenoKVAdapter, session, SessionFlavor } from "./deps.ts";
|
||||
import { SdTxt2ImgRequest } from "./sd.ts";
|
||||
import { Grammy, GrammyKvStorage } from "../deps.ts";
|
||||
import { SdApi, SdTxt2ImgRequest } from "../sd.ts";
|
||||
|
||||
export type MySessionFlavor = SessionFlavor<SessionData>;
|
||||
export type SessionFlavor = Grammy.SessionFlavor<SessionData>;
|
||||
|
||||
export interface SessionData {
|
||||
global: GlobalData;
|
||||
|
@ -12,45 +12,55 @@ export interface SessionData {
|
|||
export interface GlobalData {
|
||||
adminUsernames: string[];
|
||||
pausedReason: string | null;
|
||||
sdApiUrl: string;
|
||||
maxUserJobs: number;
|
||||
maxJobs: number;
|
||||
defaultParams?: Partial<SdTxt2ImgRequest>;
|
||||
workers: WorkerData[];
|
||||
}
|
||||
|
||||
export interface WorkerData {
|
||||
name: string;
|
||||
api: SdApi;
|
||||
auth?: string;
|
||||
maxResolution: number;
|
||||
}
|
||||
|
||||
export interface ChatData {
|
||||
language: string;
|
||||
language?: string;
|
||||
}
|
||||
|
||||
export interface UserData {
|
||||
steps: number;
|
||||
detail: number;
|
||||
batchSize: number;
|
||||
params?: Partial<SdTxt2ImgRequest>;
|
||||
}
|
||||
|
||||
const globalDb = await Deno.openKv("./app.db");
|
||||
|
||||
const globalDbAdapter = new DenoKVAdapter<GlobalData>(globalDb);
|
||||
const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter<GlobalData>(globalDb);
|
||||
|
||||
const getDefaultGlobalData = (): GlobalData => ({
|
||||
adminUsernames: (Deno.env.get("ADMIN_USERNAMES") ?? "").split(",").filter(Boolean),
|
||||
adminUsernames: Deno.env.get("TG_ADMIN_USERS")?.split(",") ?? [],
|
||||
pausedReason: null,
|
||||
sdApiUrl: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/",
|
||||
maxUserJobs: 3,
|
||||
maxJobs: 20,
|
||||
defaultParams: {
|
||||
batch_size: 1,
|
||||
n_iter: 1,
|
||||
width: 128 * 2,
|
||||
height: 128 * 3,
|
||||
steps: 20,
|
||||
cfg_scale: 9,
|
||||
send_images: true,
|
||||
width: 512,
|
||||
height: 768,
|
||||
steps: 30,
|
||||
cfg_scale: 10,
|
||||
negative_prompt: "boring_e621_fluffyrock_v4 boring_e621_v4",
|
||||
},
|
||||
workers: [
|
||||
{
|
||||
name: "local",
|
||||
api: { url: Deno.env.get("SD_API_URL") ?? "http://127.0.0.1:7860/" },
|
||||
maxResolution: 1024 * 1024,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
export const mySession = session<SessionData, Context & MySessionFlavor>({
|
||||
export const session = Grammy.session<SessionData, Grammy.Context & SessionFlavor>({
|
||||
type: "multi",
|
||||
global: {
|
||||
getSessionKey: () => "global",
|
||||
|
@ -58,17 +68,11 @@ export const mySession = session<SessionData, Context & MySessionFlavor>({
|
|||
storage: globalDbAdapter,
|
||||
},
|
||||
chat: {
|
||||
initial: () => ({
|
||||
language: "en",
|
||||
}),
|
||||
initial: () => ({}),
|
||||
},
|
||||
user: {
|
||||
getSessionKey: (ctx) => ctx.from?.id.toFixed(),
|
||||
initial: () => ({
|
||||
steps: 20,
|
||||
detail: 8,
|
||||
batchSize: 2,
|
||||
}),
|
||||
initial: () => ({}),
|
||||
},
|
||||
});
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
import { Grammy } from "../deps.ts";
|
||||
import { formatUserChat } from "../utils.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
import { parsePngInfo } from "../sd.ts";
|
||||
import { Context, logger } from "./mod.ts";
|
||||
|
||||
export async function txt2imgCommand(ctx: Grammy.CommandContext<Context>) {
|
||||
if (!ctx.from?.id) {
|
||||
return ctx.reply("I don't know who you are");
|
||||
}
|
||||
const config = ctx.session.global;
|
||||
if (config.pausedReason != null) {
|
||||
return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`);
|
||||
}
|
||||
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||
if (jobs.length >= config.maxJobs) {
|
||||
return ctx.reply(
|
||||
`The queue is full. Try again later. (Max queue size: ${config.maxJobs})`,
|
||||
);
|
||||
}
|
||||
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.from?.id);
|
||||
if (userJobs.length >= config.maxUserJobs) {
|
||||
return ctx.reply(
|
||||
`You already have ${config.maxUserJobs} jobs in queue. Try again later.`,
|
||||
);
|
||||
}
|
||||
const params = parsePngInfo(ctx.match);
|
||||
if (!params.prompt) {
|
||||
return ctx.reply("Please describe what you want to see after the command");
|
||||
}
|
||||
const reply = await ctx.reply("Accepted. You are now in queue.");
|
||||
await jobStore.create({
|
||||
params,
|
||||
request: ctx.message,
|
||||
reply,
|
||||
status: { type: "waiting" },
|
||||
});
|
||||
logger().debug(`Job enqueued for ${formatUserChat(ctx)}`);
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
import { GrammyTypes, IKV } from "../deps.ts";
|
||||
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
|
||||
import { db } from "./db.ts";
|
||||
|
||||
export interface JobSchema {
|
||||
params: Partial<SdTxt2ImgRequest>;
|
||||
request: GrammyTypes.Message.TextMessage & { from: GrammyTypes.User };
|
||||
reply?: GrammyTypes.Message.TextMessage;
|
||||
status:
|
||||
| { type: "waiting" }
|
||||
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
|
||||
| { type: "done"; info?: SdTxt2ImgInfo; startDate?: Date; endDate?: Date };
|
||||
}
|
||||
|
||||
export const jobStore = new IKV.Store(db, "job", {
|
||||
schema: new IKV.Schema<JobSchema>(),
|
||||
indices: ["status.type"],
|
||||
});
|
25
deps.ts
25
deps.ts
|
@ -1,7 +1,18 @@
|
|||
export * from "https://deno.land/x/grammy@v1.18.1/mod.ts";
|
||||
export * from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
|
||||
export * from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
|
||||
export * from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
|
||||
export * as types from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
|
||||
export * from "https://deno.land/x/ulid@v0.3.0/mod.ts";
|
||||
export * as log from "https://deno.land/std@0.201.0/log/mod.ts";
|
||||
export * as Log from "https://deno.land/std@0.201.0/log/mod.ts";
|
||||
export * as Async from "https://deno.land/std@0.201.0/async/mod.ts";
|
||||
export * as FmtDuration from "https://deno.land/std@0.201.0/fmt/duration.ts";
|
||||
export * as Collections from "https://deno.land/std@0.201.0/collections/mod.ts";
|
||||
export * as Base64 from "https://deno.land/std@0.201.0/encoding/base64.ts";
|
||||
export * as AsyncX from "https://deno.land/x/async@v2.0.2/mod.ts";
|
||||
export * as ULID from "https://deno.land/x/ulid@v0.3.0/mod.ts";
|
||||
export * as IKV from "https://deno.land/x/indexed_kv@v0.2.0/mod.ts";
|
||||
export * as Grammy from "https://deno.land/x/grammy@v1.18.1/mod.ts";
|
||||
export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
|
||||
export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
|
||||
export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts";
|
||||
export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
|
||||
export * as FileType from "npm:file-type@18.5.0";
|
||||
// @deno-types="./types/png-chunks-extract.d.ts"
|
||||
export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0";
|
||||
// @deno-types="./types/png-chunk-text.d.ts"
|
||||
export * as PngChunkText from "npm:png-chunk-text@1.0.0";
|
||||
|
|
43
intl.ts
43
intl.ts
|
@ -1,43 +0,0 @@
|
|||
import { FormattedString } from "./deps.ts";
|
||||
|
||||
export function formatOrdinal(n: number) {
|
||||
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
|
||||
if (n % 10 === 1) return `${n}st`;
|
||||
if (n % 10 === 2) return `${n}nd`;
|
||||
if (n % 10 === 3) return `${n}rd`;
|
||||
return `${n}th`;
|
||||
}
|
||||
|
||||
type DeepArray<T> = Array<T | DeepArray<T>>;
|
||||
type StringLikes = DeepArray<FormattedString | string | number | null | undefined>;
|
||||
|
||||
/**
|
||||
* Like `fmt` from `grammy_parse_mode` but additionally accepts arrays.
|
||||
* @see https://deno.land/x/grammy_parse_mode@1.7.1/format.ts?source=#L182
|
||||
*/
|
||||
export const fmt = (
|
||||
rawStringParts: TemplateStringsArray | StringLikes,
|
||||
...stringLikes: StringLikes
|
||||
): FormattedString => {
|
||||
let text = "";
|
||||
const entities: ConstructorParameters<typeof FormattedString>[1][] = [];
|
||||
|
||||
const length = Math.max(rawStringParts.length, stringLikes.length);
|
||||
for (let i = 0; i < length; i++) {
|
||||
for (let stringLike of [rawStringParts[i], stringLikes[i]]) {
|
||||
if (Array.isArray(stringLike)) {
|
||||
stringLike = fmt(stringLike);
|
||||
}
|
||||
if (stringLike instanceof FormattedString) {
|
||||
entities.push(
|
||||
...stringLike.entities.map((e) => ({
|
||||
...e,
|
||||
offset: e.offset + text.length,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (stringLike != null) text += stringLike.toString();
|
||||
}
|
||||
}
|
||||
return new FormattedString(text, entities);
|
||||
};
|
23
main.ts
23
main.ts
|
@ -1,24 +1,21 @@
|
|||
// Load environment variables from .env file
|
||||
import "https://deno.land/std@0.201.0/dotenv/load.ts";
|
||||
import { bot } from "./bot.ts";
|
||||
import { processQueue, returnHangedJobs } from "./queue.ts";
|
||||
import { log } from "./deps.ts";
|
||||
|
||||
log.setup({
|
||||
// Setup logging
|
||||
import { Log } from "./deps.ts";
|
||||
Log.setup({
|
||||
handlers: {
|
||||
console: new log.handlers.ConsoleHandler("INFO", {
|
||||
formatter: (record) =>
|
||||
`[${record.levelName}] ${record.msg} ${
|
||||
record.args.map((arg) => JSON.stringify(arg)).join(" ")
|
||||
} (${record.datetime.toISOString()})`,
|
||||
}),
|
||||
console: new Log.handlers.ConsoleHandler("DEBUG"),
|
||||
},
|
||||
loggers: {
|
||||
default: { level: "INFO", handlers: ["console"] },
|
||||
default: { level: "DEBUG", handlers: ["console"] },
|
||||
},
|
||||
});
|
||||
|
||||
// Main program logic
|
||||
import { bot } from "./bot/mod.ts";
|
||||
import { runAllTasks } from "./tasks/mod.ts";
|
||||
await Promise.all([
|
||||
bot.start(),
|
||||
processQueue(),
|
||||
returnHangedJobs(),
|
||||
runAllTasks(),
|
||||
]);
|
||||
|
|
15
mimeType.ts
15
mimeType.ts
|
@ -1,15 +0,0 @@
|
|||
export function mimeTypeFromBase64(base64: string) {
|
||||
if (base64.startsWith("/9j/")) return "image/jpeg";
|
||||
if (base64.startsWith("iVBORw0KGgo")) return "image/png";
|
||||
if (base64.startsWith("R0lGODlh")) return "image/gif";
|
||||
if (base64.startsWith("UklGRg")) return "image/webp";
|
||||
throw new Error("Unknown image type");
|
||||
}
|
||||
|
||||
export function extFromMimeType(mimeType: string) {
|
||||
if (mimeType === "image/jpeg") return "jpg";
|
||||
if (mimeType === "image/png") return "png";
|
||||
if (mimeType === "image/gif") return "gif";
|
||||
if (mimeType === "image/webp") return "webp";
|
||||
throw new Error("Unknown image type");
|
||||
}
|
181
queue.ts
181
queue.ts
|
@ -1,181 +0,0 @@
|
|||
import { InputFile, InputMediaBuilder, log, types } from "./deps.ts";
|
||||
import { bot } from "./bot.ts";
|
||||
import { getGlobalSession } from "./session.ts";
|
||||
import { formatOrdinal } from "./intl.ts";
|
||||
import { SdTxt2ImgRequest, SdTxt2ImgResponse, txt2img } from "./sd.ts";
|
||||
import { extFromMimeType, mimeTypeFromBase64 } from "./mimeType.ts";
|
||||
import { Model, Schema, Store } from "./store.ts";
|
||||
|
||||
const logger = () => log.getLogger();
|
||||
|
||||
interface Job {
|
||||
params: Partial<SdTxt2ImgRequest>;
|
||||
user: types.User;
|
||||
chat: types.Chat.PrivateChat | types.Chat.GroupChat | types.Chat.SupergroupChat;
|
||||
requestMessage: types.Message & types.Message.TextMessage;
|
||||
statusMessage?: types.Message & types.Message.TextMessage;
|
||||
status:
|
||||
| { type: "idle" }
|
||||
| { type: "processing"; progress: number; updatedDate: Date };
|
||||
}
|
||||
|
||||
const db = await Deno.openKv("./app.db");
|
||||
|
||||
const jobStore = new Store(db, "job", {
|
||||
schema: new Schema<Job>(),
|
||||
indices: ["status.type", "user.id", "chat.id"],
|
||||
});
|
||||
|
||||
jobStore.getBy("user.id", 123).then(() => {});
|
||||
|
||||
export async function pushJob(job: Job) {
|
||||
await jobStore.create(job);
|
||||
}
|
||||
|
||||
async function takeJob(): Promise<Model<Job> | null> {
|
||||
const jobs = await jobStore.getAll();
|
||||
const job = jobs.find((job) => job.value.status.type === "idle");
|
||||
if (!job) return null;
|
||||
await job.update({ status: { type: "processing", progress: 0, updatedDate: new Date() } });
|
||||
return job;
|
||||
}
|
||||
|
||||
export async function getAllJobs(): Promise<Array<Job>> {
|
||||
return await jobStore.getAll().then((jobs) => jobs.map((job) => job.value));
|
||||
}
|
||||
|
||||
export async function processQueue() {
|
||||
while (true) {
|
||||
const job = await takeJob().catch((err) =>
|
||||
void logger().warning("failed getting job", err.message)
|
||||
);
|
||||
if (!job) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
continue;
|
||||
}
|
||||
let place = 0;
|
||||
for (const job of await jobStore.getAll().catch(() => [])) {
|
||||
if (job.value.status.type === "idle") place += 1;
|
||||
if (place === 0) continue;
|
||||
const statusMessageText = `You are ${formatOrdinal(place)} in queue.`;
|
||||
if (!job.value.statusMessage) {
|
||||
await bot.api.sendMessage(job.value.chat.id, statusMessageText, {
|
||||
reply_to_message_id: job.value.requestMessage.message_id,
|
||||
}).catch(() => undefined)
|
||||
.then((message) => job.update({ statusMessage: message })).catch(() => undefined);
|
||||
} else {
|
||||
await bot.api.editMessageText(
|
||||
job.value.chat.id,
|
||||
job.value.statusMessage.message_id,
|
||||
statusMessageText,
|
||||
)
|
||||
.catch(() => undefined);
|
||||
}
|
||||
}
|
||||
try {
|
||||
if (job.value.statusMessage) {
|
||||
await bot.api
|
||||
.deleteMessage(job.value.chat.id, job.value.statusMessage?.message_id)
|
||||
.catch(() => undefined)
|
||||
.then(() => job.update({ statusMessage: undefined }));
|
||||
}
|
||||
await bot.api.sendMessage(
|
||||
job.value.chat.id,
|
||||
"Generating your prompt now...",
|
||||
{ reply_to_message_id: job.value.requestMessage.message_id },
|
||||
).then((message) => job.update({ statusMessage: message }));
|
||||
const config = await getGlobalSession();
|
||||
const response = await txt2img(
|
||||
config.sdApiUrl,
|
||||
{ ...config.defaultParams, ...job.value.params },
|
||||
(progress) => {
|
||||
job.update({
|
||||
status: { type: "processing", progress: progress.progress, updatedDate: new Date() },
|
||||
});
|
||||
if (job.value.statusMessage) {
|
||||
bot.api
|
||||
.editMessageText(
|
||||
job.value.chat.id,
|
||||
job.value.statusMessage.message_id,
|
||||
`Generating your prompt now... ${
|
||||
Math.round(
|
||||
progress.progress * 100,
|
||||
)
|
||||
}%`,
|
||||
)
|
||||
.catch(() => undefined);
|
||||
}
|
||||
},
|
||||
);
|
||||
const jobCount = (await jobStore.getAll()).filter((job) =>
|
||||
job.value.status.type !== "processing"
|
||||
).length;
|
||||
logger().info("Job finished", job.value.user.first_name, job.value.chat.type, { jobCount });
|
||||
if (job.value.statusMessage) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.chat.id,
|
||||
job.value.statusMessage.message_id,
|
||||
`Uploading your images...`,
|
||||
).catch(() => undefined);
|
||||
}
|
||||
const inputFiles = await Promise.all(
|
||||
response.images.map(async (imageBase64, idx) => {
|
||||
const mimeType = mimeTypeFromBase64(imageBase64);
|
||||
const imageBlob = await fetch(`data:${mimeType};base64,${imageBase64}`)
|
||||
.then((resp) => resp.blob());
|
||||
return InputMediaBuilder.photo(
|
||||
new InputFile(imageBlob, `image_${idx}.${extFromMimeType(mimeType)}`),
|
||||
);
|
||||
}),
|
||||
);
|
||||
if (job.value.statusMessage) {
|
||||
await bot.api
|
||||
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
|
||||
.catch(() => undefined).then(() => job.update({ statusMessage: undefined }));
|
||||
}
|
||||
await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
|
||||
reply_to_message_id: job.value.requestMessage.message_id,
|
||||
});
|
||||
await job.delete();
|
||||
} catch (err) {
|
||||
logger().error("Job failed", job.value.user.first_name, job.value.chat.type, err);
|
||||
const errorMessage = await bot.api
|
||||
.sendMessage(job.value.chat.id, err.toString(), {
|
||||
reply_to_message_id: job.value.requestMessage.message_id,
|
||||
})
|
||||
.catch(() => undefined);
|
||||
if (errorMessage) {
|
||||
if (job.value.statusMessage) {
|
||||
await bot.api
|
||||
.deleteMessage(job.value.chat.id, job.value.statusMessage.message_id)
|
||||
.then(() => job.update({ statusMessage: undefined }))
|
||||
.catch(() => void logger().warning("failed deleting status message", err.message));
|
||||
}
|
||||
await job.update({ status: { type: "idle" } }).catch((err) =>
|
||||
void logger().warning("failed returning job", err.message)
|
||||
);
|
||||
} else {
|
||||
await job.delete().catch((err) =>
|
||||
void logger().warning("failed deleting job", err.message)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function returnHangedJobs() {
|
||||
while (true) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
const jobs = await jobStore.getAll().catch(() => []);
|
||||
for (const job of jobs) {
|
||||
if (job.value.status.type !== "processing") continue;
|
||||
// if job wasn't updated for 1 minute, return it to the queue
|
||||
if (job.value.status.updatedDate.getTime() < Date.now() - 60 * 1000) {
|
||||
logger().warning("Hanged job returned", job.value.user.first_name, job.value.chat.type);
|
||||
await job.update({ status: { type: "idle" } }).catch((err) =>
|
||||
void logger().warning("failed returning job", err.message)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
294
sd.ts
294
sd.ts
|
@ -1,53 +1,88 @@
|
|||
export async function txt2img(
|
||||
apiUrl: string,
|
||||
import { Async, AsyncX, PngChunksExtract, PngChunkText } from "./deps.ts";
|
||||
|
||||
const neverSignal = new AbortController().signal;
|
||||
|
||||
export interface SdApi {
|
||||
url: string;
|
||||
auth?: string;
|
||||
}
|
||||
|
||||
async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Promise<T> {
|
||||
let options: RequestInit | undefined;
|
||||
if (body != null) {
|
||||
options = {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...api.auth ? { Authorization: api.auth } : {},
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
};
|
||||
} else if (api.auth) {
|
||||
options = {
|
||||
headers: { Authorization: api.auth },
|
||||
};
|
||||
}
|
||||
const response = await fetch(new URL(endpoint, api.url), options).catch(() => {
|
||||
throw new SdApiError(endpoint, options, 0, "Network error");
|
||||
});
|
||||
const result = await response.json().catch(() => {
|
||||
throw new SdApiError(endpoint, options, response.status, response.statusText, {
|
||||
detail: "Invalid JSON",
|
||||
});
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new SdApiError(endpoint, options, response.status, response.statusText, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
export async function sdTxt2Img(
|
||||
api: SdApi,
|
||||
params: Partial<SdTxt2ImgRequest>,
|
||||
onProgress?: (progress: SdProgressResponse) => void,
|
||||
signal?: AbortSignal,
|
||||
signal: AbortSignal = neverSignal,
|
||||
): Promise<SdTxt2ImgResponse> {
|
||||
let response: Response | undefined;
|
||||
let error: unknown;
|
||||
|
||||
fetch(new URL("sdapi/v1/txt2img", apiUrl), {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(params),
|
||||
}).then(
|
||||
(resp) => (response = resp),
|
||||
(err) => (error = err),
|
||||
);
|
||||
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
|
||||
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
||||
.then((data) => ({
|
||||
...data,
|
||||
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
|
||||
}));
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
const progressRequest = await fetch(new URL("sdapi/v1/progress", apiUrl));
|
||||
if (progressRequest.ok) {
|
||||
const progress = (await progressRequest.json()) as SdProgressResponse;
|
||||
onProgress?.(progress);
|
||||
}
|
||||
if (response != null) {
|
||||
if (response.ok) {
|
||||
const result = (await response.json()) as SdTxt2ImgResponse;
|
||||
return result;
|
||||
} else {
|
||||
throw new Error(`Request failed: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
}
|
||||
if (error != null) {
|
||||
throw error;
|
||||
}
|
||||
signal?.throwIfAborted();
|
||||
await Async.abortable(Promise.race([request, Async.delay(3000)]), signal);
|
||||
if (await AsyncX.promiseState(request) !== "pending") return await request;
|
||||
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
|
||||
}
|
||||
} finally {
|
||||
if (!response && !error) {
|
||||
await fetch(new URL("sdapi/v1/interrupt", apiUrl), { method: "POST" });
|
||||
if (await AsyncX.promiseState(request) === "pending") {
|
||||
await fetchSdApi(api, "sdapi/v1/interrupt", {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgRequest {
|
||||
enable_hr: boolean;
|
||||
denoising_strength: number;
|
||||
firstphase_width: number;
|
||||
firstphase_height: number;
|
||||
hr_scale: number;
|
||||
hr_upscaler: unknown;
|
||||
hr_second_pass_steps: number;
|
||||
hr_resize_x: number;
|
||||
hr_resize_y: number;
|
||||
hr_sampler_name: unknown;
|
||||
hr_prompt: string;
|
||||
hr_negative_prompt: string;
|
||||
prompt: string;
|
||||
styles: unknown;
|
||||
seed: number;
|
||||
subseed: number;
|
||||
subseed_strength: number;
|
||||
seed_resize_from_h: number;
|
||||
seed_resize_from_w: number;
|
||||
sampler_name: unknown;
|
||||
batch_size: number;
|
||||
n_iter: number;
|
||||
|
@ -55,16 +90,68 @@ export interface SdTxt2ImgRequest {
|
|||
cfg_scale: number;
|
||||
width: number;
|
||||
height: number;
|
||||
restore_faces: boolean;
|
||||
tiling: boolean;
|
||||
do_not_save_samples: boolean;
|
||||
do_not_save_grid: boolean;
|
||||
negative_prompt: string;
|
||||
eta: unknown;
|
||||
s_min_uncond: number;
|
||||
s_churn: number;
|
||||
s_tmax: unknown;
|
||||
s_tmin: number;
|
||||
s_noise: number;
|
||||
override_settings: object;
|
||||
override_settings_restore_afterwards: boolean;
|
||||
script_args: unknown[];
|
||||
sampler_index: string;
|
||||
script_name: unknown;
|
||||
send_images: boolean;
|
||||
save_images: boolean;
|
||||
alwayson_scripts: object;
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgResponse {
|
||||
images: string[];
|
||||
parameters: SdTxt2ImgRequest;
|
||||
/** Contains serialized JSON */
|
||||
info: string;
|
||||
// Warning: raw response from API is a JSON-serialized string
|
||||
info: SdTxt2ImgInfo;
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgInfo {
|
||||
prompt: string;
|
||||
all_prompts: string[];
|
||||
negative_prompt: string;
|
||||
all_negative_prompts: string[];
|
||||
seed: number;
|
||||
all_seeds: number[];
|
||||
subseed: number;
|
||||
all_subseeds: number[];
|
||||
subseed_strength: number;
|
||||
width: number;
|
||||
height: number;
|
||||
sampler_name: string;
|
||||
cfg_scale: number;
|
||||
steps: number;
|
||||
batch_size: number;
|
||||
restore_faces: boolean;
|
||||
face_restoration_model: unknown;
|
||||
sd_model_hash: string;
|
||||
seed_resize_from_w: number;
|
||||
seed_resize_from_h: number;
|
||||
denoising_strength: number;
|
||||
extra_generation_params: SdTxt2ImgInfoExtraParams;
|
||||
index_of_first_image: number;
|
||||
infotexts: string[];
|
||||
styles: unknown[];
|
||||
job_timestamp: string;
|
||||
clip_skip: number;
|
||||
is_using_inpainting_conditioning: boolean;
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgInfoExtraParams {
|
||||
"Lora hashes": string;
|
||||
"TI hashes": string;
|
||||
}
|
||||
|
||||
export interface SdProgressResponse {
|
||||
|
@ -86,3 +173,138 @@ export interface SdProgressState {
|
|||
sampling_step: number;
|
||||
sampling_steps: number;
|
||||
}
|
||||
|
||||
export function sdGetConfig(api: SdApi): Promise<SdConfigResponse> {
|
||||
return fetchSdApi(api, "config");
|
||||
}
|
||||
|
||||
export interface SdConfigResponse {
|
||||
/** version with new line at the end for some reason */
|
||||
version: string;
|
||||
mode: string;
|
||||
dev_mode: boolean;
|
||||
analytics_enabled: boolean;
|
||||
components: object[];
|
||||
css: unknown;
|
||||
title: string;
|
||||
is_space: boolean;
|
||||
enable_queue: boolean;
|
||||
show_error: boolean;
|
||||
show_api: boolean;
|
||||
is_colab: boolean;
|
||||
stylesheets: unknown[];
|
||||
theme: string;
|
||||
layout: object;
|
||||
dependencies: object[];
|
||||
root: string;
|
||||
}
|
||||
|
||||
export interface SdErrorResponse {
|
||||
/**
|
||||
* The HTTP status message or array of invalid fields.
|
||||
* Can also be empty string.
|
||||
*/
|
||||
detail: string | Array<{ loc: string[]; msg: string; type: string }>;
|
||||
/** Can be e.g. "OutOfMemoryError" or undefined. */
|
||||
error?: string;
|
||||
/** Empty string. */
|
||||
body?: string;
|
||||
/** Long description of error. */
|
||||
errors?: string;
|
||||
}
|
||||
|
||||
export class SdApiError extends Error {
|
||||
constructor(
|
||||
public readonly endpoint: string,
|
||||
public readonly options: RequestInit | undefined,
|
||||
public readonly statusCode: number,
|
||||
public readonly statusText: string,
|
||||
public readonly response?: SdErrorResponse,
|
||||
) {
|
||||
let message = `${options?.method ?? "GET"} ${endpoint} : ${statusCode} ${statusText}`;
|
||||
if (response?.error) {
|
||||
message += `: ${response.error}`;
|
||||
if (response.errors) message += ` - ${response.errors}`;
|
||||
} else if (typeof response?.detail === "string" && response.detail.length > 0) {
|
||||
message += `: ${response.detail}`;
|
||||
} else if (response?.detail) {
|
||||
message += `: ${JSON.stringify(response.detail)}`;
|
||||
}
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
|
||||
export function getPngInfo(pngData: Uint8Array): string | undefined {
|
||||
return PngChunksExtract.default(pngData)
|
||||
.filter((chunk) => chunk.name === "tEXt")
|
||||
.map((chunk) => PngChunkText.decode(chunk.data))
|
||||
.find((textChunk) => textChunk.keyword === "parameters")
|
||||
?.text;
|
||||
}
|
||||
|
||||
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
|
||||
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
|
||||
let part: "prompt" | "negative_prompt" | "params" = "prompt";
|
||||
const params: Partial<SdTxt2ImgRequest> = {};
|
||||
const prompt: string[] = [];
|
||||
const negativePrompt: string[] = [];
|
||||
for (const tag of tags) {
|
||||
const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+([\d\w. ]+)\s*$/u);
|
||||
if (paramValuePair) {
|
||||
const [, param, value] = paramValuePair;
|
||||
switch (param.replace(/\s+/u, "").toLowerCase()) {
|
||||
case "prompt":
|
||||
part = "prompt";
|
||||
prompt.push(value.trim());
|
||||
break;
|
||||
case "negativeprompt":
|
||||
part = "negative_prompt";
|
||||
negativePrompt.push(value.trim());
|
||||
break;
|
||||
case "steps":
|
||||
case "cycles": {
|
||||
part = "params";
|
||||
const steps = Number(value.trim());
|
||||
if (steps > 0) params.steps = Math.min(steps, 50);
|
||||
break;
|
||||
}
|
||||
case "cfgscale":
|
||||
case "detail": {
|
||||
part = "params";
|
||||
const cfgScale = Number(value.trim());
|
||||
if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20);
|
||||
break;
|
||||
}
|
||||
case "size":
|
||||
case "resolution": {
|
||||
part = "params";
|
||||
const [width, height] = value.trim()
|
||||
.split(/\s*[x,]\s*/u, 2)
|
||||
.map((v) => v.trim())
|
||||
.map(Number);
|
||||
if (width > 0 && height > 0) {
|
||||
params.width = Math.min(width, 2048);
|
||||
params.height = Math.min(height, 2048);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (tag.trim().length > 0) {
|
||||
switch (part) {
|
||||
case "prompt":
|
||||
prompt.push(tag.trim());
|
||||
break;
|
||||
case "negative_prompt":
|
||||
negativePrompt.push(tag.trim());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (prompt.length > 0) params.prompt = prompt.join(", ");
|
||||
if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", ");
|
||||
return params;
|
||||
}
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
import { assert } from "https://deno.land/std@0.198.0/assert/assert.ts";
|
||||
import { Schema, Store } from "./store.ts";
|
||||
import { log } from "./deps.ts";
|
||||
|
||||
const db = await Deno.openKv();
|
||||
|
||||
log.setup({
|
||||
handlers: {
|
||||
console: new log.handlers.ConsoleHandler("DEBUG", {}),
|
||||
},
|
||||
loggers: {
|
||||
kvStore: { level: "DEBUG", handlers: ["console"] },
|
||||
},
|
||||
});
|
||||
|
||||
interface PointSchema {
|
||||
x: number;
|
||||
y: number;
|
||||
}
|
||||
|
||||
interface JobSchema {
|
||||
name: string;
|
||||
params: {
|
||||
a: number;
|
||||
b: number | null;
|
||||
};
|
||||
status: { type: "idle" } | { type: "processing"; progress: number } | { type: "done" };
|
||||
lastUpdateDate: Date;
|
||||
}
|
||||
|
||||
const pointStore = new Store(db, "points", {
|
||||
schema: new Schema<PointSchema>(),
|
||||
indices: ["x", "y"],
|
||||
});
|
||||
const jobStore = new Store(db, "jobs", {
|
||||
schema: new Schema<JobSchema>(),
|
||||
indices: ["name", "status.type"],
|
||||
});
|
||||
|
||||
Deno.test("create and delete", async () => {
|
||||
await pointStore.deleteAll();
|
||||
const point1 = await pointStore.create({ x: 1, y: 2 });
|
||||
const point2 = await pointStore.create({ x: 3, y: 4 });
|
||||
assert((await pointStore.getAll()).length === 2);
|
||||
const point3 = await pointStore.create({ x: 5, y: 6 });
|
||||
assert((await pointStore.getAll()).length === 3);
|
||||
assert((await pointStore.get(point2.id))?.value.y === 4);
|
||||
await point1.delete();
|
||||
assert((await pointStore.getAll()).length === 2);
|
||||
await point2.delete();
|
||||
await point3.delete();
|
||||
assert((await pointStore.getAll()).length === 0);
|
||||
});
|
||||
|
||||
Deno.test("list by index", async () => {
|
||||
await jobStore.deleteAll();
|
||||
|
||||
const test = await jobStore.create({
|
||||
name: "test",
|
||||
params: { a: 1, b: null },
|
||||
status: { type: "idle" },
|
||||
lastUpdateDate: new Date(),
|
||||
});
|
||||
assert((await jobStore.getBy("name", "test"))[0].value.params.a === 1);
|
||||
assert((await jobStore.getBy("status.type", "idle"))[0].value.params.a === 1);
|
||||
|
||||
await test.update({ status: { type: "processing", progress: 33 } });
|
||||
assert((await jobStore.getBy("status.type", "processing"))[0].value.params.a === 1);
|
||||
|
||||
await test.update({ status: { type: "done" } });
|
||||
assert((await jobStore.getBy("status.type", "done"))[0].value.params.a === 1);
|
||||
assert((await jobStore.getBy("status.type", "processing")).length === 0);
|
||||
|
||||
await test.delete();
|
||||
assert((await jobStore.getBy("status.type", "done")).length === 0);
|
||||
assert((await jobStore.getBy("name", "test")).length === 0);
|
||||
});
|
||||
|
||||
Deno.test("fail on concurrent update", async () => {
|
||||
await jobStore.deleteAll();
|
||||
|
||||
const test = await jobStore.create({
|
||||
name: "test",
|
||||
params: { a: 1, b: null },
|
||||
status: { type: "idle" },
|
||||
lastUpdateDate: new Date(),
|
||||
});
|
||||
|
||||
const result = await Promise.all([
|
||||
test.update({ status: { type: "processing", progress: 33 } }),
|
||||
test.update({ status: { type: "done" } }),
|
||||
]).catch(() => true);
|
||||
assert(result === true);
|
||||
|
||||
await test.delete();
|
||||
});
|
241
store.ts
241
store.ts
|
@ -1,241 +0,0 @@
|
|||
import { log, ulid } from "./deps.ts";
|
||||
|
||||
const logger = () => log.getLogger("kvStore");
|
||||
|
||||
export type validIndexKey<T> = {
|
||||
[K in keyof T]: K extends string ? (T[K] extends Deno.KvKeyPart ? K
|
||||
: T[K] extends readonly unknown[] ? never
|
||||
: T[K] extends object ? `${K}.${validIndexKey<T[K]>}`
|
||||
: never)
|
||||
: never;
|
||||
}[keyof T];
|
||||
|
||||
export type indexValue<T, I extends validIndexKey<T>> = I extends `${infer K}.${infer Rest}`
|
||||
? K extends keyof T ? Rest extends validIndexKey<T[K]> ? indexValue<T[K], Rest>
|
||||
: never
|
||||
: never
|
||||
: I extends keyof T ? T[I]
|
||||
: never;
|
||||
|
||||
export class Schema<T> {}
|
||||
|
||||
interface StoreOptions<T, I> {
|
||||
readonly schema: Schema<T>;
|
||||
readonly indices: readonly I[];
|
||||
}
|
||||
|
||||
export class Store<T, I extends validIndexKey<T>> {
|
||||
readonly #db: Deno.Kv;
|
||||
readonly #key: Deno.KvKeyPart;
|
||||
readonly #indices: readonly I[];
|
||||
|
||||
constructor(db: Deno.Kv, key: Deno.KvKeyPart, options: StoreOptions<T, I>) {
|
||||
this.#db = db;
|
||||
this.#key = key;
|
||||
this.#indices = options.indices;
|
||||
}
|
||||
|
||||
async create(value: T): Promise<Model<T>> {
|
||||
const id = ulid();
|
||||
await this.#db.set([this.#key, "id", id], value);
|
||||
logger().debug(["created", this.#key, "id", id].join(" "));
|
||||
for (const index of this.#indices) {
|
||||
const indexValue: Deno.KvKeyPart = index
|
||||
.split(".")
|
||||
.reduce((value, key) => value[key], value as any);
|
||||
await this.#db.set([this.#key, index, indexValue, id], value);
|
||||
logger().debug(["created", this.#key, index, indexValue, id].join(" "));
|
||||
}
|
||||
return new Model(this.#db, this.#key, this.#indices, id, value);
|
||||
}
|
||||
|
||||
async get(id: Deno.KvKeyPart): Promise<Model<T> | null> {
|
||||
const entry = await this.#db.get<T>([this.#key, "id", id]);
|
||||
if (entry.versionstamp == null) return null;
|
||||
return new Model(this.#db, this.#key, this.#indices, id, entry.value);
|
||||
}
|
||||
|
||||
async getBy<J extends I>(
|
||||
index: J,
|
||||
value: indexValue<T, J>,
|
||||
options?: Deno.KvListOptions,
|
||||
): Promise<Array<Model<T>>> {
|
||||
const models: Model<T>[] = [];
|
||||
for await (
|
||||
const entry of this.#db.list<T>(
|
||||
{ prefix: [this.#key, index, value as Deno.KvKeyPart] },
|
||||
options,
|
||||
)
|
||||
) {
|
||||
models.push(new Model(this.#db, this.#key, this.#indices, entry.key[3], entry.value));
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
async getAll(
|
||||
opts?: { limit?: number; reverse?: boolean },
|
||||
): Promise<Array<Model<T>>> {
|
||||
const { limit, reverse } = opts ?? {};
|
||||
const models: Array<Model<T>> = [];
|
||||
for await (
|
||||
const entry of this.#db.list<T>({
|
||||
prefix: [this.#key, "id"],
|
||||
}, { limit, reverse })
|
||||
) {
|
||||
models.push(new Model(this.#db, this.#key, this.#indices, entry.key[2], entry.value));
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
async deleteAll(): Promise<void> {
|
||||
for await (const entry of this.#db.list({ prefix: [this.#key] })) {
|
||||
await this.#db.delete(entry.key);
|
||||
logger().debug(["deleted", ...entry.key].join(" "));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class Model<T> {
|
||||
readonly #db: Deno.Kv;
|
||||
readonly #key: Deno.KvKeyPart;
|
||||
readonly #indices: readonly string[];
|
||||
readonly #id: Deno.KvKeyPart;
|
||||
value: T;
|
||||
|
||||
constructor(
|
||||
db: Deno.Kv,
|
||||
key: Deno.KvKeyPart,
|
||||
indices: readonly string[],
|
||||
id: Deno.KvKeyPart,
|
||||
value: T,
|
||||
) {
|
||||
this.#db = db;
|
||||
this.#key = key;
|
||||
this.#indices = indices;
|
||||
this.#id = id;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
get id(): Deno.KvKeyPart {
|
||||
return this.#id;
|
||||
}
|
||||
|
||||
async update(updater: Partial<T> | ((value: T) => T)): Promise<T | null> {
|
||||
// get current main entry
|
||||
const oldEntry = await this.#db.get<T>([this.#key, "id", this.#id]);
|
||||
|
||||
// get all current index entries
|
||||
const oldIndexEntries: Record<string, Deno.KvEntryMaybe<T>> = {};
|
||||
for (const index of this.#indices) {
|
||||
const indexKey: Deno.KvKeyPart = index
|
||||
.split(".")
|
||||
.reduce((value, key) => value[key], oldEntry.value as any);
|
||||
oldIndexEntries[index] = await this.#db.get<T>([this.#key, index, indexKey, this.#id]);
|
||||
}
|
||||
|
||||
// compute new value
|
||||
if (typeof updater === "function") {
|
||||
this.value = updater(this.value);
|
||||
} else {
|
||||
this.value = { ...this.value, ...updater };
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
const transaction = this.#db.atomic();
|
||||
|
||||
// set the main entry
|
||||
transaction
|
||||
.check(oldEntry)
|
||||
.set([this.#key, "id", this.#id], this.value);
|
||||
logger().debug(["updated", this.#key, "id", this.#id].join(" "));
|
||||
|
||||
// delete and create all changed index entries
|
||||
for (const index of this.#indices) {
|
||||
const oldIndexKey: Deno.KvKeyPart = index
|
||||
.split(".")
|
||||
.reduce((value, key) => value[key], oldIndexEntries[index].value as any);
|
||||
const newIndexKey: Deno.KvKeyPart = index
|
||||
.split(".")
|
||||
.reduce((value, key) => value[key], this.value as any);
|
||||
if (newIndexKey !== oldIndexKey) {
|
||||
transaction
|
||||
.check(oldIndexEntries[index])
|
||||
.delete([this.#key, index, oldIndexKey, this.#id])
|
||||
.set([this.#key, index, newIndexKey, this.#id], this.value);
|
||||
logger().debug(["deleted", this.#key, index, oldIndexKey, this.#id].join(" "));
|
||||
logger().debug(["created", this.#key, index, newIndexKey, this.#id].join(" "));
|
||||
}
|
||||
}
|
||||
|
||||
// commit
|
||||
const result = await transaction.commit();
|
||||
if (!result.ok) throw new Error(`Failed to update ${this.#key} ${this.#id}`);
|
||||
return this.value;
|
||||
}
|
||||
|
||||
async delete(): Promise<void> {
|
||||
// get current main entry
|
||||
const entry = await this.#db.get<T>([this.#key, "id", this.#id]);
|
||||
|
||||
// begin transaction
|
||||
const transaction = this.#db.atomic();
|
||||
|
||||
// delete main entry
|
||||
transaction
|
||||
.check(entry)
|
||||
.delete([this.#key, "id", this.#id]);
|
||||
logger().debug(["deleted", this.#key, "id", this.#id].join(" "));
|
||||
|
||||
// delete all index entries
|
||||
for (const index of this.#indices) {
|
||||
const indexKey: Deno.KvKeyPart = index
|
||||
.split(".")
|
||||
.reduce((value, key) => value[key], entry.value as any);
|
||||
transaction
|
||||
.delete([this.#key, index, indexKey, this.#id]);
|
||||
logger().debug(["deleted", this.#key, index, indexKey, this.#id].join(" "));
|
||||
}
|
||||
|
||||
// commit
|
||||
const result = await transaction.commit();
|
||||
if (!result.ok) throw new Error(`Failed to delete ${this.#key} ${this.#id}`);
|
||||
}
|
||||
}
|
||||
|
||||
export async function retry<T>(
|
||||
fn: () => Promise<T>,
|
||||
options: { maxAttempts?: number; delayMs?: number } = {},
|
||||
): Promise<T> {
|
||||
const { maxAttempts = 3, delayMs = 1000 } = options;
|
||||
let error: unknown;
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
return await fn();
|
||||
} catch (err) {
|
||||
error = err;
|
||||
await new Promise((resolve) => setTimeout(resolve, delayMs));
|
||||
}
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
export async function collectIterator<T>(
|
||||
iterator: AsyncIterableIterator<T>,
|
||||
options: { maxItems?: number; timeoutMs?: number } = {},
|
||||
): Promise<T[]> {
|
||||
const { maxItems = 1000, timeoutMs = 2000 } = options;
|
||||
const result: T[] = [];
|
||||
const timeout = setTimeout(() => iterator.return?.(), timeoutMs);
|
||||
try {
|
||||
for await (const item of iterator) {
|
||||
result.push(item);
|
||||
if (result.length >= maxItems) {
|
||||
iterator.return?.();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
import { pingWorkers } from "./pingWorkers.ts";
|
||||
import { processJobs } from "./processJobs.ts";
|
||||
import { returnHangedJobs } from "./returnHangedJobs.ts";
|
||||
import { updateJobStatusMsgs } from "./updateJobStatusMsgs.ts";
|
||||
|
||||
export async function runAllTasks() {
|
||||
await Promise.all([
|
||||
processJobs(),
|
||||
updateJobStatusMsgs(),
|
||||
returnHangedJobs(),
|
||||
pingWorkers(),
|
||||
]);
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
import { Async, Log } from "../deps.ts";
|
||||
import { getGlobalSession } from "../bot/session.ts";
|
||||
import { sdGetConfig } from "../sd.ts";
|
||||
|
||||
const logger = () => Log.getLogger();
|
||||
|
||||
export const runningWorkers = new Set<string>();
|
||||
|
||||
/**
|
||||
* Periodically ping the workers to see if they are alive.
|
||||
*/
|
||||
export async function pingWorkers(): Promise<never> {
|
||||
while (true) {
|
||||
try {
|
||||
const config = await getGlobalSession();
|
||||
for (const worker of config.workers) {
|
||||
const status = await sdGetConfig(worker.api).catch(() => null);
|
||||
const wasRunning = runningWorkers.has(worker.name);
|
||||
if (status) {
|
||||
runningWorkers.add(worker.name);
|
||||
if (!wasRunning) logger().info(`Worker ${worker.name} is online`);
|
||||
} else {
|
||||
runningWorkers.delete(worker.name);
|
||||
if (wasRunning) logger().warning(`Worker ${worker.name} went offline`);
|
||||
}
|
||||
}
|
||||
await Async.delay(60 * 1000);
|
||||
} catch (err) {
|
||||
logger().warning(`Pinging workers failed: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
import { Base64, FileType, FmtDuration, Grammy, GrammyParseMode, IKV, Log } from "../deps.ts";
|
||||
import { bot } from "../bot/mod.ts";
|
||||
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
|
||||
import { fmt, formatUserChat } from "../utils.ts";
|
||||
import { SdApiError, sdTxt2Img } from "../sd.ts";
|
||||
import { JobSchema, jobStore } from "../db/jobStore.ts";
|
||||
import { runningWorkers } from "./pingWorkers.ts";
|
||||
|
||||
const logger = () => Log.getLogger();
|
||||
|
||||
/**
|
||||
* Sends waiting jobs to workers.
|
||||
*/
|
||||
export async function processJobs(): Promise<never> {
|
||||
const busyWorkers = new Set<string>();
|
||||
while (true) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
|
||||
try {
|
||||
// get first waiting job
|
||||
const job = await jobStore.getBy("status.type", "waiting").then((jobs) => jobs[0]);
|
||||
if (!job) continue;
|
||||
|
||||
// find a worker to handle the job
|
||||
const config = await getGlobalSession();
|
||||
const worker = config.workers?.find((worker) =>
|
||||
runningWorkers.has(worker.name) &&
|
||||
!busyWorkers.has(worker.name)
|
||||
);
|
||||
if (!worker) continue;
|
||||
|
||||
// process the job
|
||||
await job.update({
|
||||
status: { type: "processing", progress: 0, worker: worker.name, updatedDate: new Date() },
|
||||
});
|
||||
|
||||
busyWorkers.add(worker.name);
|
||||
processJob(job, worker, config)
|
||||
.catch(async (err) => {
|
||||
logger().error(
|
||||
`Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`,
|
||||
);
|
||||
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
|
||||
await bot.api.sendMessage(
|
||||
job.value.request.chat.id,
|
||||
`Failed to generate your prompt: ${err.message}`,
|
||||
{ reply_to_message_id: job.value.request.message_id },
|
||||
).catch(() => undefined);
|
||||
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
|
||||
}
|
||||
if (
|
||||
err instanceof SdApiError &&
|
||||
(err.statusCode === 0 /* Network error */ || err.statusCode === 404)
|
||||
) {
|
||||
runningWorkers.delete(worker.name);
|
||||
logger().warning(
|
||||
`Worker ${worker.name} was marked as offline because of network error`,
|
||||
);
|
||||
}
|
||||
await job.delete().catch(() => undefined);
|
||||
await jobStore.create(job.value);
|
||||
})
|
||||
.finally(() => busyWorkers.delete(worker.name));
|
||||
} catch (err) {
|
||||
logger().warning(`Processing jobs failed: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
|
||||
logger().debug(
|
||||
`Job started for ${formatUserChat(job.value.request)} using ${worker.name}`,
|
||||
);
|
||||
const startDate = new Date();
|
||||
|
||||
// if there is already a status message delete it
|
||||
if (job.value.reply) {
|
||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
||||
.catch(() => undefined);
|
||||
}
|
||||
|
||||
// send a new status message
|
||||
const newStatusMessage = await bot.api.sendMessage(
|
||||
job.value.request.chat.id,
|
||||
`Generating your prompt now... 0% using ${worker.name}`,
|
||||
{ reply_to_message_id: job.value.request.message_id },
|
||||
).catch((err) => {
|
||||
// don't error if the request message was deleted
|
||||
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
|
||||
else throw err;
|
||||
});
|
||||
// if the request message was deleted, cancel the job
|
||||
if (!newStatusMessage) {
|
||||
await job.delete();
|
||||
logger().info(
|
||||
`Job cancelled for ${formatUserChat(job.value.request)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
await job.update({ reply: newStatusMessage });
|
||||
|
||||
// reduce size if worker can't handle the resolution
|
||||
const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution);
|
||||
|
||||
// process the job
|
||||
const response = await sdTxt2Img(
|
||||
worker.api,
|
||||
{ ...config.defaultParams, ...job.value.params, ...size },
|
||||
async (progress) => {
|
||||
// important: don't let any errors escape this callback
|
||||
if (job.value.reply) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
`Generating your prompt now... ${
|
||||
(progress.progress * 100).toFixed(0)
|
||||
}% using ${worker.name}`,
|
||||
).catch(() => undefined);
|
||||
}
|
||||
await job.update({
|
||||
status: {
|
||||
type: "processing",
|
||||
progress: progress.progress,
|
||||
worker: worker.name,
|
||||
updatedDate: new Date(),
|
||||
},
|
||||
}).catch(() => undefined);
|
||||
},
|
||||
);
|
||||
|
||||
// upload the result
|
||||
if (job.value.reply) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
`Uploading your images...`,
|
||||
).catch(() => undefined);
|
||||
}
|
||||
|
||||
// render the caption
|
||||
// const detailedReply = Object.keys(job.value.params).filter((key) => key !== "prompt").length > 0;
|
||||
const detailedReply = true;
|
||||
const jobDurationMs = Math.trunc((Date.now() - startDate.getTime()) / 1000) * 1000;
|
||||
const { bold } = GrammyParseMode;
|
||||
const caption = fmt([
|
||||
`${response.info.prompt}\n`,
|
||||
...detailedReply
|
||||
? [
|
||||
response.info.negative_prompt
|
||||
? fmt`${bold("Negative prompt:")} ${response.info.negative_prompt}\n`
|
||||
: "",
|
||||
fmt`${bold("Steps:")} ${response.info.steps}, `,
|
||||
fmt`${bold("Sampler:")} ${response.info.sampler_name}, `,
|
||||
fmt`${bold("CFG scale:")} ${response.info.cfg_scale}, `,
|
||||
fmt`${bold("Seed:")} ${response.info.seed}, `,
|
||||
fmt`${bold("Size")}: ${response.info.width}x${response.info.height}, `,
|
||||
fmt`${bold("Worker")}: ${worker.name}, `,
|
||||
fmt`${bold("Time taken")}: ${FmtDuration.format(jobDurationMs, { ignoreZero: true })}`,
|
||||
]
|
||||
: [],
|
||||
]);
|
||||
|
||||
// parse files from reply JSON
|
||||
const inputFiles = await Promise.all(
|
||||
response.images.map(async (imageBase64, idx) => {
|
||||
const imageBuffer = Base64.decode(imageBase64);
|
||||
const imageType = await FileType.fileTypeFromBuffer(imageBuffer);
|
||||
if (!imageType) throw new Error("Unknown file type returned from worker");
|
||||
return Grammy.InputMediaBuilder.photo(
|
||||
new Grammy.InputFile(imageBuffer, `image${idx}.${imageType.ext}`),
|
||||
// if it can fit, add caption for first photo
|
||||
idx === 0 && caption.text.length <= 1024
|
||||
? { caption: caption.text, caption_entities: caption.entities }
|
||||
: undefined,
|
||||
);
|
||||
}),
|
||||
);
|
||||
|
||||
// send the result to telegram
|
||||
const resultMessage = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, {
|
||||
reply_to_message_id: job.value.request.message_id,
|
||||
});
|
||||
// send caption in separate message if it couldn't fit
|
||||
if (caption.text.length > 1024 && caption.text.length <= 4096) {
|
||||
await bot.api.sendMessage(job.value.request.chat.id, caption.text, {
|
||||
reply_to_message_id: resultMessage[0].message_id,
|
||||
entities: caption.entities,
|
||||
});
|
||||
}
|
||||
|
||||
// delete the status message
|
||||
if (job.value.reply) {
|
||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
||||
.catch(() => undefined)
|
||||
.then(() => job.update({ reply: undefined }))
|
||||
.catch(() => undefined);
|
||||
}
|
||||
|
||||
// update job to status done
|
||||
await job.update({
|
||||
status: { type: "done", info: response.info, startDate, endDate: new Date() },
|
||||
});
|
||||
logger().debug(
|
||||
`Job finished for ${formatUserChat(job.value.request)} using ${worker.name}`,
|
||||
);
|
||||
}
|
||||
|
||||
function limitSize(
|
||||
{ width, height }: { width?: number; height?: number },
|
||||
maxResolution: number,
|
||||
): { width?: number; height?: number } {
|
||||
if (!width || !height) return {};
|
||||
const ratio = width / height;
|
||||
if (width * height > maxResolution) {
|
||||
return {
|
||||
width: Math.trunc(Math.sqrt(maxResolution * ratio)),
|
||||
height: Math.trunc(Math.sqrt(maxResolution / ratio)),
|
||||
};
|
||||
}
|
||||
return { width, height };
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
import { FmtDuration, Log } from "../deps.ts";
|
||||
import { formatUserChat } from "../utils.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
|
||||
const logger = () => Log.getLogger();
|
||||
|
||||
/**
|
||||
* Returns hanged jobs to the queue.
|
||||
*/
|
||||
export async function returnHangedJobs(): Promise<never> {
|
||||
while (true) {
|
||||
try {
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
const jobs = await jobStore.getBy("status.type", "processing");
|
||||
for (const job of jobs) {
|
||||
if (job.value.status.type !== "processing") continue;
|
||||
// if job wasn't updated for 1 minute, return it to the queue
|
||||
const timeSinceLastUpdateMs = Date.now() - job.value.status.updatedDate.getTime();
|
||||
if (timeSinceLastUpdateMs > 60 * 1000) {
|
||||
await job.update({ status: { type: "waiting" } });
|
||||
logger().warning(
|
||||
`Job for ${
|
||||
formatUserChat(job.value.request)
|
||||
} was returned to the queue because it hanged for ${
|
||||
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
|
||||
ignoreZero: true,
|
||||
})
|
||||
}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger().warning(`Returning hanged jobs failed: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
import { Log } from "../deps.ts";
|
||||
import { bot } from "../bot/mod.ts";
|
||||
import { formatOrdinal } from "../utils.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
|
||||
const logger = () => Log.getLogger();
|
||||
|
||||
/**
|
||||
* Updates status messages for jobs in the queue.
|
||||
*/
|
||||
export async function updateJobStatusMsgs(): Promise<never> {
|
||||
while (true) {
|
||||
try {
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||
for (const [index, job] of jobs.entries()) {
|
||||
if (!job.value.reply) continue;
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
`You are ${formatOrdinal(index + 1)} in queue.`,
|
||||
).catch(() => undefined);
|
||||
}
|
||||
} catch (err) {
|
||||
logger().warning(`Updating job status messages failed: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
export function decode(chunk: Uint8Array): { keyword: string; text: string };
|
||||
export function encode(keyword: string, text: string): Uint8Array;
|
|
@ -0,0 +1 @@
|
|||
export default function encode(chunks: Array<{ name: string; data: Uint8Array }>): Uint8Array;
|
|
@ -0,0 +1 @@
|
|||
export default function extract(data: Uint8Array): Array<{ name: string; data: Uint8Array }>;
|
|
@ -0,0 +1,111 @@
|
|||
import { GrammyParseMode, GrammyTypes } from "./deps.ts";
|
||||
|
||||
export function formatOrdinal(n: number) {
|
||||
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
|
||||
if (n % 10 === 1) return `${n}st`;
|
||||
if (n % 10 === 2) return `${n}nd`;
|
||||
if (n % 10 === 3) return `${n}rd`;
|
||||
return `${n}th`;
|
||||
}
|
||||
|
||||
export const fmt = (
|
||||
rawStringParts: TemplateStringsArray | GrammyParseMode.Stringable[],
|
||||
...stringLikes: GrammyParseMode.Stringable[]
|
||||
): GrammyParseMode.FormattedString => {
|
||||
let text = "";
|
||||
const entities: GrammyTypes.MessageEntity[] = [];
|
||||
|
||||
const length = Math.max(rawStringParts.length, stringLikes.length);
|
||||
for (let i = 0; i < length; i++) {
|
||||
for (const stringLike of [rawStringParts[i], stringLikes[i]]) {
|
||||
if (stringLike instanceof GrammyParseMode.FormattedString) {
|
||||
entities.push(
|
||||
...stringLike.entities.map((e) => ({
|
||||
...e,
|
||||
offset: e.offset + text.length,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (stringLike != null) text += stringLike.toString();
|
||||
}
|
||||
}
|
||||
return new GrammyParseMode.FormattedString(text, entities);
|
||||
};
|
||||
|
||||
export function formatUserChat(ctx: { from?: GrammyTypes.User; chat?: GrammyTypes.Chat }) {
|
||||
const msg: string[] = [];
|
||||
if (ctx.from) {
|
||||
msg.push(ctx.from.first_name);
|
||||
if (ctx.from.last_name) msg.push(ctx.from.last_name);
|
||||
if (ctx.from.username) msg.push(`(@${ctx.from.username})`);
|
||||
if (ctx.from.language_code) msg.push(`(${ctx.from.language_code.toUpperCase()})`);
|
||||
}
|
||||
if (ctx.chat) {
|
||||
if (
|
||||
ctx.chat.type === "group" ||
|
||||
ctx.chat.type === "supergroup" ||
|
||||
ctx.chat.type === "channel"
|
||||
) {
|
||||
msg.push("in");
|
||||
msg.push(ctx.chat.title);
|
||||
if (
|
||||
(ctx.chat.type === "supergroup" || ctx.chat.type === "channel") &&
|
||||
ctx.chat.username
|
||||
) {
|
||||
msg.push(`(@${ctx.chat.username})`);
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg.join(" ");
|
||||
}
|
||||
|
||||
/** Language to biggest country emoji map */
|
||||
const languageToFlagMap: Record<string, string> = {
|
||||
"en": "🇺🇸", // English - United States
|
||||
"zh": "🇨🇳", // Chinese - China
|
||||
"es": "🇪🇸", // Spanish - Spain
|
||||
"hi": "🇮🇳", // Hindi - India
|
||||
"ar": "🇪🇬", // Arabic - Egypt
|
||||
"pt": "🇧🇷", // Portuguese - Brazil
|
||||
"bn": "🇧🇩", // Bengali - Bangladesh
|
||||
"ru": "🇷🇺", // Russian - Russia
|
||||
"ja": "🇯🇵", // Japanese - Japan
|
||||
"pa": "🇮🇳", // Punjabi - India
|
||||
"de": "🇩🇪", // German - Germany
|
||||
"ko": "🇰🇷", // Korean - South Korea
|
||||
"fr": "🇫🇷", // French - France
|
||||
"tr": "🇹🇷", // Turkish - Turkey
|
||||
"ur": "🇵🇰", // Urdu - Pakistan
|
||||
"it": "🇮🇹", // Italian - Italy
|
||||
"th": "🇹🇭", // Thai - Thailand
|
||||
"vi": "🇻🇳", // Vietnamese - Vietnam
|
||||
"pl": "🇵🇱", // Polish - Poland
|
||||
"uk": "🇺🇦", // Ukrainian - Ukraine
|
||||
"uz": "🇺🇿", // Uzbek - Uzbekistan
|
||||
"su": "🇮🇩", // Sundanese - Indonesia
|
||||
"sw": "🇹🇿", // Swahili - Tanzania
|
||||
"nl": "🇳🇱", // Dutch - Netherlands
|
||||
"fi": "🇫🇮", // Finnish - Finland
|
||||
"el": "🇬🇷", // Greek - Greece
|
||||
"da": "🇩🇰", // Danish - Denmark
|
||||
"cs": "🇨🇿", // Czech - Czech Republic
|
||||
"sk": "🇸🇰", // Slovak - Slovakia
|
||||
"bg": "🇧🇬", // Bulgarian - Bulgaria
|
||||
"sv": "🇸🇪", // Swedish - Sweden
|
||||
"be": "🇧🇾", // Belarusian - Belarus
|
||||
"hu": "🇭🇺", // Hungarian - Hungary
|
||||
"lt": "🇱🇹", // Lithuanian - Lithuania
|
||||
"lv": "🇱🇻", // Latvian - Latvia
|
||||
"et": "🇪🇪", // Estonian - Estonia
|
||||
"sl": "🇸🇮", // Slovenian - Slovenia
|
||||
"hr": "🇭🇷", // Croatian - Croatia
|
||||
"zu": "🇿🇦", // Zulu - South Africa
|
||||
"id": "🇮🇩", // Indonesian - Indonesia
|
||||
"is": "🇮🇸", // Icelandic - Iceland
|
||||
"lb": "🇱🇺", // Luxembourgish - Luxembourg
|
||||
};
|
||||
|
||||
export function getFlagEmoji(countryCode?: string): string | undefined {
|
||||
if (!countryCode) return;
|
||||
return languageToFlagMap[countryCode];
|
||||
}
|
Loading…
Reference in New Issue