From 0517ce193081dcaed13bfc7917c54551fdaac884 Mon Sep 17 00:00:00 2001 From: pinks Date: Sat, 16 Sep 2023 13:49:12 +0200 Subject: [PATCH] feat: add scale parameter to img2img --- bot/img2imgCommand.ts | 16 ++-------------- bot/txt2imgCommand.ts | 25 +++---------------------- common/parsePngInfo.ts | 31 ++++++++++++++++++++++++++++--- tasks/processJobs.ts | 9 ++++++++- 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/bot/img2imgCommand.ts b/bot/img2imgCommand.ts index 0376a65..d3ef9c6 100644 --- a/bot/img2imgCommand.ts +++ b/bot/img2imgCommand.ts @@ -73,22 +73,10 @@ async function img2img( const repliedToText = repliedToMsg?.text || repliedToMsg?.caption; if (includeRepliedTo && repliedToText) { // TODO: remove bot command from replied to text - const originalParams = parsePngInfo(repliedToText); - params = { - ...originalParams, - ...params, - prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"), - negative_prompt: [originalParams.negative_prompt, params.negative_prompt] - .filter(Boolean).join("\n"), - }; + params = parsePngInfo(repliedToText, params); } - const messageParams = parsePngInfo(match ?? ""); - params = { - ...params, - ...messageParams, - prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"), - }; + params = parsePngInfo(match ?? "", params); if (!fileId) { await ctx.reply( diff --git a/bot/txt2imgCommand.ts b/bot/txt2imgCommand.ts index 0f088a5..12db715 100644 --- a/bot/txt2imgCommand.ts +++ b/bot/txt2imgCommand.ts @@ -50,35 +50,16 @@ async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean): if (includeRepliedTo && repliedToMsg?.document?.mime_type === "image/png") { const file = await ctx.api.getFile(repliedToMsg.document.file_id); const buffer = await fetch(file.getUrl()).then((resp) => resp.arrayBuffer()); - const fileParams = parsePngInfo(getPngInfo(new Uint8Array(buffer)) ?? ""); - params = { - ...params, - ...fileParams, - prompt: [params.prompt, fileParams.prompt].filter(Boolean).join("\n"), - negative_prompt: [params.negative_prompt, fileParams.negative_prompt] - .filter(Boolean).join("\n"), - }; + params = parsePngInfo(getPngInfo(new Uint8Array(buffer)) ?? "", params); } const repliedToText = repliedToMsg?.text || repliedToMsg?.caption; if (includeRepliedTo && repliedToText) { // TODO: remove bot command from replied to text - const originalParams = parsePngInfo(repliedToText); - params = { - ...originalParams, - ...params, - prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"), - negative_prompt: [originalParams.negative_prompt, params.negative_prompt] - .filter(Boolean).join("\n"), - }; + params = parsePngInfo(repliedToText, params); } - const messageParams = parsePngInfo(match); - params = { - ...params, - ...messageParams, - prompt: [params.prompt, messageParams.prompt].filter(Boolean).join("\n"), - }; + params = parsePngInfo(match, params); if (!params.prompt) { await ctx.reply( diff --git a/common/parsePngInfo.ts b/common/parsePngInfo.ts index b324b58..49d9bda 100644 --- a/common/parsePngInfo.ts +++ b/common/parsePngInfo.ts @@ -20,10 +20,14 @@ export interface PngInfo { denoising_strength: number; } -export function parsePngInfo(pngInfo: string): Partial { +interface PngInfoExtra extends PngInfo { + upscale?: number; +} + +export function parsePngInfo(pngInfo: string, baseParams?: Partial): Partial { const tags = pngInfo.split(/[,;]+|\.+\s|\n/u); let part: "prompt" | "negative_prompt" | "params" = "prompt"; - const params: Partial = {}; + const params: Partial = {}; const prompt: string[] = []; const negativePrompt: string[] = []; for (const tag of tags) { @@ -72,6 +76,13 @@ export function parsePngInfo(pngInfo: string): Partial { } break; } + case "upscale": + case "scale": { + part = "params"; + const upscale = Number(value.trim()); + if (upscale > 0) params.upscale = Math.min(upscale, 2); + break; + } case "denoisingstrength": case "denoising": case "denoise": { @@ -113,5 +124,19 @@ export function parsePngInfo(pngInfo: string): Partial { } if (prompt.length > 0) params.prompt = prompt.join(", "); if (negativePrompt.length > 0) params.negative_prompt = negativePrompt.join(", "); - return params; + + // handle upscale + if (params.upscale && baseParams?.width && baseParams?.height) { + params.width = baseParams.width * params.upscale; + params.height = baseParams.height * params.upscale; + } + + return { + ...baseParams, + ...params, + prompt: [baseParams?.prompt, params.prompt] + .filter(Boolean).join("\n"), + negative_prompt: [baseParams?.negative_prompt, params.negative_prompt] + .filter(Boolean).join("\n"), + }; } diff --git a/tasks/processJobs.ts b/tasks/processJobs.ts index e158a18..78ed6b0 100644 --- a/tasks/processJobs.ts +++ b/tasks/processJobs.ts @@ -213,7 +213,14 @@ async function processJob(job: IKV.Model, worker: WorkerData, config: case "txt2img": response = await sdTxt2Img( worker.api, - { ...config.defaultParams, ...job.value.task.params, ...size }, + { + ...config.defaultParams, + ...job.value.task.params, + ...size, + negative_prompt: job.value.task.params.negative_prompt + ? job.value.task.params.negative_prompt + : config.defaultParams?.negative_prompt, + }, handleProgress, ); break;