eris/app/generationQueue.ts

351 lines
11 KiB
TypeScript
Raw Normal View History

2023-09-24 13:08:35 +00:00
import { promiseState } from "async";
import { Chat, Message, User } from "grammy_types";
import { JobData, Queue, Worker } from "kvmq";
import createOpenApiClient from "openapi_fetch";
2023-09-26 10:43:36 +00:00
import { delay } from "std/async/delay.ts";
import { decode, encode } from "std/encoding/base64.ts";
2023-10-15 19:13:38 +00:00
import { debug, error, info } from "std/log/mod.ts";
2023-09-24 13:08:35 +00:00
import { ulid } from "ulid";
2023-09-22 02:59:22 +00:00
import { bot } from "../bot/mod.ts";
2023-09-26 10:43:36 +00:00
import { PngInfo } from "../bot/parsePngInfo.ts";
2023-09-24 13:08:35 +00:00
import { formatOrdinal } from "../utils/formatOrdinal.ts";
2023-09-23 18:49:05 +00:00
import { formatUserChat } from "../utils/formatUserChat.ts";
2023-10-13 11:47:57 +00:00
import { getAuthHeader } from "../utils/getAuthHeader.ts";
2023-10-19 21:37:03 +00:00
import { omitUndef } from "../utils/omitUndef.ts";
2023-09-26 10:43:36 +00:00
import { SdError } from "./SdError.ts";
2023-10-05 09:00:51 +00:00
import { getConfig } from "./config.ts";
import { db, fs } from "./db.ts";
import { SdGenerationInfo } from "./generationStore.ts";
2023-09-26 10:43:36 +00:00
import * as SdApi from "./sdApi.ts";
import { uploadQueue } from "./uploadQueue.ts";
2023-10-13 11:47:57 +00:00
import { workerInstanceStore } from "./workerInstanceStore.ts";
2023-09-22 02:59:22 +00:00
interface GenerationJob {
task:
| {
type: "txt2img";
params: Partial<PngInfo>;
}
| {
type: "img2img";
params: Partial<PngInfo>;
fileId: string;
};
2023-09-24 13:08:35 +00:00
from: User;
chat: Chat;
requestMessage: Message;
replyMessage: Message;
2023-10-13 11:47:57 +00:00
workerInstanceKey?: string;
2023-09-22 02:59:22 +00:00
progress?: number;
}
2023-09-24 13:08:35 +00:00
export const generationQueue = new Queue<GenerationJob>(db, "jobQueue");
2023-09-22 02:59:22 +00:00
2023-09-24 13:08:35 +00:00
export const activeGenerationWorkers = new Map<string, Worker<GenerationJob>>();
2023-09-22 02:59:22 +00:00
/**
* Initializes queue workers for each SD instance when they become online.
2023-09-22 02:59:22 +00:00
*/
export async function processGenerationQueue() {
2023-09-22 02:59:22 +00:00
while (true) {
2023-10-13 11:47:57 +00:00
for await (const workerInstance of workerInstanceStore.listAll()) {
const activeWorker = activeGenerationWorkers.get(workerInstance.id);
2023-10-05 09:00:51 +00:00
if (activeWorker?.isProcessing) {
continue;
}
2023-09-22 02:59:22 +00:00
const workerSdClient = createOpenApiClient<SdApi.paths>({
2023-10-13 11:47:57 +00:00
baseUrl: workerInstance.value.sdUrl,
headers: getAuthHeader(workerInstance.value.sdAuth),
2023-09-22 02:59:22 +00:00
});
2023-10-13 11:47:57 +00:00
2023-09-22 02:59:22 +00:00
// check if worker is up
const activeWorkerStatus = await workerSdClient.GET("/sdapi/v1/memory", {
2023-09-23 18:49:05 +00:00
signal: AbortSignal.timeout(10_000),
2023-09-22 02:59:22 +00:00
})
.then((response) => {
if (!response.data) {
throw new SdError("Failed to get worker status", response.response, response.error);
}
return response;
})
.catch((error) => {
2023-10-13 11:47:57 +00:00
workerInstance.update({ lastError: { message: error.message, time: Date.now() } })
.catch(() => undefined);
2023-10-15 19:13:38 +00:00
debug(`Worker ${workerInstance.value.key} is down: ${error}`);
2023-09-22 02:59:22 +00:00
});
2023-10-13 11:47:57 +00:00
2023-09-22 02:59:22 +00:00
if (!activeWorkerStatus?.data) {
continue;
}
2023-11-10 02:50:02 +00:00
await workerSdClient.GET("/internal/sysinfo", {
params: {
query: {
attachment: false,
},
},
signal: AbortSignal.timeout(10_000),
})
.then((response) => {
if (!response.data) {
throw new SdError("Failed to get worker sysinfo", response.response, response.error);
}
// @ts-ignore there is no schema for /internal/sysinfo endpoint response
const nvidiaGPUModels = response.data["Torch env info"].nvidia_gpu_models ?? null;
if (nvidiaGPUModels !== null) {
workerInstance.update({ gpu: nvidiaGPUModels });
}
});
// create worker
const newWorker = generationQueue.createWorker(async ({ state }, updateJob) => {
2023-10-13 11:47:57 +00:00
await processGenerationJob(state, updateJob, workerInstance.id);
});
2023-10-13 11:47:57 +00:00
2023-09-22 02:59:22 +00:00
newWorker.addEventListener("error", (e) => {
2023-10-15 19:13:38 +00:00
error(
`Generation failed for ${formatUserChat(e.detail.job.state)}: ${e.detail.error}`,
);
2023-09-22 02:59:22 +00:00
bot.api.sendMessage(
e.detail.job.state.requestMessage.chat.id,
`Generation failed: ${e.detail.error}\n\n` +
(e.detail.job.retryCount > 0
? `Will retry ${e.detail.job.retryCount} more times.`
: `Aborting.`),
2023-09-22 02:59:22 +00:00
{
reply_to_message_id: e.detail.job.state.requestMessage.message_id,
allow_sending_without_reply: true,
2023-09-22 02:59:22 +00:00
},
).catch(() => undefined);
2023-10-05 09:00:51 +00:00
newWorker.stopProcessing();
2023-10-13 11:47:57 +00:00
workerInstance.update({ lastError: { message: e.detail.error.message, time: Date.now() } })
.catch(() => undefined);
2023-10-15 19:13:38 +00:00
info(`Stopped worker ${workerInstance.value.key}`);
2023-09-22 02:59:22 +00:00
});
2023-10-13 11:47:57 +00:00
newWorker.addEventListener("complete", () => {
workerInstance.update({ lastOnlineTime: Date.now() }).catch(() => undefined);
});
await workerInstance.update({ lastOnlineTime: Date.now() });
newWorker.processJobs();
2023-10-13 11:47:57 +00:00
activeGenerationWorkers.set(workerInstance.id, newWorker);
2023-10-15 19:13:38 +00:00
info(`Started worker ${workerInstance.value.key}`);
2023-09-22 02:59:22 +00:00
}
2023-09-24 13:08:35 +00:00
await delay(60_000);
2023-09-22 02:59:22 +00:00
}
}
/**
* Processes a single job from the queue.
*/
2023-09-22 02:59:22 +00:00
async function processGenerationJob(
state: GenerationJob,
2023-09-24 13:08:35 +00:00
updateJob: (job: Partial<JobData<GenerationJob>>) => Promise<void>,
2023-10-13 11:47:57 +00:00
workerInstanceId: string,
2023-09-22 02:59:22 +00:00
) {
const startDate = new Date();
2023-10-05 09:00:51 +00:00
const config = await getConfig();
2023-10-13 11:47:57 +00:00
const workerInstance = await workerInstanceStore.getById(workerInstanceId);
if (!workerInstance) {
throw new Error(`Unknown workerInstanceId: ${workerInstanceId}`);
2023-10-05 09:00:51 +00:00
}
2023-09-22 02:59:22 +00:00
const workerSdClient = createOpenApiClient<SdApi.paths>({
2023-10-13 11:47:57 +00:00
baseUrl: workerInstance.value.sdUrl,
headers: getAuthHeader(workerInstance.value.sdAuth),
2023-09-22 02:59:22 +00:00
});
2023-10-13 11:47:57 +00:00
state.workerInstanceKey = workerInstance.value.key;
2023-10-01 21:47:29 +00:00
state.progress = 0;
2023-10-15 19:13:38 +00:00
debug(`Generation started for ${formatUserChat(state)}`);
await updateJob({ state: state });
// check if bot can post messages in this chat
const chat = await bot.api.getChat(state.chat.id);
if (
(chat.type === "group" || chat.type === "supergroup") &&
(!chat.permissions?.can_send_messages || !chat.permissions?.can_send_photos)
) {
throw new Error("Bot doesn't have permissions to send photos in this chat");
2023-09-22 02:59:22 +00:00
}
// edit the existing status message
await bot.api.editMessageText(
state.replyMessage.chat.id,
state.replyMessage.message_id,
2023-10-13 11:47:57 +00:00
`Generating your prompt now... 0% using ${
workerInstance.value.name || workerInstance.value.key
}`,
{ maxAttempts: 1 },
2023-09-25 22:01:50 +00:00
).catch(() => undefined);
2023-09-22 02:59:22 +00:00
// reduce size if worker can't handle the resolution
const size = limitSize(
2023-10-19 21:37:03 +00:00
omitUndef({ ...config.defaultParams, ...state.task.params }),
2023-10-13 11:47:57 +00:00
1024 * 1024,
2023-09-22 02:59:22 +00:00
);
function limitSize(
{ width, height }: { width?: number; height?: number },
maxResolution: number,
): { width?: number; height?: number } {
if (!width || !height) return {};
const ratio = width / height;
if (width * height > maxResolution) {
return {
width: Math.trunc(Math.sqrt(maxResolution * ratio)),
height: Math.trunc(Math.sqrt(maxResolution / ratio)),
};
}
return { width, height };
}
// start generating the image
const responsePromise = state.task.type === "txt2img"
2023-09-22 02:59:22 +00:00
? workerSdClient.POST("/sdapi/v1/txt2img", {
2023-10-19 21:37:03 +00:00
body: omitUndef({
2023-09-22 02:59:22 +00:00
...config.defaultParams,
...state.task.params,
2023-09-22 02:59:22 +00:00
...size,
negative_prompt: state.task.params.negative_prompt
? state.task.params.negative_prompt
2023-09-22 02:59:22 +00:00
: config.defaultParams?.negative_prompt,
2023-10-19 21:37:03 +00:00
}),
2023-09-22 02:59:22 +00:00
})
: state.task.type === "img2img"
2023-09-22 02:59:22 +00:00
? workerSdClient.POST("/sdapi/v1/img2img", {
2023-10-19 21:37:03 +00:00
body: omitUndef({
2023-09-22 02:59:22 +00:00
...config.defaultParams,
...state.task.params,
2023-09-22 02:59:22 +00:00
...size,
negative_prompt: state.task.params.negative_prompt
? state.task.params.negative_prompt
2023-09-22 02:59:22 +00:00
: config.defaultParams?.negative_prompt,
init_images: [
2023-09-24 13:08:35 +00:00
encode(
2023-09-22 02:59:22 +00:00
await fetch(
`https://api.telegram.org/file/bot${bot.token}/${await bot.api.getFile(
state.task.fileId,
2023-09-22 02:59:22 +00:00
).then((file) => file.file_path)}`,
).then((resp) => resp.arrayBuffer()),
),
],
2023-10-19 21:37:03 +00:00
}),
2023-09-22 02:59:22 +00:00
})
: undefined;
if (!responsePromise) {
throw new Error(`Unknown task type: ${state.task.type}`);
2023-09-22 02:59:22 +00:00
}
2023-09-26 09:52:11 +00:00
// we await the promise only after it finishes
// so we need to add catch callback to not crash the process before that
responsePromise.catch(() => undefined);
2023-09-22 02:59:22 +00:00
// poll for progress while the generation request is pending
2023-10-01 21:47:29 +00:00
do {
2023-09-22 02:59:22 +00:00
const progressResponse = await workerSdClient.GET("/sdapi/v1/progress", {
params: {},
signal: AbortSignal.timeout(15000),
2023-09-22 02:59:22 +00:00
});
if (!progressResponse.data) {
throw new SdError(
"Progress request failed",
2023-09-22 02:59:22 +00:00
progressResponse.response,
progressResponse.error,
);
}
2023-10-01 21:47:29 +00:00
if (progressResponse.data.progress > state.progress) {
state.progress = progressResponse.data.progress;
await updateJob({ state: state });
await bot.api.sendChatAction(state.chat.id, "upload_photo", { maxAttempts: 1 })
.catch(() => undefined);
await bot.api.editMessageText(
state.replyMessage.chat.id,
state.replyMessage.message_id,
`Generating your prompt now... ${
(progressResponse.data.progress * 100).toFixed(0)
2023-10-13 11:47:57 +00:00
}% using ${workerInstance.value.name || workerInstance.value.key}`,
2023-10-01 21:47:29 +00:00
{ maxAttempts: 1 },
).catch(() => undefined);
}
2023-10-15 19:13:38 +00:00
await Promise.race([delay(2_000), responsePromise]).catch(() => undefined);
2023-09-24 13:08:35 +00:00
} while (await promiseState(responsePromise) === "pending");
// check response
const response = await responsePromise;
2023-09-22 02:59:22 +00:00
if (!response.data) {
throw new SdError(`${state.task.type} failed`, response.response, response.error);
2023-09-22 02:59:22 +00:00
}
if (!response.data.images?.length) {
throw new Error("No images returned from SD");
}
// info field is a json serialized string so we need to parse it
const info: SdGenerationInfo = JSON.parse(response.data.info);
// save images to db
const imageKeys: Deno.KvKey[] = [];
for (const imageBase64 of response.data.images) {
2023-09-24 13:08:35 +00:00
const imageBuffer = decode(imageBase64);
const imageKey = ["images", "upload", ulid()];
await fs.set(imageKey, imageBuffer, { expireIn: 30 * 60 * 1000 });
imageKeys.push(imageKey);
}
// create a new upload job
await uploadQueue.pushJob({
chat: state.chat,
from: state.from,
requestMessage: state.requestMessage,
replyMessage: state.replyMessage,
2023-10-13 11:47:57 +00:00
workerInstanceKey: workerInstance.value.key,
startDate,
endDate: new Date(),
imageKeys,
info,
}, { retryCount: 5, retryDelayMs: 10000 });
2023-09-22 02:59:22 +00:00
// change status message to uploading images
await bot.api.editMessageText(
state.replyMessage.chat.id,
state.replyMessage.message_id,
2023-09-22 02:59:22 +00:00
`Uploading your images...`,
{ maxAttempts: 1 },
).catch(() => undefined);
2023-10-15 19:13:38 +00:00
debug(`Generation finished for ${formatUserChat(state)}`);
2023-09-22 02:59:22 +00:00
}
/**
* Handles queue updates and updates the status message.
2023-09-22 02:59:22 +00:00
*/
export async function updateGenerationQueue() {
2023-09-22 02:59:22 +00:00
while (true) {
const jobs = await generationQueue.getAllJobs();
2023-10-23 00:40:58 +00:00
await Promise.all(jobs.map(async (job) => {
if (job.status === "processing") {
// if the job is processing, the worker will update its status message
return;
2023-09-22 02:59:22 +00:00
}
2023-10-23 00:40:58 +00:00
// spread the updates in time randomly
await delay(Math.random() * 3_000);
2023-09-22 02:59:22 +00:00
await bot.api.editMessageText(
job.state.replyMessage.chat.id,
job.state.replyMessage.message_id,
2023-10-23 00:40:58 +00:00
`You are ${formatOrdinal(job.place)} in queue.`,
2023-09-22 02:59:22 +00:00
{ maxAttempts: 1 },
).catch(() => undefined);
2023-10-23 00:40:58 +00:00
}));
await delay(10_000);
2023-09-22 02:59:22 +00:00
}
}