eris/common/sdApi.ts

301 lines
7.9 KiB
TypeScript
Raw Normal View History

2023-09-13 09:50:22 +00:00
import { Async, AsyncX } from "../deps.ts";
2023-09-10 18:56:17 +00:00
export interface SdApi {
url: string;
auth?: string;
}
2023-09-13 00:38:09 +00:00
async function fetchSdApi<T>(
api: SdApi,
endpoint: string,
{ body, timeoutMs }: { body?: unknown; timeoutMs?: number } = {},
): Promise<T> {
const controller = new AbortController();
const timeoutId = timeoutMs ? setTimeout(() => controller.abort(), timeoutMs) : undefined;
2023-09-10 18:56:17 +00:00
let options: RequestInit | undefined;
if (body != null) {
options = {
method: "POST",
headers: {
"Content-Type": "application/json",
...api.auth ? { Authorization: api.auth } : {},
},
body: JSON.stringify(body),
2023-09-13 00:38:09 +00:00
signal: controller.signal,
2023-09-10 18:56:17 +00:00
};
} else if (api.auth) {
options = {
headers: { Authorization: api.auth },
2023-09-13 00:38:09 +00:00
signal: controller.signal,
2023-09-10 18:56:17 +00:00
};
}
const response = await fetch(new URL(endpoint, api.url), options).catch(() => {
2023-09-13 00:38:09 +00:00
if (controller.signal.aborted) {
throw new SdApiError(endpoint, options, -1, "Timed out");
}
2023-09-10 18:56:17 +00:00
throw new SdApiError(endpoint, options, 0, "Network error");
});
const result = await response.json().catch(() => {
throw new SdApiError(endpoint, options, response.status, response.statusText, {
detail: "Invalid JSON",
});
});
2023-09-13 00:38:09 +00:00
clearTimeout(timeoutId);
2023-09-10 18:56:17 +00:00
if (!response.ok) {
throw new SdApiError(endpoint, options, response.status, response.statusText, result);
}
return result;
}
2023-09-12 01:57:44 +00:00
interface SdRequest {
prompt: string;
denoising_strength: number;
styles: string[];
negative_prompt: string;
seed: number;
subseed: number;
subseed_strength: number;
seed_resize_from_h: number;
seed_resize_from_w: number;
width: number;
height: number;
sampler_name: string;
batch_size: number;
n_iter: number;
steps: number;
cfg_scale: number;
restore_faces: boolean;
tiling: boolean;
do_not_save_samples: boolean;
do_not_save_grid: boolean;
eta: number;
s_min_uncond: number;
s_churn: number;
s_tmax: number;
s_tmin: number;
s_noise: number;
override_settings: object;
override_settings_restore_afterwards: boolean;
script_args: unknown[];
sampler_index: string;
script_name: string;
send_images: boolean;
save_images: boolean;
alwayson_scripts: object;
}
2023-09-10 18:56:17 +00:00
export async function sdTxt2Img(
api: SdApi,
params: Partial<SdTxt2ImgRequest>,
2023-09-04 16:55:48 +00:00
onProgress?: (progress: SdProgressResponse) => void,
2023-09-12 01:57:44 +00:00
): Promise<SdResponse<SdTxt2ImgRequest>> {
2023-09-13 00:38:09 +00:00
const request = fetchSdApi<SdResponse<SdTxt2ImgRequest>>(
api,
"sdapi/v1/txt2img",
{ body: params },
)
2023-09-10 18:56:17 +00:00
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
}));
2023-09-04 16:55:48 +00:00
try {
while (true) {
2023-09-13 00:38:09 +00:00
await Promise.race([request, Async.delay(3000)]);
2023-09-10 18:56:17 +00:00
if (await AsyncX.promiseState(request) !== "pending") return await request;
2023-09-13 00:38:09 +00:00
onProgress?.(
await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress", { timeoutMs: 10_000 }),
);
2023-09-04 16:55:48 +00:00
}
} finally {
2023-09-10 18:56:17 +00:00
if (await AsyncX.promiseState(request) === "pending") {
2023-09-14 01:21:40 +00:00
await fetchSdApi(api, "sdapi/v1/interrupt", { body: {}, timeoutMs: 10_000 });
2023-09-04 16:55:48 +00:00
}
}
}
2023-09-12 01:57:44 +00:00
export interface SdTxt2ImgRequest extends SdRequest {
2023-09-10 18:56:17 +00:00
enable_hr: boolean;
firstphase_height: number;
2023-09-12 01:57:44 +00:00
firstphase_width: number;
2023-09-10 18:56:17 +00:00
hr_resize_x: number;
hr_negative_prompt: string;
2023-09-12 01:57:44 +00:00
hr_prompt: string;
hr_resize_y: number;
hr_sampler_name: string;
hr_scale: number;
hr_second_pass_steps: number;
hr_upscaler: string;
2023-09-04 16:55:48 +00:00
}
2023-09-12 01:57:44 +00:00
export async function sdImg2Img(
api: SdApi,
params: Partial<SdImg2ImgRequest>,
onProgress?: (progress: SdProgressResponse) => void,
): Promise<SdResponse<SdImg2ImgRequest>> {
2023-09-13 00:38:09 +00:00
const request = fetchSdApi<SdResponse<SdImg2ImgRequest>>(
api,
"sdapi/v1/img2img",
{ body: params },
)
2023-09-12 01:57:44 +00:00
// JSON field "info" is a JSON-serialized string so we need to parse this part second time
.then((data) => ({
...data,
info: typeof data.info === "string" ? JSON.parse(data.info) : data.info,
}));
try {
while (true) {
2023-09-13 00:38:09 +00:00
await Promise.race([request, Async.delay(3000)]);
2023-09-12 01:57:44 +00:00
if (await AsyncX.promiseState(request) !== "pending") return await request;
2023-09-13 00:38:09 +00:00
onProgress?.(
await fetchSdApi<SdProgressResponse>(api, "sdapi/v1/progress", { timeoutMs: 10_000 }),
);
2023-09-12 01:57:44 +00:00
}
} finally {
if (await AsyncX.promiseState(request) === "pending") {
2023-09-13 00:38:09 +00:00
await fetchSdApi(api, "sdapi/v1/interrupt", { timeoutMs: 10_000 });
2023-09-12 01:57:44 +00:00
}
}
}
export interface SdImg2ImgRequest extends SdRequest {
image_cfg_scale: number;
include_init_images: boolean;
init_images: string[];
initial_noise_multiplier: number;
inpaint_full_res: boolean;
inpaint_full_res_padding: number;
inpainting_fill: number;
inpainting_mask_invert: number;
mask: string;
mask_blur: number;
mask_blur_x: number;
mask_blur_y: number;
resize_mode: number;
}
export interface SdResponse<T> {
2023-09-04 16:55:48 +00:00
images: string[];
2023-09-12 01:57:44 +00:00
parameters: T;
2023-09-10 18:56:17 +00:00
// Warning: raw response from API is a JSON-serialized string
info: SdTxt2ImgInfo;
}
export interface SdTxt2ImgInfo {
prompt: string;
all_prompts: string[];
negative_prompt: string;
all_negative_prompts: string[];
seed: number;
all_seeds: number[];
subseed: number;
all_subseeds: number[];
subseed_strength: number;
width: number;
height: number;
sampler_name: string;
cfg_scale: number;
steps: number;
batch_size: number;
restore_faces: boolean;
face_restoration_model: unknown;
sd_model_hash: string;
seed_resize_from_w: number;
seed_resize_from_h: number;
denoising_strength: number;
extra_generation_params: SdTxt2ImgInfoExtraParams;
index_of_first_image: number;
infotexts: string[];
styles: unknown[];
job_timestamp: string;
clip_skip: number;
is_using_inpainting_conditioning: boolean;
}
export interface SdTxt2ImgInfoExtraParams {
"Lora hashes": string;
"TI hashes": string;
2023-09-04 16:55:48 +00:00
}
export interface SdProgressResponse {
progress: number;
eta_relative: number;
state: SdProgressState;
/** base64 encoded preview */
current_image: string | null;
textinfo: string | null;
}
export interface SdProgressState {
skipped: boolean;
interrupted: boolean;
job: string;
job_count: number;
job_timestamp: string;
job_no: number;
sampling_step: number;
sampling_steps: number;
}
2023-09-10 18:56:17 +00:00
export function sdGetConfig(api: SdApi): Promise<SdConfigResponse> {
2023-09-13 00:38:09 +00:00
return fetchSdApi(api, "config", { timeoutMs: 10_000 });
2023-09-10 18:56:17 +00:00
}
export interface SdConfigResponse {
/** version with new line at the end for some reason */
version: string;
mode: string;
dev_mode: boolean;
analytics_enabled: boolean;
components: object[];
css: unknown;
title: string;
is_space: boolean;
enable_queue: boolean;
show_error: boolean;
show_api: boolean;
is_colab: boolean;
stylesheets: unknown[];
theme: string;
layout: object;
dependencies: object[];
root: string;
}
export interface SdErrorResponse {
/**
* The HTTP status message or array of invalid fields.
* Can also be empty string.
*/
detail: string | Array<{ loc: string[]; msg: string; type: string }>;
/** Can be e.g. "OutOfMemoryError" or undefined. */
error?: string;
/** Empty string. */
body?: string;
/** Long description of error. */
errors?: string;
}
export class SdApiError extends Error {
constructor(
public readonly endpoint: string,
public readonly options: RequestInit | undefined,
public readonly statusCode: number,
public readonly statusText: string,
public readonly response?: SdErrorResponse,
) {
let message = `${options?.method ?? "GET"} ${endpoint} : ${statusCode} ${statusText}`;
if (response?.error) {
message += `: ${response.error}`;
if (response.errors) message += ` - ${response.errors}`;
} else if (typeof response?.detail === "string" && response.detail.length > 0) {
message += `: ${response.detail}`;
} else if (response?.detail) {
message += `: ${JSON.stringify(response.detail)}`;
}
super(message);
}
}