diff --git a/README.md b/README.md index eb5e343..a53e3f9 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,8 @@ You can put these in `.env` file or pass them as environment variables. - [x] Changing params, parsing png info in request - [x] Cancelling jobs by deleting message - [x] Multiple parallel workers -- [ ] Replying to another text message to copy prompt and generate -- [ ] Replying to bot message, conversation in DMs +- [x] Replying to another text message to copy prompt and generate +- [x] Replying to bot message, conversation in DMs - [ ] Replying to png message to extract png info nad generate - [ ] Banning tags - [ ] Img2Img + Upscale diff --git a/bot/mod.ts b/bot/mod.ts index 9ef9106..71e8184 100644 --- a/bot/mod.ts +++ b/bot/mod.ts @@ -2,7 +2,7 @@ import { Grammy, GrammyAutoQuote, GrammyParseMode, Log } from "../deps.ts"; import { formatUserChat } from "../utils.ts"; import { session, SessionFlavor } from "./session.ts"; import { queueCommand } from "./queueCommand.ts"; -import { txt2imgCommand } from "./txt2imgCommand.ts"; +import { txt2imgCommand, txt2imgQuestion } from "./txt2imgCommand.ts"; export const logger = () => Log.getLogger(); @@ -62,6 +62,7 @@ bot.api.setMyCommands([ bot.command("start", (ctx) => ctx.reply("Hello! Use the /txt2img command to generate an image")); bot.command("txt2img", txt2imgCommand); +bot.use(txt2imgQuestion.middleware() as any); bot.command("queue", queueCommand); diff --git a/bot/session.ts b/bot/session.ts index 2f17df1..62c50be 100644 --- a/bot/session.ts +++ b/bot/session.ts @@ -1,3 +1,4 @@ +import { db } from "../db/db.ts"; import { Grammy, GrammyKvStorage } from "../deps.ts"; import { SdApi, SdTxt2ImgRequest } from "../sd.ts"; @@ -33,9 +34,7 @@ export interface UserData { params?: Partial; } -const globalDb = await Deno.openKv("./app.db"); - -const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter(globalDb); +const globalDbAdapter = new GrammyKvStorage.DenoKVAdapter(db); const getDefaultGlobalData = (): GlobalData => ({ adminUsernames: Deno.env.get("TG_ADMIN_USERS")?.split(",") ?? [], diff --git a/bot/txt2imgCommand.ts b/bot/txt2imgCommand.ts index 40f205b..e0687fe 100644 --- a/bot/txt2imgCommand.ts +++ b/bot/txt2imgCommand.ts @@ -1,39 +1,72 @@ -import { Grammy } from "../deps.ts"; +import { Grammy, GrammyStatelessQ } from "../deps.ts"; import { formatUserChat } from "../utils.ts"; import { jobStore } from "../db/jobStore.ts"; import { parsePngInfo } from "../sd.ts"; import { Context, logger } from "./mod.ts"; +export const txt2imgQuestion = new GrammyStatelessQ.StatelessQuestion( + "txt2img", + async (ctx) => { + if (!ctx.message.text) return; + await txt2img(ctx as any, ctx.message.text, false); + }, +); + export async function txt2imgCommand(ctx: Grammy.CommandContext) { - if (!ctx.from?.id) { - return ctx.reply("I don't know who you are"); + await txt2img(ctx, ctx.match, true); +} + +async function txt2img(ctx: Context, match: string, includeRepliedTo: boolean): Promise { + if (!ctx.message?.from?.id) { + return void ctx.reply("I don't know who you are"); } + const config = ctx.session.global; if (config.pausedReason != null) { - return ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); + return void ctx.reply(`I'm paused: ${config.pausedReason || "No reason given"}`); } + const jobs = await jobStore.getBy("status.type", "waiting"); if (jobs.length >= config.maxJobs) { - return ctx.reply( + return void ctx.reply( `The queue is full. Try again later. (Max queue size: ${config.maxJobs})`, ); } - const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.from?.id); + + const userJobs = jobs.filter((job) => job.value.request.from.id === ctx.message?.from?.id); if (userJobs.length >= config.maxUserJobs) { - return ctx.reply( + return void ctx.reply( `You already have ${config.maxUserJobs} jobs in queue. Try again later.`, ); } - const params = parsePngInfo(ctx.match); - if (!params.prompt) { - return ctx.reply("Please describe what you want to see after the command"); + + let params = parsePngInfo(match); + const repliedToMsg = ctx.message.reply_to_message; + const repliedToText = repliedToMsg?.text || repliedToMsg?.caption; + if (includeRepliedTo && repliedToText) { + const originalParams = parsePngInfo(repliedToText); + params = { + ...originalParams, + ...params, + prompt: [originalParams.prompt, params.prompt].filter(Boolean).join("\n"), + }; } - const reply = await ctx.reply("Accepted. You are now in queue."); + if (!params.prompt) { + return void ctx.reply( + "Please tell me what you want to see." + + txt2imgQuestion.messageSuffixMarkdown(), + { reply_markup: { force_reply: true, selective: true }, parse_mode: "Markdown" }, + ); + } + + const replyMessage = await ctx.reply("Accepted. You are now in queue."); + await jobStore.create({ params, request: ctx.message, - reply, + reply: replyMessage, status: { type: "waiting" }, }); - logger().debug(`Job enqueued for ${formatUserChat(ctx)}`); + + logger().debug(`Job enqueued for ${formatUserChat(ctx.message)}`); } diff --git a/db/jobStore.ts b/db/jobStore.ts index 6b74717..f4274a2 100644 --- a/db/jobStore.ts +++ b/db/jobStore.ts @@ -4,7 +4,7 @@ import { db } from "./db.ts"; export interface JobSchema { params: Partial; - request: GrammyTypes.Message.TextMessage & { from: GrammyTypes.User }; + request: GrammyTypes.Message & { from: GrammyTypes.User }; reply?: GrammyTypes.Message.TextMessage; status: | { type: "waiting" } diff --git a/deps.ts b/deps.ts index 2a15398..5d89d00 100644 --- a/deps.ts +++ b/deps.ts @@ -11,6 +11,7 @@ export * as GrammyTypes from "https://deno.land/x/grammy_types@v3.2.0/mod.ts"; export * as GrammyAutoQuote from "https://deno.land/x/grammy_autoquote@v1.1.2/mod.ts"; export * as GrammyParseMode from "https://deno.land/x/grammy_parse_mode@1.7.1/mod.ts"; export * as GrammyKvStorage from "https://deno.land/x/grammy_storages@v2.3.1/denokv/src/mod.ts"; +export * as GrammyStatelessQ from "npm:@grammyjs/stateless-question"; export * as FileType from "npm:file-type@18.5.0"; // @deno-types="./types/png-chunks-extract.d.ts" export * as PngChunksExtract from "npm:png-chunks-extract@1.0.0";