Compare commits

..

9 Commits

Author SHA1 Message Date
pinks 3f27b4470b chore: update to ikv 0.3 2023-09-18 01:01:09 +02:00
pinks 2d2ffb8588 fix: sd interrupt 2023-09-18 00:35:33 +02:00
pinks 37b4f5d96a remove types 2023-09-18 00:35:12 +02:00
pinks 0517ce1930 feat: add scale parameter to img2img 2023-09-16 13:49:12 +02:00
pinks 8e81f82d8b add todo 2023-09-15 00:40:12 +02:00
pinks 9155d513b5 wait before retrying failed job 2023-09-14 03:23:55 +02:00
pinks ee4c2091f0 interrupt call should be POST 2023-09-14 03:21:40 +02:00
pinks 0cf9dcad04 add vscode config 2023-09-13 11:50:49 +02:00
pinks fcb655ea09 move files 2023-09-13 11:50:22 +02:00
24 changed files with 347 additions and 312 deletions

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
.vscode
.env
app.db*
deno.lock

5
.vscode/extensions.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"recommendations": [
"denoland.vscode-deno"
]
}

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"deno.enable": true,
"deno.unstable": true,
"editor.defaultFormatter": "denoland.vscode-deno"
}

View File

@ -1,7 +1,7 @@
import { Collections, Grammy, GrammyStatelessQ } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { formatUserChat } from "../common/utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { parsePngInfo, PngInfo } from "../sd.ts";
import { parsePngInfo, PngInfo } from "../common/parsePngInfo.ts";
import { Context, logger } from "./mod.ts";
export const img2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
@ -32,7 +32,7 @@ async function img2img(
return;
}
const jobs = await jobStore.getBy("status.type", "waiting");
const jobs = await jobStore.getBy("status.type", { value: "waiting" });
if (jobs.length >= ctx.session.global.maxJobs) {
await ctx.reply(
`The queue is full. Try again later. (Max queue size: ${ctx.session.global.maxJobs})`,
@ -73,22 +73,10 @@ async function img2img(
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
if (includeRepliedTo && repliedToText) {
// TODO: remove bot command from replied to text
const originalParams = parsePngInfo(repliedToText);
params = {
...originalParams,
...params,
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
negative_prompt: [originalParams.negative_prompt, params.negative_prompt]
.filter(Boolean).join("\n"),
};
params = parsePngInfo(repliedToText, params);
}
const messageParams = parsePngInfo(match ?? "");
params = {
...params,
...messageParams,
prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"),
};
params = parsePngInfo(match ?? "", params);
if (!fileId) {
await ctx.reply(

View File

@ -1,5 +1,5 @@
import { Grammy, GrammyAutoQuote, GrammyFiles, GrammyParseMode, Log } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { formatUserChat } from "../common/utils.ts";
import { session, SessionFlavor } from "./session.ts";
import { queueCommand } from "./queueCommand.ts";
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
@ -32,13 +32,15 @@ bot.api.config.use(async (prev, method, payload, signal) => {
let timedOut = false;
const timeout = setTimeout(() => {
timedOut = true;
// TODO: this sometimes throws with "can't abort a locked stream" and crashes whole process
controller.abort();
}, 30 * 1000);
signal?.addEventListener("abort", () => {
controller.abort();
});
try {
return await prev(method, payload, controller.signal);
const result = await prev(method, payload, controller.signal);
return result;
} finally {
clearTimeout(timeout);
if (timedOut) {

View File

@ -1,7 +1,6 @@
import { Grammy, GrammyParseMode, GrammyStatelessQ } from "../deps.ts";
import { fmt } from "../utils.ts";
import { getPngInfo, parsePngInfo } from "../sd.ts";
import { fmt } from "../common/utils.ts";
import { getPngInfo, parsePngInfo } from "../common/parsePngInfo.ts";
import { Context } from "./mod.ts";
export const pnginfoQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(

View File

@ -1,8 +1,9 @@
import { Grammy, GrammyParseMode } from "../deps.ts";
import { fmt, getFlagEmoji } from "../utils.ts";
import { fmt } from "../common/utils.ts";
import { runningWorkers } from "../tasks/pingWorkers.ts";
import { jobStore } from "../db/jobStore.ts";
import { Context, logger } from "./mod.ts";
import { getFlagEmoji } from "../common/getFlagEmoji.ts";
export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
let formattedMessage = await getMessageText();
@ -10,9 +11,9 @@ export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
handleFutureUpdates().catch((err) => logger().warning(`Updating queue message failed: ${err}`));
async function getMessageText() {
const processingJobs = await jobStore.getBy("status.type", "processing")
const processingJobs = await jobStore.getBy("status.type", { value: "processing" })
.then((jobs) => jobs.map((job) => ({ ...job.value, place: 0 })));
const waitingJobs = await jobStore.getBy("status.type", "waiting")
const waitingJobs = await jobStore.getBy("status.type", { value: "waiting" })
.then((jobs) => jobs.map((job, index) => ({ ...job.value, place: index + 1 })));
const jobs = [...processingJobs, ...waitingJobs];
const { bold } = GrammyParseMode;

View File

@ -1,6 +1,6 @@
import { db } from "../db/db.ts";
import { Grammy, GrammyKvStorage } from "../deps.ts";
import { SdApi, SdTxt2ImgRequest } from "../sd.ts";
import { SdApi, SdTxt2ImgRequest } from "../common/sdApi.ts";
export type SessionFlavor = Grammy.SessionFlavor<SessionData>;

View File

@ -1,7 +1,7 @@
import { Grammy, GrammyStatelessQ } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { formatUserChat } from "../common/utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { getPngInfo, parsePngInfo, PngInfo } from "../sd.ts";
import { getPngInfo, parsePngInfo, PngInfo } from "../common/parsePngInfo.ts";
import { Context, logger } from "./mod.ts";
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
@ -27,7 +27,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
return;
}
const jobs = await jobStore.getBy("status.type", "waiting");
const jobs = await jobStore.getBy("status.type", { value: "waiting" });
if (jobs.length >= ctx.session.global.maxJobs) {
await ctx.reply(
`The queue is full. Try again later. (Max queue size: ${ctx.session.global.maxJobs})`,
@ -50,35 +50,16 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
if (includeRepliedTo && repliedToMsg?.document?.mime_type === "image/png") {
const file = await ctx.api.getFile(repliedToMsg.document.file_id);
const buffer = await fetch(file.getUrl()).then((resp) => resp.arrayBuffer());
const fileParams = parsePngInfo(getPngInfo(new Uint8Array(buffer)) ?? "");
params = {
...params,
...fileParams,
prompt: [params.prompt, fileParams.prompt].filter(Boolean).join("\n"),
negative_prompt: [params.negative_prompt, fileParams.negative_prompt]
.filter(Boolean).join("\n"),
};
params = parsePngInfo(getPngInfo(new Uint8Array(buffer)) ?? "", params);
}
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
if (includeRepliedTo && repliedToText) {
// TODO: remove bot command from replied to text
const originalParams = parsePngInfo(repliedToText);
params = {
...originalParams,
...params,
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
negative_prompt: [originalParams.negative_prompt, params.negative_prompt]
.filter(Boolean).join("\n"),
};
params = parsePngInfo(repliedToText, params);
}
const messageParams = parsePngInfo(match);
params = {
...params,
...messageParams,
prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"),
};
params = parsePngInfo(match, params);
if (!params.prompt) {
await ctx.reply(

51
common/getFlagEmoji.ts Normal file
View File

@ -0,0 +1,51 @@
/** Language to biggest country emoji map */
const languageToFlagMap: Record<string, string> = {
"en": "🇺🇸",
"zh": "🇨🇳",
"es": "🇪🇸",
"hi": "🇮🇳",
"ar": "🇪🇬",
"pt": "🇧🇷",
"bn": "🇧🇩",
"ru": "🇷🇺",
"ja": "🇯🇵",
"pa": "🇮🇳",
"de": "🇩🇪",
"ko": "🇰🇷",
"fr": "🇫🇷",
"tr": "🇹🇷",
"ur": "🇵🇰",
"it": "🇮🇹",
"th": "🇹🇭",
"vi": "🇻🇳",
"pl": "🇵🇱",
"uk": "🇺🇦",
"uz": "🇺🇿",
"su": "🇮🇩",
"sw": "🇹🇿",
"nl": "🇳🇱",
"fi": "🇫🇮",
"el": "🇬🇷",
"da": "🇩🇰",
"cs": "🇨🇿",
"sk": "🇸🇰",
"bg": "🇧🇬",
"sv": "🇸🇪",
"be": "🇧🇾",
"hu": "🇭🇺",
"lt": "🇱🇹",
"lv": "🇱🇻",
"et": "🇪🇪",
"sl": "🇸🇮",
"hr": "🇭🇷",
"zu": "🇿🇦",
"id": "🇮🇩",
"is": "🇮🇸",
"lb": "🇱🇺", // Luxembourgish - Luxembourg
};
export function getFlagEmoji(languageCode?: string): string | undefined {
const language = languageCode?.split("-").pop()?.toLowerCase();
if (!language) return;
return languageToFlagMap[language];
}

View File

@ -3,7 +3,7 @@ import {
assertEquals,
assertMatch,
} from "https://deno.land/std@0.135.0/testing/asserts.ts";
import { parsePngInfo } from "./sd.ts";
import { parsePngInfo } from "./parsePngInfo.ts";
Deno.test("parses pnginfo", async (t) => {
await t.step("1", () => {

142
common/parsePngInfo.ts Normal file
View File

@ -0,0 +1,142 @@
import { pngChunksExtract, pngChunkTextDecode } from "../deps.ts";
export function getPngInfo(pngData: Uint8Array): string | undefined {
return pngChunksExtract(pngData)
.filter((chunk) => chunk.name === "tEXt")
.map((chunk) => pngChunkTextDecode(chunk.data))
.find((textChunk) => textChunk.keyword === "parameters")
?.text;
}
export interface PngInfo {
prompt: string;
negative_prompt: string;
steps: number;
cfg_scale: number;
width: number;
height: number;
sampler_name: string;
seed: number;
denoising_strength: number;
}
interface PngInfoExtra extends PngInfo {
upscale?: number;
}
export function parsePngInfo(pngInfo: string, baseParams?: Partial<PngInfo>): Partial<PngInfo> {
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
let part: "prompt" | "negative_prompt" | "params" = "prompt";
const params: Partial<PngInfoExtra> = {};
const prompt: string[] = [];
const negativePrompt: string[] = [];
for (const tag of tags) {
const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+(.*)$/u);
if (paramValuePair) {
const [, param, value] = paramValuePair;
switch (param.replace(/\s+/u, "").toLowerCase()) {
case "positiveprompt":
case "positive":
case "prompt":
case "pos":
part = "prompt";
prompt.push(value.trim());
break;
case "negativeprompt":
case "negative":
case "neg":
part = "negative_prompt";
negativePrompt.push(value.trim());
break;
case "steps":
case "cycles": {
part = "params";
const steps = Number(value.trim());
if (steps > 0) params.steps = Math.min(steps, 50);
break;
}
case "cfgscale":
case "cfg":
case "detail": {
part = "params";
const cfgScale = Number(value.trim());
if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20);
break;
}
case "size":
case "resolution": {
part = "params";
const [width, height] = value.trim()
.split(/\s*[x,]\s*/u, 2)
.map((v) => v.trim())
.map(Number);
if (width > 0 && height > 0) {
params.width = Math.min(width, 2048);
params.height = Math.min(height, 2048);
}
break;
}
case "upscale":
case "scale": {
part = "params";
const upscale = Number(value.trim());
if (upscale > 0) params.upscale = Math.min(upscale, 2);
break;
}
case "denoisingstrength":
case "denoising":
case "denoise": {
part = "params";
// allow percent or decimal
let denoisingStrength: number;
if (value.trim().endsWith("%")) {
denoisingStrength = Number(value.trim().slice(0, -1).trim()) / 100;
} else {
denoisingStrength = Number(value.trim());
}
denoisingStrength = Math.min(Math.max(denoisingStrength, 0), 1);
params.denoising_strength = denoisingStrength;
break;
}
case "seed":
case "model":
case "modelhash":
case "modelname":
case "sampler":
part = "params";
// ignore for now
break;
default:
break;
}
} else if (tag.trim().length > 0) {
switch (part) {
case "prompt":
prompt.push(tag.trim());
break;
case "negative_prompt":
negativePrompt.push(tag.trim());
break;
default:
break;
}
}
}
if (prompt.length > 0) params.prompt = prompt.join(", ");
if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", ");
// handle upscale
if (params.upscale && baseParams?.width && baseParams?.height) {
params.width = baseParams.width * params.upscale;
params.height = baseParams.height * params.upscale;
}
return {
...baseParams,
...params,
prompt: [baseParams?.prompt, params.prompt]
.filter(Boolean).join("\n"),
negative_prompt: [baseParams?.negative_prompt, params.negative_prompt]
.filter(Boolean).join("\n"),
};
}

View File

@ -1,4 +1,4 @@
import { Async, AsyncX, PngChunksExtract, PngChunkText } from "./deps.ts";
import { Async, AsyncX } from "../deps.ts";
export interface SdApi {
url: string;
@ -110,7 +110,7 @@ export async function sdTxt2Img(
}
} finally {
if (await AsyncX.promiseState(request) === "pending") {
await fetchSdApi(api, "sdapi/v1/interrupt", { timeoutMs: 10_000 });
await fetchSdApi(api, "sdapi/v1/interrupt", { body: {}, timeoutMs: 10_000 });
}
}
}
@ -155,7 +155,7 @@ export async function sdImg2Img(
}
} finally {
if (await AsyncX.promiseState(request) === "pending") {
await fetchSdApi(api, "sdapi/v1/interrupt", { timeoutMs: 10_000 });
await fetchSdApi(api, "sdapi/v1/interrupt", { body: {}, timeoutMs: 10_000 });
}
}
}
@ -298,119 +298,3 @@ export class SdApiError extends Error {
super(message);
}
}
export function getPngInfo(pngData: Uint8Array): string | undefined {
return PngChunksExtract.default(pngData)
.filter((chunk) => chunk.name === "tEXt")
.map((chunk) => PngChunkText.decode(chunk.data))
.find((textChunk) => textChunk.keyword === "parameters")
?.text;
}
export interface PngInfo {
prompt: string;
negative_prompt: string;
steps: number;
cfg_scale: number;
width: number;
height: number;
sampler_name: string;
seed: number;
denoising_strength: number;
}
export function parsePngInfo(pngInfo: string): Partial<PngInfo> {
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
let part: "prompt" | "negative_prompt" | "params" = "prompt";
const params: Partial<PngInfo> = {};
const prompt: string[] = [];
const negativePrompt: string[] = [];
for (const tag of tags) {
const paramValuePair = tag.trim().match(/^(\w+\s*\w*):\s+(.*)$/u);
if (paramValuePair) {
const [, param, value] = paramValuePair;
switch (param.replace(/\s+/u, "").toLowerCase()) {
case "positiveprompt":
case "positive":
case "prompt":
case "pos":
part = "prompt";
prompt.push(value.trim());
break;
case "negativeprompt":
case "negative":
case "neg":
part = "negative_prompt";
negativePrompt.push(value.trim());
break;
case "steps":
case "cycles": {
part = "params";
const steps = Number(value.trim());
if (steps > 0) params.steps = Math.min(steps, 50);
break;
}
case "cfgscale":
case "cfg":
case "detail": {
part = "params";
const cfgScale = Number(value.trim());
if (cfgScale > 0) params.cfg_scale = Math.min(cfgScale, 20);
break;
}
case "size":
case "resolution": {
part = "params";
const [width, height] = value.trim()
.split(/\s*[x,]\s*/u, 2)
.map((v) => v.trim())
.map(Number);
if (width > 0 && height > 0) {
params.width = Math.min(width, 2048);
params.height = Math.min(height, 2048);
}
break;
}
case "denoisingstrength":
case "denoising":
case "denoise": {
part = "params";
// allow percent or decimal
let denoisingStrength: number;
if (value.trim().endsWith("%")) {
denoisingStrength = Number(value.trim().slice(0, -1).trim()) / 100;
} else {
denoisingStrength = Number(value.trim());
}
denoisingStrength = Math.min(Math.max(denoisingStrength, 0), 1);
params.denoising_strength = denoisingStrength;
break;
}
case "seed":
case "model":
case "modelhash":
case "modelname":
case "sampler":
part = "params";
// ignore for now
break;
default:
break;
}
} else if (tag.trim().length > 0) {
switch (part) {
case "prompt":
prompt.push(tag.trim());
break;
case "negative_prompt":
negativePrompt.push(tag.trim());
break;
default:
break;
}
}
}
if (prompt.length > 0) params.prompt = prompt.join(", ");
if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", ");
return params;
}

60
common/utils.ts Normal file
View File

@ -0,0 +1,60 @@
import { GrammyParseMode, GrammyTypes } from "../deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
export const fmt = (
rawStringParts: TemplateStringsArray | GrammyParseMode.Stringable[],
...stringLikes: GrammyParseMode.Stringable[]
): GrammyParseMode.FormattedString => {
let text = "";
const entities: GrammyTypes.MessageEntity[] = [];
const length = Math.max(rawStringParts.length, stringLikes.length);
for (let i = 0; i < length; i++) {
for (const stringLike of [rawStringParts[i], stringLikes[i]]) {
if (stringLike instanceof GrammyParseMode.FormattedString) {
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
}
return new GrammyParseMode.FormattedString(text, entities);
};
export function formatUserChat(ctx: { from?: GrammyTypes.User; chat?: GrammyTypes.Chat }) {
const msg: string[] = [];
if (ctx.from) {
msg.push(ctx.from.first_name);
if (ctx.from.last_name) msg.push(ctx.from.last_name);
if (ctx.from.username) msg.push(`(@${ctx.from.username})`);
if (ctx.from.language_code) msg.push(`(${ctx.from.language_code.toUpperCase()})`);
}
if (ctx.chat) {
if (
ctx.chat.type === "group" ||
ctx.chat.type === "supergroup" ||
ctx.chat.type === "channel"
) {
msg.push("in");
msg.push(ctx.chat.title);
if (
(ctx.chat.type === "supergroup" || ctx.chat.type === "channel") &&
ctx.chat.username
) {
msg.push(`(@${ctx.chat.username})`);
}
}
}
return msg.join(" ");
}

View File

@ -1,5 +1,6 @@
import { GrammyTypes, IKV } from "../deps.ts";
import { PngInfo, SdTxt2ImgInfo } from "../sd.ts";
import { SdTxt2ImgInfo } from "../common/sdApi.ts";
import { PngInfo } from "../common/parsePngInfo.ts";
import { db } from "./db.ts";
export interface JobSchema {
@ -20,6 +21,7 @@ export interface JobSchema {
| {
type: "waiting";
message?: GrammyTypes.Message.TextMessage;
lastErrorDate?: Date;
}
| {
type: "processing";
@ -36,7 +38,12 @@ export interface JobSchema {
};
}
export const jobStore = new IKV.Store(db, "job", {
schema: new IKV.Schema<JobSchema>(),
indices: ["status.type"],
type JobIndices = {
"status.type": JobSchema["status"]["type"];
};
export const jobStore = new IKV.Store<JobSchema, JobIndices>(db, "job", {
indices: {
"status.type": { getValue: (job) => job.status.type },
},
});

10
deps.ts
View File

@ -5,7 +5,7 @@ export * as Collections from "https://deno.land/std@0.201.0/collections/mod.ts";
export * as Base64 from "https://deno.land/std@0.201.0/encoding/base64.ts";
export * as AsyncX from "https://deno.land/x/async@v2.0.2/mod.ts";
export * as ULID from "https://deno.land/x/ulid@v0.3.0/mod.ts";
export * as IKV from "https://deno.land/x/indexed_kv@v0.2.0/mod.ts";
export * as IKV from "https://deno.land/x/indexed_kv@v0.3.0/mod.ts";
export * as Grammy from "https://deno.land/x/grammy@v1.18.1/mod.ts";
export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts";
export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts";
@ -13,8 +13,6 @@ export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mo
export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts";
export * as GrammyStatelessQ from "https://deno.land/x/grammy_stateless_question_alpha@v3.0.3/mod.ts";
export * as GrammyFiles from "https://deno.land/x/grammy_files@v1.0.4/mod.ts";
export * as FileType from "npm:file-type@18.5.0";
// @deno-types="./types/png-chunks-extract.d.ts"
export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0";
// @deno-types="./types/png-chunk-text.d.ts"
export * as PngChunkText from "npm:png-chunk-text@1.0.0";
export * as FileType from "https://esm.sh/file-type@18.5.0";
export { default as pngChunksExtract } from "https://esm.sh/png-chunks-extract@1.0.0";
export { decode as pngChunkTextDecode } from "https://esm.sh/png-chunk-text@1.0.0";

View File

@ -1,6 +1,6 @@
import { Async, Log } from "../deps.ts";
import { getGlobalSession } from "../bot/session.ts";
import { sdGetConfig } from "../sd.ts";
import { sdGetConfig } from "../common/sdApi.ts";
const logger = () => Log.getLogger();

View File

@ -11,8 +11,14 @@ import {
} from "../deps.ts";
import { bot } from "../bot/mod.ts";
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
import { fmt, formatUserChat } from "../utils.ts";
import { SdApiError, sdImg2Img, SdProgressResponse, SdResponse, sdTxt2Img } from "../sd.ts";
import { fmt, formatUserChat } from "../common/utils.ts";
import {
SdApiError,
sdImg2Img,
SdProgressResponse,
SdResponse,
sdTxt2Img,
} from "../common/sdApi.ts";
import { JobSchema, jobStore } from "../db/jobStore.ts";
import { runningWorkers } from "./pingWorkers.ts";
@ -27,8 +33,12 @@ export async function processJobs(): Promise<never> {
await new Promise((resolve) => setTimeout(resolve, 1000));
try {
// get first waiting job
const job = await jobStore.getBy("status.type", "waiting").then((jobs) => jobs[0]);
const jobs = await jobStore.getBy("status.type", { value: "waiting" });
// get first waiting job which hasn't errored in last minute
const job = jobs.find((job) =>
job.value.status.type === "waiting" &&
(job.value.status.lastErrorDate?.getTime() ?? 0) < Date.now() - 60_000
);
if (!job) continue;
// find a worker to handle the job
@ -56,13 +66,20 @@ export async function processJobs(): Promise<never> {
logger().error(
`Job failed for ${formatUserChat(job.value)} via ${worker.id}: ${err}`,
);
if (job.value.status.type === "processing" && job.value.status.message) {
await bot.api.deleteMessage(
job.value.status.message.chat.id,
job.value.status.message.message_id,
).catch(() => undefined);
}
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
await bot.api.sendMessage(
job.value.chat.id,
`Failed to generate your prompt using ${worker.name}: ${err.message}`,
{ reply_to_message_id: job.value.requestMessageId },
).catch(() => undefined);
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
await job.update({ status: { type: "waiting", lastErrorDate: new Date() } })
.catch(() => undefined);
}
if (
err instanceof SdApiError &&
@ -163,10 +180,9 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
const handleProgress = async (progress: SdProgressResponse) => {
// Important: don't let any errors escape this function
if (job.value.status.type === "processing" && job.value.status.message) {
if (job.value.status.progress === progress.progress) return;
await Promise.all([
bot.api.sendChatAction(job.value.chat.id, "upload_photo", { maxAttempts: 1 }),
bot.api.editMessageText(
progress.progress > job.value.status.progress && bot.api.editMessageText(
job.value.status.message.chat.id,
job.value.status.message.message_id,
`Generating your prompt now... ${
@ -197,7 +213,14 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
case "txt2img":
response = await sdTxt2Img(
worker.api,
{ ...config.defaultParams, ...job.value.task.params, ...size },
{
...config.defaultParams,
...job.value.task.params,
...size,
negative_prompt: job.value.task.params.negative_prompt
? job.value.task.params.negative_prompt
: config.defaultParams?.negative_prompt,
},
handleProgress,
);
break;

View File

@ -1,5 +1,5 @@
import { FmtDuration, Log } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { formatUserChat } from "../common/utils.ts";
import { jobStore } from "../db/jobStore.ts";
const logger = () => Log.getLogger();
@ -11,13 +11,19 @@ export async function returnHangedJobs(): Promise<never> {
while (true) {
try {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getBy("status.type", "processing");
const jobs = await jobStore.getBy("status.type", { value: "processing" });
for (const job of jobs) {
if (job.value.status.type !== "processing") continue;
// if job wasn't updated for 2 minutes, return it to the queue
const timeSinceLastUpdateMs = Date.now() - job.value.status.updatedDate.getTime();
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
await job.update({ status: { type: "waiting" } });
await job.update((value) => ({
...value,
status: {
type: "waiting",
message: value.status.type !== "done" ? value.status.message : undefined,
},
}));
logger().warning(
`Job for ${formatUserChat(job.value)} was returned to the queue because it hanged for ${
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {

View File

@ -1,6 +1,6 @@
import { Log } from "../deps.ts";
import { bot } from "../bot/mod.ts";
import { formatOrdinal } from "../utils.ts";
import { formatOrdinal } from "../common/utils.ts";
import { jobStore } from "../db/jobStore.ts";
const logger = () => Log.getLogger();
@ -12,7 +12,7 @@ export async function updateJobStatusMsgs(): Promise<never> {
while (true) {
try {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getBy("status.type", "waiting");
const jobs = await jobStore.getBy("status.type", { value: "waiting" });
for (const [index, job] of jobs.entries()) {
if (job.value.status.type !== "waiting" || !job.value.status.message) continue;
await bot.api.editMessageText(

View File

@ -1,2 +0,0 @@
export function decode(chunk: Uint8Array): { keyword: string; text: string };
export function encode(keyword: string, text: string): Uint8Array;

View File

@ -1 +0,0 @@
export default function encode(chunks: Array<{ name: string; data: Uint8Array }>): Uint8Array;

View File

@ -1 +0,0 @@
export default function extract(data: Uint8Array): Array<{ name: string; data: Uint8Array }>;

112
utils.ts
View File

@ -1,112 +0,0 @@
import { GrammyParseMode, GrammyTypes } from "./deps.ts";
export function formatOrdinal(n: number) {
if (n % 100 === 11 || n % 100 === 12 || n % 100 === 13) return `${n}th`;
if (n % 10 === 1) return `${n}st`;
if (n % 10 === 2) return `${n}nd`;
if (n % 10 === 3) return `${n}rd`;
return `${n}th`;
}
export const fmt = (
rawStringParts: TemplateStringsArray | GrammyParseMode.Stringable[],
...stringLikes: GrammyParseMode.Stringable[]
): GrammyParseMode.FormattedString => {
let text = "";
const entities: GrammyTypes.MessageEntity[] = [];
const length = Math.max(rawStringParts.length, stringLikes.length);
for (let i = 0; i < length; i++) {
for (const stringLike of [rawStringParts[i], stringLikes[i]]) {
if (stringLike instanceof GrammyParseMode.FormattedString) {
entities.push(
...stringLike.entities.map((e) => ({
...e,
offset: e.offset + text.length,
})),
);
}
if (stringLike != null) text += stringLike.toString();
}
}
return new GrammyParseMode.FormattedString(text, entities);
};
export function formatUserChat(ctx: { from?: GrammyTypes.User; chat?: GrammyTypes.Chat }) {
const msg: string[] = [];
if (ctx.from) {
msg.push(ctx.from.first_name);
if (ctx.from.last_name) msg.push(ctx.from.last_name);
if (ctx.from.username) msg.push(`(@${ctx.from.username})`);
if (ctx.from.language_code) msg.push(`(${ctx.from.language_code.toUpperCase()})`);
}
if (ctx.chat) {
if (
ctx.chat.type === "group" ||
ctx.chat.type === "supergroup" ||
ctx.chat.type === "channel"
) {
msg.push("in");
msg.push(ctx.chat.title);
if (
(ctx.chat.type === "supergroup" || ctx.chat.type === "channel") &&
ctx.chat.username
) {
msg.push(`(@${ctx.chat.username})`);
}
}
}
return msg.join(" ");
}
/** Language to biggest country emoji map */
const languageToFlagMap: Record<string, string> = {
"en": "🇺🇸", // English - United States
"zh": "🇨🇳", // Chinese - China
"es": "🇪🇸", // Spanish - Spain
"hi": "🇮🇳", // Hindi - India
"ar": "🇪🇬", // Arabic - Egypt
"pt": "🇧🇷", // Portuguese - Brazil
"bn": "🇧🇩", // Bengali - Bangladesh
"ru": "🇷🇺", // Russian - Russia
"ja": "🇯🇵", // Japanese - Japan
"pa": "🇮🇳", // Punjabi - India
"de": "🇩🇪", // German - Germany
"ko": "🇰🇷", // Korean - South Korea
"fr": "🇫🇷", // French - France
"tr": "🇹🇷", // Turkish - Turkey
"ur": "🇵🇰", // Urdu - Pakistan
"it": "🇮🇹", // Italian - Italy
"th": "🇹🇭", // Thai - Thailand
"vi": "🇻🇳", // Vietnamese - Vietnam
"pl": "🇵🇱", // Polish - Poland
"uk": "🇺🇦", // Ukrainian - Ukraine
"uz": "🇺🇿", // Uzbek - Uzbekistan
"su": "🇮🇩", // Sundanese - Indonesia
"sw": "🇹🇿", // Swahili - Tanzania
"nl": "🇳🇱", // Dutch - Netherlands
"fi": "🇫🇮", // Finnish - Finland
"el": "🇬🇷", // Greek - Greece
"da": "🇩🇰", // Danish - Denmark
"cs": "🇨🇿", // Czech - Czech Republic
"sk": "🇸🇰", // Slovak - Slovakia
"bg": "🇧🇬", // Bulgarian - Bulgaria
"sv": "🇸🇪", // Swedish - Sweden
"be": "🇧🇾", // Belarusian - Belarus
"hu": "🇭🇺", // Hungarian - Hungary
"lt": "🇱🇹", // Lithuanian - Lithuania
"lv": "🇱🇻", // Latvian - Latvia
"et": "🇪🇪", // Estonian - Estonia
"sl": "🇸🇮", // Slovenian - Slovenia
"hr": "🇭🇷", // Croatian - Croatia
"zu": "🇿🇦", // Zulu - South Africa
"id": "🇮🇩", // Indonesian - Indonesia
"is": "🇮🇸", // Icelandic - Iceland
"lb": "🇱🇺", // Luxembourgish - Luxembourg
};
export function getFlagEmoji(languageCode?: string): string | undefined {
const language = languageCode?.split("-").pop()?.toLowerCase();
if (!language) return;
return languageToFlagMap[language];
}