feat: img2img
This commit is contained in:
parent
5b3da79129
commit
5b365c9e17
|
@ -33,7 +33,8 @@ You can put these in `.env` file or pass them as environment variables.
|
||||||
- [x] Replying to bot message, conversation in DMs
|
- [x] Replying to bot message, conversation in DMs
|
||||||
- [x] Replying to png message to extract png info nad generate
|
- [x] Replying to png message to extract png info nad generate
|
||||||
- [ ] Banning tags
|
- [ ] Banning tags
|
||||||
- [ ] Img2Img + Upscale
|
- [x] Img2Img + Upscale
|
||||||
|
- [ ] special param "scale" to change image size preserving aspect ratio
|
||||||
- [ ] Admin WebUI
|
- [ ] Admin WebUI
|
||||||
- [ ] User daily generation limits
|
- [ ] User daily generation limits
|
||||||
- [ ] Querying all generation history, displaying stats
|
- [ ] Querying all generation history, displaying stats
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
import { Collections, Grammy, GrammyStatelessQ } from "../deps.ts";
|
||||||
|
import { formatUserChat } from "../utils.ts";
|
||||||
|
import { jobStore } from "../db/jobStore.ts";
|
||||||
|
import { parsePngInfo, PngInfo } from "../sd.ts";
|
||||||
|
import { Context, logger } from "./mod.ts";
|
||||||
|
|
||||||
|
export const img2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
|
||||||
|
"img2img",
|
||||||
|
async (ctx, state) => {
|
||||||
|
// todo: also save original image size in state
|
||||||
|
await img2img(ctx, ctx.message.text, false, state);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
export async function img2imgCommand(ctx: Grammy.CommandContext<Context>) {
|
||||||
|
await img2img(ctx, ctx.match, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function img2img(
|
||||||
|
ctx: Context,
|
||||||
|
match: string | undefined,
|
||||||
|
includeRepliedTo: boolean,
|
||||||
|
fileId?: string,
|
||||||
|
): Promise<void> {
|
||||||
|
if (!ctx.message?.from?.id) {
|
||||||
|
await ctx.reply("I don't know who you are");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx.session.global.pausedReason != null) {
|
||||||
|
await ctx.reply(`I'm paused: ${ctx.session.global.pausedReason || "No reason given"}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||||
|
if (jobs.length >= ctx.session.global.maxJobs) {
|
||||||
|
await ctx.reply(
|
||||||
|
`The queue is full. Try again later. (Max queue size: ${ctx.session.global.maxJobs})`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
|
||||||
|
if (userJobs.length >= ctx.session.global.maxUserJobs) {
|
||||||
|
await ctx.reply(
|
||||||
|
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let params: Partial<PngInfo> = {};
|
||||||
|
|
||||||
|
const repliedToMsg = ctx.message.reply_to_message;
|
||||||
|
|
||||||
|
if (includeRepliedTo && repliedToMsg?.photo) {
|
||||||
|
const photos = repliedToMsg.photo;
|
||||||
|
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
|
||||||
|
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
|
||||||
|
fileId = biggestPhoto.file_id;
|
||||||
|
params.width = biggestPhoto.width;
|
||||||
|
params.height = biggestPhoto.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx.message.photo) {
|
||||||
|
const photos = ctx.message.photo;
|
||||||
|
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
|
||||||
|
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
|
||||||
|
fileId = biggestPhoto.file_id;
|
||||||
|
params.width = biggestPhoto.width;
|
||||||
|
params.height = biggestPhoto.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
|
||||||
|
if (includeRepliedTo && repliedToText) {
|
||||||
|
// TODO: remove bot command from replied to text
|
||||||
|
const originalParams = parsePngInfo(repliedToText);
|
||||||
|
params = {
|
||||||
|
...originalParams,
|
||||||
|
...params,
|
||||||
|
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
|
||||||
|
negative_prompt: [originalParams.negative_prompt, params.negative_prompt]
|
||||||
|
.filter(Boolean).join("\n"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageParams = parsePngInfo(match ?? "");
|
||||||
|
params = {
|
||||||
|
...params,
|
||||||
|
...messageParams,
|
||||||
|
prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"),
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!fileId) {
|
||||||
|
await ctx.reply(
|
||||||
|
"Please show me a picture to repaint." +
|
||||||
|
img2imgQuestion.messageSuffixMarkdown(),
|
||||||
|
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.prompt) {
|
||||||
|
await ctx.reply(
|
||||||
|
"Please describe the picture you want to repaint." +
|
||||||
|
img2imgQuestion.messageSuffixMarkdown(fileId),
|
||||||
|
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
|
||||||
|
|
||||||
|
await jobStore.create({
|
||||||
|
task: { type: "img2img", params, fileId },
|
||||||
|
from: ctx.message.from,
|
||||||
|
chat: ctx.message.chat,
|
||||||
|
requestMessageId: ctx.message.message_id,
|
||||||
|
replyMessageId: replyMessage.message_id,
|
||||||
|
status: { type: "waiting" },
|
||||||
|
});
|
||||||
|
|
||||||
|
logger().debug(`Job enqueued for ${formatUserChat(ctx.message)}`);
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import { session, SessionFlavor } from "./session.ts";
|
||||||
import { queueCommand } from "./queueCommand.ts";
|
import { queueCommand } from "./queueCommand.ts";
|
||||||
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
|
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
|
||||||
import { pnginfoCommand, pnginfoQuestion } from "./pnginfoCommand.ts";
|
import { pnginfoCommand, pnginfoQuestion } from "./pnginfoCommand.ts";
|
||||||
|
import { img2imgCommand, img2imgQuestion } from "./img2imgCommand.ts";
|
||||||
|
|
||||||
export const logger = () => Log.getLogger();
|
export const logger = () => Log.getLogger();
|
||||||
|
|
||||||
|
@ -74,7 +75,8 @@ bot.api.setMyDescription(
|
||||||
"Send /txt2img to generate an image.",
|
"Send /txt2img to generate an image.",
|
||||||
);
|
);
|
||||||
bot.api.setMyCommands([
|
bot.api.setMyCommands([
|
||||||
{ command: "txt2img", description: "Generate an image" },
|
{ command: "txt2img", description: "Generate an image from text" },
|
||||||
|
{ command: "img2img", description: "Generate an image based on another image" },
|
||||||
{ command: "pnginfo", description: "Show generation parameters of an image" },
|
{ command: "pnginfo", description: "Show generation parameters of an image" },
|
||||||
{ command: "queue", description: "Show the current queue" },
|
{ command: "queue", description: "Show the current queue" },
|
||||||
]);
|
]);
|
||||||
|
@ -84,6 +86,9 @@ bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to gene
|
||||||
bot.command("txt2img", txt2imgCommand);
|
bot.command("txt2img", txt2imgCommand);
|
||||||
bot.use(txt2imgQuestion.middleware());
|
bot.use(txt2imgQuestion.middleware());
|
||||||
|
|
||||||
|
bot.command("img2img", img2imgCommand);
|
||||||
|
bot.use(img2imgQuestion.middleware());
|
||||||
|
|
||||||
bot.command("pnginfo", pnginfoCommand);
|
bot.command("pnginfo", pnginfoCommand);
|
||||||
bot.use(pnginfoQuestion.middleware());
|
bot.use(pnginfoQuestion.middleware());
|
||||||
|
|
||||||
|
|
|
@ -21,16 +21,14 @@ export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
|
||||||
...jobs.length > 0
|
...jobs.length > 0
|
||||||
? jobs.flatMap((job) => [
|
? jobs.flatMap((job) => [
|
||||||
`${job.place}. `,
|
`${job.place}. `,
|
||||||
fmt`${bold(job.request.from.first_name)} `,
|
fmt`${bold(job.from.first_name)} `,
|
||||||
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
|
job.from.last_name ? fmt`${bold(job.from.last_name)} ` : "",
|
||||||
job.request.from.username ? `(@${job.request.from.username}) ` : "",
|
job.from.username ? `(@${job.from.username}) ` : "",
|
||||||
getFlagEmoji(job.request.from.language_code) ?? "",
|
getFlagEmoji(job.from.language_code) ?? "",
|
||||||
job.request.chat.type === "private"
|
job.chat.type === "private" ? " in private chat " : ` in ${job.chat.title} `,
|
||||||
? " in private chat "
|
job.chat.type !== "private" && job.chat.type !== "group" &&
|
||||||
: ` in ${job.request.chat.title} `,
|
job.chat.username
|
||||||
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
|
? `(@${job.chat.username}) `
|
||||||
job.request.chat.username
|
|
||||||
? `(@${job.request.chat.username}) `
|
|
||||||
: "",
|
: "",
|
||||||
job.status.type === "processing"
|
job.status.type === "processing"
|
||||||
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
|
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { Grammy, GrammyStatelessQ } from "../deps.ts";
|
import { Grammy, GrammyStatelessQ } from "../deps.ts";
|
||||||
import { formatUserChat } from "../utils.ts";
|
import { formatUserChat } from "../utils.ts";
|
||||||
import { jobStore } from "../db/jobStore.ts";
|
import { jobStore } from "../db/jobStore.ts";
|
||||||
import { getPngInfo, parsePngInfo, SdTxt2ImgRequest } from "../sd.ts";
|
import { getPngInfo, parsePngInfo, PngInfo } from "../sd.ts";
|
||||||
import { Context, logger } from "./mod.ts";
|
import { Context, logger } from "./mod.ts";
|
||||||
|
|
||||||
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
|
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
|
||||||
|
@ -35,7 +35,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.message?.from?.id);
|
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
|
||||||
if (userJobs.length >= ctx.session.global.maxUserJobs) {
|
if (userJobs.length >= ctx.session.global.maxUserJobs) {
|
||||||
await ctx.reply(
|
await ctx.reply(
|
||||||
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
|
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
|
||||||
|
@ -43,7 +43,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let params: Partial<SdTxt2ImgRequest> = {};
|
let params: Partial<PngInfo> = {};
|
||||||
|
|
||||||
const repliedToMsg = ctx.message.reply_to_message;
|
const repliedToMsg = ctx.message.reply_to_message;
|
||||||
|
|
||||||
|
@ -92,9 +92,11 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
|
||||||
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
|
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
|
||||||
|
|
||||||
await jobStore.create({
|
await jobStore.create({
|
||||||
params,
|
task: { type: "txt2img", params },
|
||||||
request: ctx.message,
|
from: ctx.message.from,
|
||||||
reply: replyMessage,
|
chat: ctx.message.chat,
|
||||||
|
requestMessageId: ctx.message.message_id,
|
||||||
|
replyMessageId: replyMessage.message_id,
|
||||||
status: { type: "waiting" },
|
status: { type: "waiting" },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
import { GrammyTypes, IKV } from "../deps.ts";
|
import { GrammyTypes, IKV } from "../deps.ts";
|
||||||
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
|
import { PngInfo, SdTxt2ImgInfo } from "../sd.ts";
|
||||||
import { db } from "./db.ts";
|
import { db } from "./db.ts";
|
||||||
|
|
||||||
export interface JobSchema {
|
export interface JobSchema {
|
||||||
params: Partial<SdTxt2ImgRequest>;
|
task:
|
||||||
request: GrammyTypes.Message & { from: GrammyTypes.User };
|
| { type: "txt2img"; params: Partial<PngInfo> }
|
||||||
reply?: GrammyTypes.Message.TextMessage;
|
| { type: "img2img"; params: Partial<PngInfo>; fileId: string };
|
||||||
|
from: GrammyTypes.User;
|
||||||
|
chat: GrammyTypes.Chat;
|
||||||
|
requestMessageId: number;
|
||||||
|
replyMessageId?: number;
|
||||||
status:
|
status:
|
||||||
| { type: "waiting" }
|
| { type: "waiting" }
|
||||||
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
|
| { type: "processing"; progress: number; worker: string; updatedDate: Date }
|
||||||
|
|
171
sd.ts
171
sd.ts
|
@ -37,13 +37,50 @@ async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Prom
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface SdRequest {
|
||||||
|
prompt: string;
|
||||||
|
denoising_strength: number;
|
||||||
|
styles: string[];
|
||||||
|
negative_prompt: string;
|
||||||
|
seed: number;
|
||||||
|
subseed: number;
|
||||||
|
subseed_strength: number;
|
||||||
|
seed_resize_from_h: number;
|
||||||
|
seed_resize_from_w: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
sampler_name: string;
|
||||||
|
batch_size: number;
|
||||||
|
n_iter: number;
|
||||||
|
steps: number;
|
||||||
|
cfg_scale: number;
|
||||||
|
restore_faces: boolean;
|
||||||
|
tiling: boolean;
|
||||||
|
do_not_save_samples: boolean;
|
||||||
|
do_not_save_grid: boolean;
|
||||||
|
eta: number;
|
||||||
|
s_min_uncond: number;
|
||||||
|
s_churn: number;
|
||||||
|
s_tmax: number;
|
||||||
|
s_tmin: number;
|
||||||
|
s_noise: number;
|
||||||
|
override_settings: object;
|
||||||
|
override_settings_restore_afterwards: boolean;
|
||||||
|
script_args: unknown[];
|
||||||
|
sampler_index: string;
|
||||||
|
script_name: string;
|
||||||
|
send_images: boolean;
|
||||||
|
save_images: boolean;
|
||||||
|
alwayson_scripts: object;
|
||||||
|
}
|
||||||
|
|
||||||
export async function sdTxt2Img(
|
export async function sdTxt2Img(
|
||||||
api: SdApi,
|
api: SdApi,
|
||||||
params: Partial<SdTxt2ImgRequest>,
|
params: Partial<SdTxt2ImgRequest>,
|
||||||
onProgress?: (progress: SdProgressResponse) => void,
|
onProgress?: (progress: SdProgressResponse) => void,
|
||||||
signal: AbortSignal = neverSignal,
|
signal: AbortSignal = neverSignal,
|
||||||
): Promise<SdTxt2ImgResponse> {
|
): Promise<SdResponse<SdTxt2ImgRequest>> {
|
||||||
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
|
const request = fetchSdApi<SdResponse<SdTxt2ImgRequest>>(api, "sdapi/v1/txt2img", params)
|
||||||
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
||||||
.then((data) => ({
|
.then((data) => ({
|
||||||
...data,
|
...data,
|
||||||
|
@ -63,57 +100,65 @@ export async function sdTxt2Img(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SdTxt2ImgRequest {
|
export interface SdTxt2ImgRequest extends SdRequest {
|
||||||
enable_hr: boolean;
|
enable_hr: boolean;
|
||||||
denoising_strength: number;
|
|
||||||
firstphase_width: number;
|
|
||||||
firstphase_height: number;
|
firstphase_height: number;
|
||||||
hr_scale: number;
|
firstphase_width: number;
|
||||||
hr_upscaler: unknown;
|
|
||||||
hr_second_pass_steps: number;
|
|
||||||
hr_resize_x: number;
|
hr_resize_x: number;
|
||||||
hr_resize_y: number;
|
|
||||||
hr_sampler_name: unknown;
|
|
||||||
hr_prompt: string;
|
|
||||||
hr_negative_prompt: string;
|
hr_negative_prompt: string;
|
||||||
prompt: string;
|
hr_prompt: string;
|
||||||
styles: unknown;
|
hr_resize_y: number;
|
||||||
seed: number;
|
hr_sampler_name: string;
|
||||||
subseed: number;
|
hr_scale: number;
|
||||||
subseed_strength: number;
|
hr_second_pass_steps: number;
|
||||||
seed_resize_from_h: number;
|
hr_upscaler: string;
|
||||||
seed_resize_from_w: number;
|
|
||||||
sampler_name: unknown;
|
|
||||||
batch_size: number;
|
|
||||||
n_iter: number;
|
|
||||||
steps: number;
|
|
||||||
cfg_scale: number;
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
restore_faces: boolean;
|
|
||||||
tiling: boolean;
|
|
||||||
do_not_save_samples: boolean;
|
|
||||||
do_not_save_grid: boolean;
|
|
||||||
negative_prompt: string;
|
|
||||||
eta: unknown;
|
|
||||||
s_min_uncond: number;
|
|
||||||
s_churn: number;
|
|
||||||
s_tmax: unknown;
|
|
||||||
s_tmin: number;
|
|
||||||
s_noise: number;
|
|
||||||
override_settings: object;
|
|
||||||
override_settings_restore_afterwards: boolean;
|
|
||||||
script_args: unknown[];
|
|
||||||
sampler_index: string;
|
|
||||||
script_name: unknown;
|
|
||||||
send_images: boolean;
|
|
||||||
save_images: boolean;
|
|
||||||
alwayson_scripts: object;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SdTxt2ImgResponse {
|
export async function sdImg2Img(
|
||||||
|
api: SdApi,
|
||||||
|
params: Partial<SdImg2ImgRequest>,
|
||||||
|
onProgress?: (progress: SdProgressResponse) => void,
|
||||||
|
signal: AbortSignal = neverSignal,
|
||||||
|
): Promise<SdResponse<SdImg2ImgRequest>> {
|
||||||
|
const request = fetchSdApi<SdResponse<SdImg2ImgRequest>>(api, "sdapi/v1/img2img", params)
|
||||||
|
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
|
||||||
|
.then((data) => ({
|
||||||
|
...data,
|
||||||
|
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
|
||||||
|
}));
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
await Async.abortable(Promise.race([request, Async.delay(4000)]), signal);
|
||||||
|
if (await AsyncX.promiseState(request) !== "pending") return await request;
|
||||||
|
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (await AsyncX.promiseState(request) === "pending") {
|
||||||
|
await fetchSdApi(api, "sdapi/v1/interrupt", {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SdImg2ImgRequest extends SdRequest {
|
||||||
|
image_cfg_scale: number;
|
||||||
|
include_init_images: boolean;
|
||||||
|
init_images: string[];
|
||||||
|
initial_noise_multiplier: number;
|
||||||
|
inpaint_full_res: boolean;
|
||||||
|
inpaint_full_res_padding: number;
|
||||||
|
inpainting_fill: number;
|
||||||
|
inpainting_mask_invert: number;
|
||||||
|
mask: string;
|
||||||
|
mask_blur: number;
|
||||||
|
mask_blur_x: number;
|
||||||
|
mask_blur_y: number;
|
||||||
|
resize_mode: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SdResponse<T> {
|
||||||
images: string[];
|
images: string[];
|
||||||
parameters: SdTxt2ImgRequest;
|
parameters: T;
|
||||||
// Warning: raw response from API is a JSON-serialized string
|
// Warning: raw response from API is a JSON-serialized string
|
||||||
info: SdTxt2ImgInfo;
|
info: SdTxt2ImgInfo;
|
||||||
}
|
}
|
||||||
|
@ -242,10 +287,22 @@ export function getPngInfo(pngData: Uint8Array): string | undefined {
|
||||||
?.text;
|
?.text;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
|
export interface PngInfo {
|
||||||
|
prompt: string;
|
||||||
|
negative_prompt: string;
|
||||||
|
steps: number;
|
||||||
|
cfg_scale: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
sampler_name: string;
|
||||||
|
seed: number;
|
||||||
|
denoising_strength: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function parsePngInfo(pngInfo: string): Partial<PngInfo> {
|
||||||
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
|
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
|
||||||
let part: "prompt" | "negative_prompt" | "params" = "prompt";
|
let part: "prompt" | "negative_prompt" | "params" = "prompt";
|
||||||
const params: Partial<SdTxt2ImgRequest> = {};
|
const params: Partial<PngInfo> = {};
|
||||||
const prompt: string[] = [];
|
const prompt: string[] = [];
|
||||||
const negativePrompt: string[] = [];
|
const negativePrompt: string[] = [];
|
||||||
for (const tag of tags) {
|
for (const tag of tags) {
|
||||||
|
@ -294,14 +351,26 @@ export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "denoisingstrength":
|
||||||
|
case "denoising":
|
||||||
|
case "denoise": {
|
||||||
|
part = "params";
|
||||||
|
// allow percent or decimal
|
||||||
|
let denoisingStrength: number;
|
||||||
|
if (value.trim().endsWith("%")) {
|
||||||
|
denoisingStrength = Number(value.trim().slice(0, -1).trim()) / 100;
|
||||||
|
} else {
|
||||||
|
denoisingStrength = Number(value.trim());
|
||||||
|
}
|
||||||
|
denoisingStrength = Math.min(Math.max(denoisingStrength, 0), 1);
|
||||||
|
params.denoising_strength = denoisingStrength;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case "seed":
|
case "seed":
|
||||||
case "model":
|
case "model":
|
||||||
case "modelhash":
|
case "modelhash":
|
||||||
case "modelname":
|
case "modelname":
|
||||||
case "sampler":
|
case "sampler":
|
||||||
case "denoisingstrength":
|
|
||||||
case "denoising":
|
|
||||||
case "denoise":
|
|
||||||
part = "params";
|
part = "params";
|
||||||
// ignore for now
|
// ignore for now
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -12,7 +12,7 @@ import {
|
||||||
import { bot } from "../bot/mod.ts";
|
import { bot } from "../bot/mod.ts";
|
||||||
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
|
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
|
||||||
import { fmt, formatUserChat } from "../utils.ts";
|
import { fmt, formatUserChat } from "../utils.ts";
|
||||||
import { SdApiError, sdTxt2Img } from "../sd.ts";
|
import { SdApiError, sdImg2Img, SdProgressResponse, SdResponse, sdTxt2Img } from "../sd.ts";
|
||||||
import { JobSchema, jobStore } from "../db/jobStore.ts";
|
import { JobSchema, jobStore } from "../db/jobStore.ts";
|
||||||
import { runningWorkers } from "./pingWorkers.ts";
|
import { runningWorkers } from "./pingWorkers.ts";
|
||||||
|
|
||||||
|
@ -48,13 +48,13 @@ export async function processJobs(): Promise<never> {
|
||||||
processJob(job, worker, config)
|
processJob(job, worker, config)
|
||||||
.catch(async (err) => {
|
.catch(async (err) => {
|
||||||
logger().error(
|
logger().error(
|
||||||
`Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`,
|
`Job failed for ${formatUserChat(job.value)} via ${worker.name}: ${err}`,
|
||||||
);
|
);
|
||||||
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
|
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
|
||||||
await bot.api.sendMessage(
|
await bot.api.sendMessage(
|
||||||
job.value.request.chat.id,
|
job.value.chat.id,
|
||||||
`Failed to generate your prompt using ${worker.name}: ${err.message}`,
|
`Failed to generate your prompt using ${worker.name}: ${err.message}`,
|
||||||
{ reply_to_message_id: job.value.request.message_id },
|
{ reply_to_message_id: job.value.requestMessageId },
|
||||||
).catch(() => undefined);
|
).catch(() => undefined);
|
||||||
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
|
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
|
||||||
}
|
}
|
||||||
|
@ -85,21 +85,21 @@ export async function processJobs(): Promise<never> {
|
||||||
|
|
||||||
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
|
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
|
||||||
logger().debug(
|
logger().debug(
|
||||||
`Job started for ${formatUserChat(job.value.request)} using ${worker.name}`,
|
`Job started for ${formatUserChat(job.value)} using ${worker.name}`,
|
||||||
);
|
);
|
||||||
const startDate = new Date();
|
const startDate = new Date();
|
||||||
|
|
||||||
// if there is already a status message delete it
|
// if there is already a status message delete it
|
||||||
if (job.value.reply) {
|
if (job.value.replyMessageId) {
|
||||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
|
||||||
.catch(() => undefined);
|
.catch(() => undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
// send a new status message
|
// send a new status message
|
||||||
const newStatusMessage = await bot.api.sendMessage(
|
const newStatusMessage = await bot.api.sendMessage(
|
||||||
job.value.request.chat.id,
|
job.value.chat.id,
|
||||||
`Generating your prompt now... 0% using ${worker.name}`,
|
`Generating your prompt now... 0% using ${worker.name}`,
|
||||||
{ reply_to_message_id: job.value.request.message_id },
|
{ reply_to_message_id: job.value.requestMessageId },
|
||||||
).catch((err) => {
|
).catch((err) => {
|
||||||
// don't error if the request message was deleted
|
// don't error if the request message was deleted
|
||||||
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
|
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
|
||||||
|
@ -109,47 +109,71 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
||||||
if (!newStatusMessage) {
|
if (!newStatusMessage) {
|
||||||
await job.delete();
|
await job.delete();
|
||||||
logger().info(
|
logger().info(
|
||||||
`Job cancelled for ${formatUserChat(job.value.request)}`,
|
`Job cancelled for ${formatUserChat(job.value)}`,
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
await job.update({ reply: newStatusMessage });
|
await job.update({ replyMessageId: newStatusMessage.message_id });
|
||||||
|
|
||||||
// reduce size if worker can't handle the resolution
|
// reduce size if worker can't handle the resolution
|
||||||
const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution);
|
const size = limitSize(
|
||||||
|
{ ...config.defaultParams, ...job.value.task.params },
|
||||||
// process the job
|
worker.maxResolution,
|
||||||
const response = await sdTxt2Img(
|
|
||||||
worker.api,
|
|
||||||
{ ...config.defaultParams, ...job.value.params, ...size },
|
|
||||||
async (progress) => {
|
|
||||||
// important: don't let any errors escape this callback
|
|
||||||
if (job.value.reply) {
|
|
||||||
await bot.api.editMessageText(
|
|
||||||
job.value.reply.chat.id,
|
|
||||||
job.value.reply.message_id,
|
|
||||||
`Generating your prompt now... ${
|
|
||||||
(progress.progress * 100).toFixed(0)
|
|
||||||
}% using ${worker.name}`,
|
|
||||||
{ maxAttempts: 1 },
|
|
||||||
).catch(() => undefined);
|
|
||||||
}
|
|
||||||
await job.update({
|
|
||||||
status: {
|
|
||||||
type: "processing",
|
|
||||||
progress: progress.progress,
|
|
||||||
worker: worker.name,
|
|
||||||
updatedDate: new Date(),
|
|
||||||
},
|
|
||||||
}, { maxAttempts: 1 }).catch(() => undefined);
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// process the job
|
||||||
|
const handleProgress = async (progress: SdProgressResponse) => {
|
||||||
|
// important: don't let any errors escape this callback
|
||||||
|
if (job.value.replyMessageId) {
|
||||||
|
await bot.api.editMessageText(
|
||||||
|
job.value.chat.id,
|
||||||
|
job.value.replyMessageId,
|
||||||
|
`Generating your prompt now... ${
|
||||||
|
(progress.progress * 100).toFixed(0)
|
||||||
|
}% using ${worker.name}`,
|
||||||
|
{ maxAttempts: 1 },
|
||||||
|
).catch(() => undefined);
|
||||||
|
}
|
||||||
|
await job.update({
|
||||||
|
status: {
|
||||||
|
type: "processing",
|
||||||
|
progress: progress.progress,
|
||||||
|
worker: worker.name,
|
||||||
|
updatedDate: new Date(),
|
||||||
|
},
|
||||||
|
}, { maxAttempts: 1 }).catch(() => undefined);
|
||||||
|
};
|
||||||
|
let response: SdResponse<unknown>;
|
||||||
|
const taskType = job.value.task.type; // don't narrow this to never pls typescript
|
||||||
|
switch (job.value.task.type) {
|
||||||
|
case "txt2img":
|
||||||
|
response = await sdTxt2Img(
|
||||||
|
worker.api,
|
||||||
|
{ ...config.defaultParams, ...job.value.task.params, ...size },
|
||||||
|
handleProgress,
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case "img2img": {
|
||||||
|
const file = await bot.api.getFile(job.value.task.fileId);
|
||||||
|
const fileUrl = `https://api.telegram.org/file/bot${bot.token}/${file.file_path}`;
|
||||||
|
const fileBuffer = await fetch(fileUrl).then((resp) => resp.arrayBuffer());
|
||||||
|
const fileBase64 = Base64.encode(fileBuffer);
|
||||||
|
response = await sdImg2Img(
|
||||||
|
worker.api,
|
||||||
|
{ ...config.defaultParams, ...job.value.task.params, ...size, init_images: [fileBase64] },
|
||||||
|
handleProgress,
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw new Error(`Unknown task type: ${taskType}`);
|
||||||
|
}
|
||||||
|
|
||||||
// upload the result
|
// upload the result
|
||||||
if (job.value.reply) {
|
if (job.value.replyMessageId) {
|
||||||
await bot.api.editMessageText(
|
await bot.api.editMessageText(
|
||||||
job.value.reply.chat.id,
|
job.value.chat.id,
|
||||||
job.value.reply.message_id,
|
job.value.replyMessageId,
|
||||||
`Uploading your images...`,
|
`Uploading your images...`,
|
||||||
).catch(() => undefined);
|
).catch(() => undefined);
|
||||||
}
|
}
|
||||||
|
@ -200,31 +224,31 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
||||||
|
|
||||||
// send the result to telegram
|
// send the result to telegram
|
||||||
try {
|
try {
|
||||||
resultMessages = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, {
|
resultMessages = await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
|
||||||
reply_to_message_id: job.value.request.message_id,
|
reply_to_message_id: job.value.requestMessageId,
|
||||||
maxAttempts: 5,
|
maxAttempts: 5,
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger().warning(`Sending images (attempt ${sendMediaAttempt}) failed: ${err}`);
|
logger().warning(`Sending images (attempt ${sendMediaAttempt}) failed: ${err}`);
|
||||||
if (sendMediaAttempt >= 5) throw err;
|
if (sendMediaAttempt >= 6) throw err;
|
||||||
await Async.delay(15000);
|
await Async.delay(10000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// send caption in separate message if it couldn't fit
|
// send caption in separate message if it couldn't fit
|
||||||
if (caption.text.length > 1024 && caption.text.length <= 4096) {
|
if (caption.text.length > 1024 && caption.text.length <= 4096) {
|
||||||
await bot.api.sendMessage(job.value.request.chat.id, caption.text, {
|
await bot.api.sendMessage(job.value.chat.id, caption.text, {
|
||||||
reply_to_message_id: resultMessages[0].message_id,
|
reply_to_message_id: resultMessages[0].message_id,
|
||||||
entities: caption.entities,
|
entities: caption.entities,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete the status message
|
// delete the status message
|
||||||
if (job.value.reply) {
|
if (job.value.replyMessageId) {
|
||||||
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
|
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
|
||||||
.catch(() => undefined)
|
.catch(() => undefined)
|
||||||
.then(() => job.update({ reply: undefined }))
|
.then(() => job.update({ replyMessageId: undefined }))
|
||||||
.catch(() => undefined);
|
.catch(() => undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,7 +257,7 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
|
||||||
status: { type: "done", info: response.info, startDate, endDate: new Date() },
|
status: { type: "done", info: response.info, startDate, endDate: new Date() },
|
||||||
});
|
});
|
||||||
logger().debug(
|
logger().debug(
|
||||||
`Job finished for ${formatUserChat(job.value.request)} using ${worker.name}${
|
`Job finished for ${formatUserChat(job.value)} using ${worker.name}${
|
||||||
sendMediaAttempt > 1 ? ` after ${sendMediaAttempt} attempts` : ""
|
sendMediaAttempt > 1 ? ` after ${sendMediaAttempt} attempts` : ""
|
||||||
}`,
|
}`,
|
||||||
);
|
);
|
||||||
|
|
|
@ -19,9 +19,7 @@ export async function returnHangedJobs(): Promise<never> {
|
||||||
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
|
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
|
||||||
await job.update({ status: { type: "waiting" } });
|
await job.update({ status: { type: "waiting" } });
|
||||||
logger().warning(
|
logger().warning(
|
||||||
`Job for ${
|
`Job for ${formatUserChat(job.value)} was returned to the queue because it hanged for ${
|
||||||
formatUserChat(job.value.request)
|
|
||||||
} was returned to the queue because it hanged for ${
|
|
||||||
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
|
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
|
||||||
ignoreZero: true,
|
ignoreZero: true,
|
||||||
})
|
})
|
||||||
|
|
|
@ -14,10 +14,10 @@ export async function updateJobStatusMsgs(): Promise<never> {
|
||||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||||
const jobs = await jobStore.getBy("status.type", "waiting");
|
const jobs = await jobStore.getBy("status.type", "waiting");
|
||||||
for (const [index, job] of jobs.entries()) {
|
for (const [index, job] of jobs.entries()) {
|
||||||
if (!job.value.reply) continue;
|
if (!job.value.replyMessageId) continue;
|
||||||
await bot.api.editMessageText(
|
await bot.api.editMessageText(
|
||||||
job.value.reply.chat.id,
|
job.value.chat.id,
|
||||||
job.value.reply.message_id,
|
job.value.replyMessageId,
|
||||||
`You are ${formatOrdinal(index + 1)} in queue.`,
|
`You are ${formatOrdinal(index + 1)} in queue.`,
|
||||||
{ maxAttempts: 1 },
|
{ maxAttempts: 1 },
|
||||||
).catch(() => undefined);
|
).catch(() => undefined);
|
||||||
|
|
Loading…
Reference in New Issue