feat: img2img

This commit is contained in:
pinks 2023-09-12 03:57:44 +02:00
parent 5b3da79129
commit 5b365c9e17
10 changed files with 353 additions and 129 deletions

View File

@ -33,7 +33,8 @@ You can put these in `.env` file or pass them as environment variables.
- [x] Replying to bot message, conversation in DMs
- [x] Replying to png message to extract png info nad generate
- [ ] Banning tags
- [ ] Img2Img + Upscale
- [x] Img2Img + Upscale
- [ ] special param "scale" to change image size preserving aspect ratio
- [ ] Admin WebUI
- [ ] User daily generation limits
- [ ] Querying all generation history, displaying stats

123
bot/img2imgCommand.ts Normal file
View File

@ -0,0 +1,123 @@
import { Collections, Grammy, GrammyStatelessQ } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { parsePngInfo, PngInfo } from "../sd.ts";
import { Context, logger } from "./mod.ts";
export const img2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
"img2img",
async (ctx, state) => {
// todo: also save original image size in state
await img2img(ctx, ctx.message.text, false, state);
},
);
export async function img2imgCommand(ctx: Grammy.CommandContext<Context>) {
await img2img(ctx, ctx.match, true);
}
async function img2img(
ctx: Context,
match: string | undefined,
includeRepliedTo: boolean,
fileId?: string,
): Promise<void> {
if (!ctx.message?.from?.id) {
await ctx.reply("I don't know who you are");
return;
}
if (ctx.session.global.pausedReason != null) {
await ctx.reply(`I'm paused: ${ctx.session.global.pausedReason || "No reason given"}`);
return;
}
const jobs = await jobStore.getBy("status.type", "waiting");
if (jobs.length >= ctx.session.global.maxJobs) {
await ctx.reply(
`The queue is full. Try again later. (Max queue size: ${ctx.session.global.maxJobs})`,
);
return;
}
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
if (userJobs.length >= ctx.session.global.maxUserJobs) {
await ctx.reply(
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
);
return;
}
let params: Partial<PngInfo> = {};
const repliedToMsg = ctx.message.reply_to_message;
if (includeRepliedTo && repliedToMsg?.photo) {
const photos = repliedToMsg.photo;
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
fileId = biggestPhoto.file_id;
params.width = biggestPhoto.width;
params.height = biggestPhoto.height;
}
if (ctx.message.photo) {
const photos = ctx.message.photo;
const biggestPhoto = Collections.maxBy(photos, (p) => p.width * p.height);
if (!biggestPhoto) throw new Error("Message was a photo but had no photos?");
fileId = biggestPhoto.file_id;
params.width = biggestPhoto.width;
params.height = biggestPhoto.height;
}
const repliedToText = repliedToMsg?.text || repliedToMsg?.caption;
if (includeRepliedTo && repliedToText) {
// TODO: remove bot command from replied to text
const originalParams = parsePngInfo(repliedToText);
params = {
...originalParams,
...params,
prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"),
negative_prompt: [originalParams.negative_prompt, params.negative_prompt]
.filter(Boolean).join("\n"),
};
}
const messageParams = parsePngInfo(match ?? "");
params = {
...params,
...messageParams,
prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"),
};
if (!fileId) {
await ctx.reply(
"Please show me a picture to repaint." +
img2imgQuestion.messageSuffixMarkdown(),
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
);
return;
}
if (!params.prompt) {
await ctx.reply(
"Please describe the picture you want to repaint." +
img2imgQuestion.messageSuffixMarkdown(fileId),
{ reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" },
);
return;
}
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
await jobStore.create({
task: { type: "img2img", params, fileId },
from: ctx.message.from,
chat: ctx.message.chat,
requestMessageId: ctx.message.message_id,
replyMessageId: replyMessage.message_id,
status: { type: "waiting" },
});
logger().debug(`Job enqueued for ${formatUserChat(ctx.message)}`);
}

View File

@ -4,6 +4,7 @@ import { session, SessionFlavor } from "./session.ts";
import { queueCommand } from "./queueCommand.ts";
import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts";
import { pnginfoCommand, pnginfoQuestion } from "./pnginfoCommand.ts";
import { img2imgCommand, img2imgQuestion } from "./img2imgCommand.ts";
export const logger = () => Log.getLogger();
@ -74,7 +75,8 @@ bot.api.setMyDescription(
"Send /txt2img to generate an image.",
);
bot.api.setMyCommands([
{ command: "txt2img", description: "Generate an image" },
{ command: "txt2img", description: "Generate an image from text" },
{ command: "img2img", description: "Generate an image based on another image" },
{ command: "pnginfo", description: "Show generation parameters of an image" },
{ command: "queue", description: "Show the current queue" },
]);
@ -84,6 +86,9 @@ bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to gene
bot.command("txt2img", txt2imgCommand);
bot.use(txt2imgQuestion.middleware());
bot.command("img2img", img2imgCommand);
bot.use(img2imgQuestion.middleware());
bot.command("pnginfo", pnginfoCommand);
bot.use(pnginfoQuestion.middleware());

View File

@ -21,16 +21,14 @@ export async function queueCommand(ctx: Grammy.CommandContext<Context>) {
...jobs.length > 0
? jobs.flatMap((job) => [
`${job.place}. `,
fmt`${bold(job.request.from.first_name)} `,
job.request.from.last_name ? fmt`${bold(job.request.from.last_name)} ` : "",
job.request.from.username ? `(@${job.request.from.username}) ` : "",
getFlagEmoji(job.request.from.language_code) ?? "",
job.request.chat.type === "private"
? " in private chat "
: ` in ${job.request.chat.title} `,
job.request.chat.type !== "private" && job.request.chat.type !== "group" &&
job.request.chat.username
? `(@${job.request.chat.username}) `
fmt`${bold(job.from.first_name)} `,
job.from.last_name ? fmt`${bold(job.from.last_name)} ` : "",
job.from.username ? `(@${job.from.username}) ` : "",
getFlagEmoji(job.from.language_code) ?? "",
job.chat.type === "private" ? " in private chat " : ` in ${job.chat.title} `,
job.chat.type !== "private" && job.chat.type !== "group" &&
job.chat.username
? `(@${job.chat.username}) `
: "",
job.status.type === "processing"
? `(${(job.status.progress * 100).toFixed(0)}% using ${job.status.worker}) `

View File

@ -1,7 +1,7 @@
import { Grammy, GrammyStatelessQ } from "../deps.ts";
import { formatUserChat } from "../utils.ts";
import { jobStore } from "../db/jobStore.ts";
import { getPngInfo, parsePngInfo, SdTxt2ImgRequest } from "../sd.ts";
import { getPngInfo, parsePngInfo, PngInfo } from "../sd.ts";
import { Context, logger } from "./mod.ts";
export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion<Context>(
@ -35,7 +35,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
return;
}
const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.message?.from?.id);
const userJobs = jobs.filter((job) => job.value.from.id === ctx.message?.from?.id);
if (userJobs.length >= ctx.session.global.maxUserJobs) {
await ctx.reply(
`You already have ${ctx.session.global.maxUserJobs} jobs in queue. Try again later.`,
@ -43,7 +43,7 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
return;
}
let params: Partial<SdTxt2ImgRequest> = {};
let params: Partial<PngInfo> = {};
const repliedToMsg = ctx.message.reply_to_message;
@ -92,9 +92,11 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean):
const replyMessage = await ctx.reply("Accepted. You are now in queue.");
await jobStore.create({
params,
request: ctx.message,
reply: replyMessage,
task: { type: "txt2img", params },
from: ctx.message.from,
chat: ctx.message.chat,
requestMessageId: ctx.message.message_id,
replyMessageId: replyMessage.message_id,
status: { type: "waiting" },
});

View File

@ -1,11 +1,15 @@
import { GrammyTypes, IKV } from "../deps.ts";
import { SdTxt2ImgInfo, SdTxt2ImgRequest } from "../sd.ts";
import { PngInfo, SdTxt2ImgInfo } from "../sd.ts";
import { db } from "./db.ts";
export interface JobSchema {
params: Partial<SdTxt2ImgRequest>;
request: GrammyTypes.Message & { from: GrammyTypes.User };
reply?: GrammyTypes.Message.TextMessage;
task:
| { type: "txt2img"; params: Partial<PngInfo> }
| { type: "img2img"; params: Partial<PngInfo>; fileId: string };
from: GrammyTypes.User;
chat: GrammyTypes.Chat;
requestMessageId: number;
replyMessageId?: number;
status:
| { type: "waiting" }
| { type: "processing"; progress: number; worker: string; updatedDate: Date }

171
sd.ts
View File

@ -37,13 +37,50 @@ async function fetchSdApi<T>(api: SdApi, endpoint: string, body?: unknown): Prom
return result;
}
interface SdRequest {
prompt: string;
denoising_strength: number;
styles: string[];
negative_prompt: string;
seed: number;
subseed: number;
subseed_strength: number;
seed_resize_from_h: number;
seed_resize_from_w: number;
width: number;
height: number;
sampler_name: string;
batch_size: number;
n_iter: number;
steps: number;
cfg_scale: number;
restore_faces: boolean;
tiling: boolean;
do_not_save_samples: boolean;
do_not_save_grid: boolean;
eta: number;
s_min_uncond: number;
s_churn: number;
s_tmax: number;
s_tmin: number;
s_noise: number;
override_settings: object;
override_settings_restore_afterwards: boolean;
script_args: unknown[];
sampler_index: string;
script_name: string;
send_images: boolean;
save_images: boolean;
alwayson_scripts: object;
}
export async function sdTxt2Img(
api: SdApi,
params: Partial<SdTxt2ImgRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal: AbortSignal = neverSignal,
): Promise<SdTxt2ImgResponse> {
const request = fetchSdApi<SdTxt2ImgResponse>(api, "sdapi/v1/txt2img", params)
): Promise<SdResponse<SdTxt2ImgRequest>> {
const request = fetchSdApi<SdResponse<SdTxt2ImgRequest>>(api, "sdapi/v1/txt2img", params)
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
@ -63,57 +100,65 @@ export async function sdTxt2Img(
}
}
export interface SdTxt2ImgRequest {
export interface SdTxt2ImgRequest extends SdRequest {
enable_hr: boolean;
denoising_strength: number;
firstphase_width: number;
firstphase_height: number;
hr_scale: number;
hr_upscaler: unknown;
hr_second_pass_steps: number;
firstphase_width: number;
hr_resize_x: number;
hr_resize_y: number;
hr_sampler_name: unknown;
hr_prompt: string;
hr_negative_prompt: string;
prompt: string;
styles: unknown;
seed: number;
subseed: number;
subseed_strength: number;
seed_resize_from_h: number;
seed_resize_from_w: number;
sampler_name: unknown;
batch_size: number;
n_iter: number;
steps: number;
cfg_scale: number;
width: number;
height: number;
restore_faces: boolean;
tiling: boolean;
do_not_save_samples: boolean;
do_not_save_grid: boolean;
negative_prompt: string;
eta: unknown;
s_min_uncond: number;
s_churn: number;
s_tmax: unknown;
s_tmin: number;
s_noise: number;
override_settings: object;
override_settings_restore_afterwards: boolean;
script_args: unknown[];
sampler_index: string;
script_name: unknown;
send_images: boolean;
save_images: boolean;
alwayson_scripts: object;
hr_prompt: string;
hr_resize_y: number;
hr_sampler_name: string;
hr_scale: number;
hr_second_pass_steps: number;
hr_upscaler: string;
}
export interface SdTxt2ImgResponse {
export async function sdImg2Img(
api: SdApi,
params: Partial<SdImg2ImgRequest>,
onProgress?: (progress: SdProgressResponse) => void,
signal: AbortSignal = neverSignal,
): Promise<SdResponse<SdImg2ImgRequest>> {
const request = fetchSdApi<SdResponse<SdImg2ImgRequest>>(api, "sdapi/v1/img2img", params)
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
}));
try {
while (true) {
await Async.abortable(Promise.race([request, Async.delay(4000)]), signal);
if (await AsyncX.promiseState(request) !== "pending") return await request;
onProgress?.(await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress"));
}
} finally {
if (await AsyncX.promiseState(request) === "pending") {
await fetchSdApi(api, "sdapi/v1/interrupt", {});
}
}
}
export interface SdImg2ImgRequest extends SdRequest {
image_cfg_scale: number;
include_init_images: boolean;
init_images: string[];
initial_noise_multiplier: number;
inpaint_full_res: boolean;
inpaint_full_res_padding: number;
inpainting_fill: number;
inpainting_mask_invert: number;
mask: string;
mask_blur: number;
mask_blur_x: number;
mask_blur_y: number;
resize_mode: number;
}
export interface SdResponse<T> {
images: string[];
parameters: SdTxt2ImgRequest;
parameters: T;
// Warning: raw response from API is a JSON-serialized string
info: SdTxt2ImgInfo;
}
@ -242,10 +287,22 @@ export function getPngInfo(pngData: Uint8Array): string | undefined {
?.text;
}
export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
export interface PngInfo {
prompt: string;
negative_prompt: string;
steps: number;
cfg_scale: number;
width: number;
height: number;
sampler_name: string;
seed: number;
denoising_strength: number;
}
export function parsePngInfo(pngInfo: string): Partial<PngInfo> {
const tags = pngInfo.split(/[,;]+|\.+\s|\n/u);
let part: "prompt" | "negative_prompt" | "params" = "prompt";
const params: Partial<SdTxt2ImgRequest> = {};
const params: Partial<PngInfo> = {};
const prompt: string[] = [];
const negativePrompt: string[] = [];
for (const tag of tags) {
@ -294,14 +351,26 @@ export function parsePngInfo(pngInfo: string): Partial<SdTxt2ImgRequest> {
}
break;
}
case "denoisingstrength":
case "denoising":
case "denoise": {
part = "params";
// allow percent or decimal
let denoisingStrength: number;
if (value.trim().endsWith("%")) {
denoisingStrength = Number(value.trim().slice(0, -1).trim()) / 100;
} else {
denoisingStrength = Number(value.trim());
}
denoisingStrength = Math.min(Math.max(denoisingStrength, 0), 1);
params.denoising_strength = denoisingStrength;
break;
}
case "seed":
case "model":
case "modelhash":
case "modelname":
case "sampler":
case "denoisingstrength":
case "denoising":
case "denoise":
part = "params";
// ignore for now
break;

View File

@ -12,7 +12,7 @@ import {
import { bot } from "../bot/mod.ts";
import { getGlobalSession, GlobalData, WorkerData } from "../bot/session.ts";
import { fmt, formatUserChat } from "../utils.ts";
import { SdApiError, sdTxt2Img } from "../sd.ts";
import { SdApiError, sdImg2Img, SdProgressResponse, SdResponse, sdTxt2Img } from "../sd.ts";
import { JobSchema, jobStore } from "../db/jobStore.ts";
import { runningWorkers } from "./pingWorkers.ts";
@ -48,13 +48,13 @@ export async function processJobs(): Promise<never> {
processJob(job, worker, config)
.catch(async (err) => {
logger().error(
`Job failed for ${formatUserChat(job.value.request)} via ${worker.name}: ${err}`,
`Job failed for ${formatUserChat(job.value)} via ${worker.name}: ${err}`,
);
if (err instanceof Grammy.GrammyError || err instanceof SdApiError) {
await bot.api.sendMessage(
job.value.request.chat.id,
job.value.chat.id,
`Failed to generate your prompt using ${worker.name}: ${err.message}`,
{ reply_to_message_id: job.value.request.message_id },
{ reply_to_message_id: job.value.requestMessageId },
).catch(() => undefined);
await job.update({ status: { type: "waiting" } }).catch(() => undefined);
}
@ -85,21 +85,21 @@ export async function processJobs(): Promise<never> {
async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config: GlobalData) {
logger().debug(
`Job started for ${formatUserChat(job.value.request)} using ${worker.name}`,
`Job started for ${formatUserChat(job.value)} using ${worker.name}`,
);
const startDate = new Date();
// if there is already a status message delete it
if (job.value.reply) {
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
if (job.value.replyMessageId) {
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
.catch(() => undefined);
}
// send a new status message
const newStatusMessage = await bot.api.sendMessage(
job.value.request.chat.id,
job.value.chat.id,
`Generating your prompt now... 0% using ${worker.name}`,
{ reply_to_message_id: job.value.request.message_id },
{ reply_to_message_id: job.value.requestMessageId },
).catch((err) => {
// don't error if the request message was deleted
if (err instanceof Grammy.GrammyError && err.message.match(/repl(y|ied)/)) return null;
@ -109,25 +109,25 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
if (!newStatusMessage) {
await job.delete();
logger().info(
`Job cancelled for ${formatUserChat(job.value.request)}`,
`Job cancelled for ${formatUserChat(job.value)}`,
);
return;
}
await job.update({ reply: newStatusMessage });
await job.update({ replyMessageId: newStatusMessage.message_id });
// reduce size if worker can't handle the resolution
const size = limitSize({ ...config.defaultParams, ...job.value.params }, worker.maxResolution);
const size = limitSize(
{ ...config.defaultParams, ...job.value.task.params },
worker.maxResolution,
);
// process the job
const response = await sdTxt2Img(
worker.api,
{ ...config.defaultParams, ...job.value.params, ...size },
async (progress) => {
const handleProgress = async (progress: SdProgressResponse) => {
// important: don't let any errors escape this callback
if (job.value.reply) {
if (job.value.replyMessageId) {
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
job.value.chat.id,
job.value.replyMessageId,
`Generating your prompt now... ${
(progress.progress * 100).toFixed(0)
}% using ${worker.name}`,
@ -142,14 +142,38 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
updatedDate: new Date(),
},
}, { maxAttempts: 1 }).catch(() => undefined);
},
};
let response: SdResponse<unknown>;
const taskType = job.value.task.type; // don't narrow this to never pls typescript
switch (job.value.task.type) {
case "txt2img":
response = await sdTxt2Img(
worker.api,
{ ...config.defaultParams, ...job.value.task.params, ...size },
handleProgress,
);
break;
case "img2img": {
const file = await bot.api.getFile(job.value.task.fileId);
const fileUrl = `https://api.telegram.org/file/bot${bot.token}/${file.file_path}`;
const fileBuffer = await fetch(fileUrl).then((resp) => resp.arrayBuffer());
const fileBase64 = Base64.encode(fileBuffer);
response = await sdImg2Img(
worker.api,
{ ...config.defaultParams, ...job.value.task.params, ...size, init_images: [fileBase64] },
handleProgress,
);
break;
}
default:
throw new Error(`Unknown task type: ${taskType}`);
}
// upload the result
if (job.value.reply) {
if (job.value.replyMessageId) {
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
job.value.chat.id,
job.value.replyMessageId,
`Uploading your images...`,
).catch(() => undefined);
}
@ -200,31 +224,31 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
// send the result to telegram
try {
resultMessages = await bot.api.sendMediaGroup(job.value.request.chat.id, inputFiles, {
reply_to_message_id: job.value.request.message_id,
resultMessages = await bot.api.sendMediaGroup(job.value.chat.id, inputFiles, {
reply_to_message_id: job.value.requestMessageId,
maxAttempts: 5,
});
break;
} catch (err) {
logger().warning(`Sending images (attempt ${sendMediaAttempt}) failed: ${err}`);
if (sendMediaAttempt >= 5) throw err;
await Async.delay(15000);
if (sendMediaAttempt >= 6) throw err;
await Async.delay(10000);
}
}
// send caption in separate message if it couldn't fit
if (caption.text.length > 1024 && caption.text.length <= 4096) {
await bot.api.sendMessage(job.value.request.chat.id, caption.text, {
await bot.api.sendMessage(job.value.chat.id, caption.text, {
reply_to_message_id: resultMessages[0].message_id,
entities: caption.entities,
});
}
// delete the status message
if (job.value.reply) {
await bot.api.deleteMessage(job.value.reply.chat.id, job.value.reply.message_id)
if (job.value.replyMessageId) {
await bot.api.deleteMessage(job.value.chat.id, job.value.replyMessageId)
.catch(() => undefined)
.then(() => job.update({ reply: undefined }))
.then(() => job.update({ replyMessageId: undefined }))
.catch(() => undefined);
}
@ -233,7 +257,7 @@ async function processJob(job: IKV.Model<JobSchema>, worker: WorkerData, config:
status: { type: "done", info: response.info, startDate, endDate: new Date() },
});
logger().debug(
`Job finished for ${formatUserChat(job.value.request)} using ${worker.name}${
`Job finished for ${formatUserChat(job.value)} using ${worker.name}${
sendMediaAttempt > 1 ? ` after ${sendMediaAttempt} attempts` : ""
}`,
);

View File

@ -19,9 +19,7 @@ export async function returnHangedJobs(): Promise<never> {
if (timeSinceLastUpdateMs > 2 * 60 * 1000) {
await job.update({ status: { type: "waiting" } });
logger().warning(
`Job for ${
formatUserChat(job.value.request)
} was returned to the queue because it hanged for ${
`Job for ${formatUserChat(job.value)} was returned to the queue because it hanged for ${
FmtDuration.format(Math.trunc(timeSinceLastUpdateMs / 1000) * 1000, {
ignoreZero: true,
})

View File

@ -14,10 +14,10 @@ export async function updateJobStatusMsgs(): Promise<never> {
await new Promise((resolve) => setTimeout(resolve, 5000));
const jobs = await jobStore.getBy("status.type", "waiting");
for (const [index, job] of jobs.entries()) {
if (!job.value.reply) continue;
if (!job.value.replyMessageId) continue;
await bot.api.editMessageText(
job.value.reply.chat.id,
job.value.reply.message_id,
job.value.chat.id,
job.value.replyMessageId,
`You are ${formatOrdinal(index + 1)} in queue.`,
{ maxAttempts: 1 },
).catch(() => undefined);