feat: img2img
This commit is contained in:
parent
5b3da79129
commit
5b365c9e17
|
@ -33,7 +33,8 @@ You can put these in `.env` file or pass them as environment variables.
|
|||
- [x] Replying to bot message, conversation in DMs
|
||||
- [x] Replying to png message to extract png info nad generate
|
||||
- [ ] Banning tags
|
||||
- [ ] Img2Img + Upscale
|
||||
- [x] Img2Img + Upscale
|
||||
- [ ] special param "scale" to change image size preserving aspect ratio
|
||||
- [ ] Admin WebUI
|
||||
- [ ] User daily generation limits
|
||||
- [ ] Querying all generation history, displaying stats
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
import { Collections, Grammy, GrammyStatelessQ } from "../deps.ts";
|
||||
import { formatUserChat } from "../utils.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
import { parsePngInfo, PngInfo } from "../sd.ts";
|
||||
import { Context, logger } from "./mod.ts";
|
||||
|
||||
export const img2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
|
||||
"img2img",
|
||||
async (ctx, state) => {
|
||||
// todo: also save original image size in state
|
||||
await img2img(ctx, ctx.message.text, false, state);
|
||||
},
|
||||
);
|
||||
|
||||
export async function img2imgCommand(ctx: Grammy.CommandContext<Context>) {
|
||||
await img2img(ctx, ctx.match, true);
|
||||
}
|
||||
|
||||
async function img2img(
|
||||
ctx: Context,
|
||||
match: string | undefined,
|
||||
includeRepliedTo: boolean,
|
||||
fileId?: string,
|
||||
): Promise<void> {
|
||||
if (!ctx.message?.from?.id) {
|
||||
await ctx.reply("I don't know who you are");
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctx.session.global.pausedReason != null) {
|
||||
await ctx.reply(`I'm paused: ${ctx.session.global.pausedReason || "No reason given"}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||
if (jobs.length >= ctx.session.global.maxJobs) {
|
||||
await ctx.reply(
|
||||
`The queue is full. Try again later. (Max queue size: ${ctx.session.global.maxJobs})`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
|
||||
if (userJobs.length >= ctx.session.global.maxUserJobs) {
|
||||
await ctx.reply(
|
||||
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let params: Partial<PngInfo> = {};
|
||||
|
||||
const repliedToMsg = ctx.message.reply_to_message;
|
||||
|
||||
if (includeRepliedTo && repliedToMsg?.photo) {
|
||||
const photos = repliedToMsg.photo;
|
||||
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
|
||||
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
|
||||
fileId = biggestPhoto.file_id;
|
||||
params.width = biggestPhoto.width;
|
||||
params.height = biggestPhoto.height;
|
||||
}
|
||||
|
||||
if (ctx.message.photo) {
|
||||
const photos = ctx.message.photo;
|
||||
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
|
||||
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
|
||||
fileId = biggestPhoto.file_id;
|
||||
params.width = biggestPhoto.width;
|
||||
params.height = biggestPhoto.height;
|
||||
}
|
||||
|
||||
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
|
||||
if (includeRepliedTo && repliedToText) {
|
||||
// TODO: remove bot command from replied to text
|
||||
const originalParams = parsePngInfo(repliedToText);
|
||||
params = {
|
||||
...originalParams,
|
||||
...params,
|
||||
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
|
||||
negative_prompt: [originalParams.negative_prompt, params.negative_prompt]
|
||||
.filter(Boolean).join("\n"),
|
||||
};
|
||||
}
|
||||
|
||||
const messageParams = parsePngInfo(match ?? "");
|
||||
params = {
|
||||
...params,
|
||||
...messageParams,
|
||||
prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"),
|
||||
};
|
||||
|
||||
if (!fileId) {
|
||||
await ctx.reply(
|
||||
"Please show me a picture to repaint." +
|
||||
img2imgQuestion.messageSuffixMarkdown(),
|
||||
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!params.prompt) {
|
||||
await ctx.reply(
|
||||
"Please describe the picture you want to repaint." +
|
||||
img2imgQuestion.messageSuffixMarkdown(fileId),
|
||||
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
|
||||
|
||||
await jobStore.create({
|
||||
task: { type: "img2img", params, fileId },
|
||||
from: ctx.message.from,
|
||||
chat: ctx.message.chat,
|
||||
requestMessageId: ctx.message.message_id,
|
||||
replyMessageId: replyMessage.message_id,
|
||||
status: { type: "waiting" },
|
||||
});
|
||||
|
||||
logger().debug(`Job enqueued for ${formatUserChat(ctx.message)}`);
|
||||
}
|
|
@ -4,6 +4,7 @@ import { session, SessionFlavor } from "./session.ts";
|
|||
import { queueCommand } from "./queueCommand.ts";
|
||||
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
|
||||
import { pnginfoCommand, pnginfoQuestion } from "./pnginfoCommand.ts";
|
||||
import { img2imgCommand, img2imgQuestion } from "./img2imgCommand.ts";
|
||||
|
||||
export const logger = () => Log.getLogger();
|
||||
|
||||
|
@ -74,7 +75,8 @@ bot.api.setMyDescription(
|
|||
"Send /txt2img to generate an image.",
|
||||
);
|
||||
bot.api.setMyCommands([
|
||||
{ command: "txt2img", description: "Generate an image" },
|
||||
{ command: "txt2img", description: "Generate an image from text" },
|
||||
{ command: "img2img", description: "Generate an image based on another image" },
|
||||
{ command: "pnginfo", description: "Show generation parameters of an image" },
|
||||
{ command: "queue", description: "Show the current queue" },
|
||||
]);
|
||||
|
@ -84,6 +86,9 @@ bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to gene
|
|||
bot.command("txt2img", txt2imgCommand);
|
||||
bot.use(txt2imgQuestion.middleware());
|
||||
|
||||
bot.command("img2img", img2imgCommand);
|
||||
bot.use(img2imgQuestion.middleware());
|
||||
|
||||
bot.command("pnginfo", pnginfoCommand);
|
||||
bot.use(pnginfoQuestion.middleware());
|
||||
|
||||
|
|
|
@ -21,16 +21,14 @@ export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
|
|||
...jobs.length > 0
|
||||
? jobs.flatMap((job) => [
|
||||
`${job.place}. `,
|
||||
fmt`${bold(job.request.from.first_name)} `,
|
||||
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
|
||||
job.request.from.username ? `(@${job.request.from.username}) ` : "",
|
||||
getFlagEmoji(job.request.from.language_code) ?? "",
|
||||
job.request.chat.type === "private"
|
||||
? " in private chat "
|
||||
: ` in ${job.request.chat.title} `,
|
||||
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
|
||||
job.request.chat.username
|
||||
? `(@${job.request.chat.username}) `
|
||||
fmt`${bold(job.from.first_name)} `,
|
||||
job.from.last_name ? fmt`${bold(job.from.last_name)} ` : "",
|
||||
job.from.username ? `(@${job.from.username}) ` : "",
|
||||
getFlagEmoji(job.from.language_code) ?? "",
|
||||
job.chat.type === "private" ? " in private chat " : ` in ${job.chat.title} `,
|
||||
job.chat.type !== "private" && job.chat.type !== "group" &&
|
||||
job.chat.username
|
||||
? `(@${job.chat.username}) `
|
||||
: "",
|
||||
job.status.type === "processing"
|
||||
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { Grammy, GrammyStatelessQ } from "../deps.ts";
|
||||
import { formatUserChat } from "../utils.ts";
|
||||
import { jobStore } from "../db/jobStore.ts";
|
||||
import { getPngInfo, parsePngInfo, SdTxt2ImgRequest } from "../sd.ts";
|
||||
import { getPngInfo, parsePngInfo, PngInfo } from "../sd.ts";
|
||||
import { Context, logger } from "./mod.ts";
|
||||
|
||||
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
|
||||
|
@ -35,7 +35,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
|||
return;
|
||||
}
|
||||
|
||||
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.message?.from?.id);
|
||||
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
|
||||
if (userJobs.length >= ctx.session.global.maxUserJobs) {
|
||||
await ctx.reply(
|
||||
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
|
||||
|
@ -43,7 +43,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
|||
return;
|
||||
}
|
||||
|
||||
let params: Partial<SdTxt2ImgRequest> = {};
|
||||
let params: Partial<PngInfo> = {};
|
||||
|
||||
const repliedToMsg = ctx.message.reply_to_message;
|
||||
|
||||
|
@ -92,9 +92,11 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
|||
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
|
||||
|
||||
await jobStore.create({
|
||||
params,
|
||||
request: ctx.message,
|
||||
reply: replyMessage,
|
||||
task: { type: "txt2img", params },
|
||||
from: ctx.message.from,
|
||||
chat: ctx.message.chat,
|
||||
requestMessageId: ctx.message.message_id,
|
||||
replyMessageId: replyMessage.message_id,
|
||||
status: { type: "waiting" },
|
||||
});
|
||||
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import { GrammyTypes, IKV } from "../deps.ts";
|
||||
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
|
||||
import { PngInfo, SdTxt2ImgInfo } from "../sd.ts";
|
||||
import { db } from "./db.ts";
|
||||
|
||||
export interface JobSchema {
|
||||
params: Partial<SdTxt2ImgRequest>;
|
||||
request: GrammyTypes.Message & { from: GrammyTypes.User };
|
||||
reply?: GrammyTypes.Message.TextMessage;
|
||||
task:
|
||||
| { type: "txt2img"; params: Partial<PngInfo> }
|
||||
| { type: "img2img"; params: Partial<PngInfo>; fileId: string };
|
||||
from: GrammyTypes.User;
|
||||
chat: GrammyTypes.Chat;
|
||||
requestMessageId: number;
|
||||
replyMessageId?: number;
|
||||
status:
|
||||
| { type: "waiting" }
|
||||
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
|
||||
|
|
171
sd.ts
171
sd.ts
|
@ -37,13 +37,50 @@ async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Prom
|
|||
return result;
|
||||
}
|
||||
|
||||
interface SdRequest {
|
||||
prompt: string;
|
||||
denoising_strength: number;
|
||||
styles: string[];
|
||||
negative_prompt: string;
|
||||
seed: number;
|
||||
subseed: number;
|
||||
subseed_strength: number;
|
||||
seed_resize_from_h: number;
|
||||
seed_resize_from_w: number;
|
||||
width: number;
|
||||
height: number;
|
||||
sampler_name: string;
|
||||
batch_size: number;
|
||||
n_iter: number;
|
||||
steps: number;
|
||||
cfg_scale: number;
|
||||
restore_faces: boolean;
|
||||
tiling: boolean;
|
||||
do_not_save_samples: boolean;
|
||||
do_not_save_grid: boolean;
|
||||
eta: number;
|
||||
s_min_uncond: number;
|
||||
s_churn: number;
|
||||
s_tmax: number;
|
||||
s_tmin: number;
|
||||
s_noise: number;
|
||||
override_settings: object;
|
||||
override_settings_restore_afterwards: boolean;
|
||||
script_args: unknown[];
|
||||
sampler_index: string;
|
||||
script_name: string;
|
||||
send_images: boolean;
|
||||
save_images: boolean;
|
||||
alwayson_scripts: object;
|
||||
}
|
||||
|
||||
export async function sdTxt2Img(
|
||||
api: SdApi,
|
||||
params: Partial<SdTxt2ImgRequest>,
|
||||
onProgress?: (progress: SdProgressResponse) => void,
|
||||
signal: AbortSignal = neverSignal,
|
||||
): Promise<SdTxt2ImgResponse> {
|
||||
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
|
||||
): Promise<SdResponse<SdTxt2ImgRequest>> {
|
||||
const request = fetchSdApi<SdResponse<SdTxt2ImgRequest>>(api, "sdapi/v1/txt2img", params)
|
||||
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
||||
.then((data) => ({
|
||||
...data,
|
||||
|
@ -63,57 +100,65 @@ export async function sdTxt2Img(
|
|||
}
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgRequest {
|
||||
export interface SdTxt2ImgRequest extends SdRequest {
|
||||
enable_hr: boolean;
|
||||
denoising_strength: number;
|
||||
firstphase_width: number;
|
||||
firstphase_height: number;
|
||||
hr_scale: number;
|
||||
hr_upscaler: unknown;
|
||||
hr_second_pass_steps: number;
|
||||
firstphase_width: number;
|
||||
hr_resize_x: number;
|
||||
hr_resize_y: number;
|
||||
hr_sampler_name: unknown;
|
||||
hr_prompt: string;
|
||||
hr_negative_prompt: string;
|
||||
prompt: string;
|
||||
styles: unknown;
|
||||
seed: number;
|
||||
subseed: number;
|
||||
subseed_strength: number;
|
||||
seed_resize_from_h: number;
|
||||
seed_resize_from_w: number;
|
||||
sampler_name: unknown;
|
||||
batch_size: number;
|
||||
n_iter: number;
|
||||
steps: number;
|
||||
cfg_scale: number;
|
||||
width: number;
|
||||
height: number;
|
||||
restore_faces: boolean;
|
||||
tiling: boolean;
|
||||
do_not_save_samples: boolean;
|
||||
do_not_save_grid: boolean;
|
||||
negative_prompt: string;
|
||||
eta: unknown;
|
||||
s_min_uncond: number;
|
||||
s_churn: number;
|
||||
s_tmax: unknown;
|
||||
s_tmin: number;
|
||||
s_noise: number;
|
||||
override_settings: object;
|
||||
override_settings_restore_afterwards: boolean;
|
||||
script_args: unknown[];
|
||||
sampler_index: string;
|
||||
script_name: unknown;
|
||||
send_images: boolean;
|
||||
save_images: boolean;
|
||||
alwayson_scripts: object;
|
||||
hr_prompt: string;
|
||||
hr_resize_y: number;
|
||||
hr_sampler_name: string;
|
||||
hr_scale: number;
|
||||
hr_second_pass_steps: number;
|
||||
hr_upscaler: string;
|
||||
}
|
||||
|
||||
export interface SdTxt2ImgResponse {
|
||||
export async function sdImg2Img(
|
||||
api: SdApi,
|
||||
params: Partial<SdImg2ImgRequest>,
|
||||
onProgress?: (progress: SdProgressResponse) => void,
|
||||
signal: AbortSignal = neverSignal,
|
||||
): Promise<SdResponse<SdImg2ImgRequest>> {
|
||||
const request = fetchSdApi<SdResponse<SdImg2ImgRequest>>(api, "sdapi/v1/img2img", params)
|
||||
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
||||
.then((data) => ({
|
||||
...data,
|
||||
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
|
||||
}));
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
await Async.abortable(Promise.race([request, Async.delay(4000)]), signal);
|
||||
if (await AsyncX.promiseState(request) !== "pending") return await request;
|
||||
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
|
||||
}
|
||||
} finally {
|
||||
if (await AsyncX.promiseState(request) === "pending") {
|
||||
await fetchSdApi(api, "sdapi/v1/interrupt", {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export interface SdImg2ImgRequest extends SdRequest {
|
||||
image_cfg_scale: number;
|
||||
include_init_images: boolean;
|
||||
init_images: string[];
|
||||
initial_noise_multiplier: number;
|
||||
inpaint_full_res: boolean;
|
||||
inpaint_full_res_padding: number;
|
||||
inpainting_fill: number;
|
||||
inpainting_mask_invert: number;
|
||||
mask: string;
|
||||
mask_blur: number;
|
||||
mask_blur_x: number;
|
||||
mask_blur_y: number;
|
||||
resize_mode: number;
|
||||
}
|
||||
|
||||
export interface SdResponse<T> {
|
||||
images: string[];
|
||||
parameters: SdTxt2ImgRequest;
|
||||
parameters: T;
|
||||
// Warning: raw response from API is a JSON-serialized string
|
||||
info: SdTxt2ImgInfo;
|
||||
}
|
||||
|
@ -242,10 +287,22 @@ export function getPngInfo(pngData: Uint8Array): string | undefined {
|
|||
?.text;
|
||||
}
|
||||
|
||||
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
|
||||
export interface PngInfo {
|
||||
prompt: string;
|
||||
negative_prompt: string;
|
||||
steps: number;
|
||||
cfg_scale: number;
|
||||
width: number;
|
||||
height: number;
|
||||
sampler_name: string;
|
||||
seed: number;
|
||||
denoising_strength: number;
|
||||
}
|
||||
|
||||
export function parsePngInfo(pngInfo: string): Partial<PngInfo> {
|
||||
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
|
||||
let part: "prompt" | "negative_prompt" | "params" = "prompt";
|
||||
const params: Partial<SdTxt2ImgRequest> = {};
|
||||
const params: Partial<PngInfo> = {};
|
||||
const prompt: string[] = [];
|
||||
const negativePrompt: string[] = [];
|
||||
for (const tag of tags) {
|
||||
|
@ -294,14 +351,26 @@ export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case "denoisingstrength":
|
||||
case "denoising":
|
||||
case "denoise": {
|
||||
part = "params";
|
||||
// allow percent or decimal
|
||||
let denoisingStrength: number;
|
||||
if (value.trim().endsWith("%")) {
|
||||
denoisingStrength = Number(value.trim().slice(0, -1).trim()) / 100;
|
||||
} else {
|
||||
denoisingStrength = Number(value.trim());
|
||||
}
|
||||
denoisingStrength = Math.min(Math.max(denoisingStrength, 0), 1);
|
||||
params.denoising_strength = denoisingStrength;
|
||||
break;
|
||||
}
|
||||
case "seed":
|
||||
case "model":
|
||||
case "modelhash":
|
||||
case "modelname":
|
||||
case "sampler":
|
||||
case "denoisingstrength":
|
||||
case "denoising":
|
||||
case "denoise":
|
||||
part = "params";
|
||||
// ignore for now
|
||||
break;
|
||||
|
|
|
@ -12,7 +12,7 @@ import {
|
|||
import { bot } from "../bot/mod.ts";
|
||||
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
|
||||
import { fmt, formatUserChat } from "../utils.ts";
|
||||
import { SdApiError, sdTxt2Img } from "../sd.ts";
|
||||
import { SdApiError, sdImg2Img, SdProgressResponse, SdResponse, sdTxt2Img } from "../sd.ts";
|
||||
import { JobSchema, jobStore } from "../db/jobStore.ts";
|
||||
import { runningWorkers } from "./pingWorkers.ts";
|
||||
|
||||
|
@ -48,13 +48,13 @@ export async function processJobs(): Promise<never> {
|
|||
processJob(job, worker, config)
|
||||
.catch(async (err) => {
|
||||
logger().error(
|
||||
`Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`,
|
||||
`Job failed for ${formatUserChat(job.value)} via ${worker.name}: ${err}`,
|
||||
);
|
||||
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
|
||||
await bot.api.sendMessage(
|
||||
job.value.request.chat.id,
|
||||
job.value.chat.id,
|
||||
`Failed to generate your prompt using ${worker.name}: ${err.message}`,
|
||||
{ reply_to_message_id: job.value.request.message_id },
|
||||
{ reply_to_message_id: job.value.requestMessageId },
|
||||
).catch(() => undefined);
|
||||
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
|
||||
}
|
||||
|
@ -85,21 +85,21 @@ export async function processJobs(): Promise<never> {
|
|||
|
||||
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
|
||||
logger().debug(
|
||||
`Job started for ${formatUserChat(job.value.request)} using ${worker.name}`,
|
||||
`Job started for ${formatUserChat(job.value)} using ${worker.name}`,
|
||||
);
|
||||
const startDate = new Date();
|
||||
|
||||
// if there is already a status message delete it
|
||||
if (job.value.reply) {
|
||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
||||
if (job.value.replyMessageId) {
|
||||
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
|
||||
.catch(() => undefined);
|
||||
}
|
||||
|
||||
// send a new status message
|
||||
const newStatusMessage = await bot.api.sendMessage(
|
||||
job.value.request.chat.id,
|
||||
job.value.chat.id,
|
||||
`Generating your prompt now... 0% using ${worker.name}`,
|
||||
{ reply_to_message_id: job.value.request.message_id },
|
||||
{ reply_to_message_id: job.value.requestMessageId },
|
||||
).catch((err) => {
|
||||
// don't error if the request message was deleted
|
||||
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
|
||||
|
@ -109,47 +109,71 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
|||
if (!newStatusMessage) {
|
||||
await job.delete();
|
||||
logger().info(
|
||||
`Job cancelled for ${formatUserChat(job.value.request)}`,
|
||||
`Job cancelled for ${formatUserChat(job.value)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
await job.update({ reply: newStatusMessage });
|
||||
await job.update({ replyMessageId: newStatusMessage.message_id });
|
||||
|
||||
// reduce size if worker can't handle the resolution
|
||||
const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution);
|
||||
|
||||
// process the job
|
||||
const response = await sdTxt2Img(
|
||||
worker.api,
|
||||
{ ...config.defaultParams, ...job.value.params, ...size },
|
||||
async (progress) => {
|
||||
// important: don't let any errors escape this callback
|
||||
if (job.value.reply) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
`Generating your prompt now... ${
|
||||
(progress.progress * 100).toFixed(0)
|
||||
}% using ${worker.name}`,
|
||||
{ maxAttempts: 1 },
|
||||
).catch(() => undefined);
|
||||
}
|
||||
await job.update({
|
||||
status: {
|
||||
type: "processing",
|
||||
progress: progress.progress,
|
||||
worker: worker.name,
|
||||
updatedDate: new Date(),
|
||||
},
|
||||
}, { maxAttempts: 1 }).catch(() => undefined);
|
||||
},
|
||||
const size = limitSize(
|
||||
{ ...config.defaultParams, ...job.value.task.params },
|
||||
worker.maxResolution,
|
||||
);
|
||||
|
||||
// process the job
|
||||
const handleProgress = async (progress: SdProgressResponse) => {
|
||||
// important: don't let any errors escape this callback
|
||||
if (job.value.replyMessageId) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.chat.id,
|
||||
job.value.replyMessageId,
|
||||
`Generating your prompt now... ${
|
||||
(progress.progress * 100).toFixed(0)
|
||||
}% using ${worker.name}`,
|
||||
{ maxAttempts: 1 },
|
||||
).catch(() => undefined);
|
||||
}
|
||||
await job.update({
|
||||
status: {
|
||||
type: "processing",
|
||||
progress: progress.progress,
|
||||
worker: worker.name,
|
||||
updatedDate: new Date(),
|
||||
},
|
||||
}, { maxAttempts: 1 }).catch(() => undefined);
|
||||
};
|
||||
let response: SdResponse<unknown>;
|
||||
const taskType = job.value.task.type; // don't narrow this to never pls typescript
|
||||
switch (job.value.task.type) {
|
||||
case "txt2img":
|
||||
response = await sdTxt2Img(
|
||||
worker.api,
|
||||
{ ...config.defaultParams, ...job.value.task.params, ...size },
|
||||
handleProgress,
|
||||
);
|
||||
break;
|
||||
case "img2img": {
|
||||
const file = await bot.api.getFile(job.value.task.fileId);
|
||||
const fileUrl = `https://api.telegram.org/file/bot${bot.token}/${file.file_path}`;
|
||||
const fileBuffer = await fetch(fileUrl).then((resp) => resp.arrayBuffer());
|
||||
const fileBase64 = Base64.encode(fileBuffer);
|
||||
response = await sdImg2Img(
|
||||
worker.api,
|
||||
{ ...config.defaultParams, ...job.value.task.params, ...size, init_images: [fileBase64] },
|
||||
handleProgress,
|
||||
);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown task type: ${taskType}`);
|
||||
}
|
||||
|
||||
// upload the result
|
||||
if (job.value.reply) {
|
||||
if (job.value.replyMessageId) {
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
job.value.chat.id,
|
||||
job.value.replyMessageId,
|
||||
`Uploading your images...`,
|
||||
).catch(() => undefined);
|
||||
}
|
||||
|
@ -200,31 +224,31 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
|||
|
||||
// send the result to telegram
|
||||
try {
|
||||
resultMessages = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, {
|
||||
reply_to_message_id: job.value.request.message_id,
|
||||
resultMessages = await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
|
||||
reply_to_message_id: job.value.requestMessageId,
|
||||
maxAttempts: 5,
|
||||
});
|
||||
break;
|
||||
} catch (err) {
|
||||
logger().warning(`Sending images (attempt ${sendMediaAttempt}) failed: ${err}`);
|
||||
if (sendMediaAttempt >= 5) throw err;
|
||||
await Async.delay(15000);
|
||||
if (sendMediaAttempt >= 6) throw err;
|
||||
await Async.delay(10000);
|
||||
}
|
||||
}
|
||||
|
||||
// send caption in separate message if it couldn't fit
|
||||
if (caption.text.length > 1024 && caption.text.length <= 4096) {
|
||||
await bot.api.sendMessage(job.value.request.chat.id, caption.text, {
|
||||
await bot.api.sendMessage(job.value.chat.id, caption.text, {
|
||||
reply_to_message_id: resultMessages[0].message_id,
|
||||
entities: caption.entities,
|
||||
});
|
||||
}
|
||||
|
||||
// delete the status message
|
||||
if (job.value.reply) {
|
||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
||||
if (job.value.replyMessageId) {
|
||||
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
|
||||
.catch(() => undefined)
|
||||
.then(() => job.update({ reply: undefined }))
|
||||
.then(() => job.update({ replyMessageId: undefined }))
|
||||
.catch(() => undefined);
|
||||
}
|
||||
|
||||
|
@ -233,7 +257,7 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
|||
status: { type: "done", info: response.info, startDate, endDate: new Date() },
|
||||
});
|
||||
logger().debug(
|
||||
`Job finished for ${formatUserChat(job.value.request)} using ${worker.name}${
|
||||
`Job finished for ${formatUserChat(job.value)} using ${worker.name}${
|
||||
sendMediaAttempt > 1 ? ` after ${sendMediaAttempt} attempts` : ""
|
||||
}`,
|
||||
);
|
||||
|
|
|
@ -19,9 +19,7 @@ export async function returnHangedJobs(): Promise<never> {
|
|||
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
|
||||
await job.update({ status: { type: "waiting" } });
|
||||
logger().warning(
|
||||
`Job for ${
|
||||
formatUserChat(job.value.request)
|
||||
} was returned to the queue because it hanged for ${
|
||||
`Job for ${formatUserChat(job.value)} was returned to the queue because it hanged for ${
|
||||
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
|
||||
ignoreZero: true,
|
||||
})
|
||||
|
|
|
@ -14,10 +14,10 @@ export async function updateJobStatusMsgs(): Promise<never> {
|
|||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||
for (const [index, job] of jobs.entries()) {
|
||||
if (!job.value.reply) continue;
|
||||
if (!job.value.replyMessageId) continue;
|
||||
await bot.api.editMessageText(
|
||||
job.value.reply.chat.id,
|
||||
job.value.reply.message_id,
|
||||
job.value.chat.id,
|
||||
job.value.replyMessageId,
|
||||
`You are ${formatOrdinal(index + 1)} in queue.`,
|
||||
{ maxAttempts: 1 },
|
||||
).catch(() => undefined);
|
||||
|
|
Loading…
Reference in New Issue