hitmare-patch-1 #2

Merged
Hitmare merged 7 commits from hitmare-patch-1 into main 2023-12-23 15:10:34 +00:00
6 changed files with 3370 additions and 3353 deletions

View File

@ -17,19 +17,17 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin, Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Dict, List, Any from typing import Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -103,7 +101,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -221,28 +220,28 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -473,9 +472,6 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return models.PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
@ -484,9 +480,10 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
items = {**{'parameters': geninfo}, **items} params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
def progressapi(self, req: models.ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
@ -541,12 +538,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
unload_model_weights() sd_models.unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
reload_model_weights() sd_models.send_model_to_device(shared.sd_model)
return {} return {}
@ -564,9 +561,9 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@ -676,19 +673,6 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
@ -770,6 +754,25 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def get_extensions_list(self):
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)

View File

@ -142,7 +142,7 @@ class StableDiffusionProcessing:
overlay_images: list = None overlay_images: list = None
eta: float = None eta: float = None
do_not_reload_embeddings: bool = False do_not_reload_embeddings: bool = False
denoising_strength: float = 0 denoising_strength: float = None
ddim_discretize: str = None ddim_discretize: str = None
s_min_uncond: float = None s_min_uncond: float = None
s_churn: float = None s_churn: float = None
@ -296,7 +296,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image
@ -533,6 +533,7 @@ class Processed:
self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info] self.infotexts = infotexts or [info]
self.version = program_version()
def js(self): def js(self):
obj = { obj = {
@ -567,6 +568,7 @@ class Processed:
"job_timestamp": self.job_timestamp, "job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip, "clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
"version": self.version,
} }
return json.dumps(obj) return json.dumps(obj)
@ -677,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process(p) p.scripts.before_process(p)
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@ -797,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@ -871,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
@ -884,6 +884,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob()
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@ -936,19 +938,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
if opts.return_mask or opts.save_mask:
image_mask = p.mask_for_overlay.convert('RGB') image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') if save_samples and opts.save_mask:
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask: if opts.return_mask:
output_images.append(image_mask) output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite: if opts.return_mask_composite:
output_images.append(image_mask_composite) output_images.append(image_mask_composite)
@ -956,7 +957,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob() if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None p.color_corrections = None
@ -1142,6 +1144,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr: if not self.enable_hr:
return samples return samples
devices.torch_gc()
if self.latent_scale_mode is None: if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
@ -1151,8 +1154,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@ -1160,7 +1161,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y target_height = self.hr_upscale_to_y
@ -1249,7 +1249,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
return decoded_samples return decoded_samples
def close(self): def close(self):

View File

@ -1,12 +1,11 @@
import torch.nn import torch.nn
import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices from modules import script_callbacks, shared, devices
unet_options = [] unet_options = []
current_unet_option = None current_unet_option = None
current_unet = None current_unet = None
original_forward = None # not used, only left temporarily for compatibility
def list_unets(): def list_unets():
new_unets = script_callbacks.list_unets_callback() new_unets = script_callbacks.list_unets_callback()
@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
pass pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None: if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs) return current_unet.forward(x, timesteps, context, *args, **kwargs)
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) return original_forward(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward

View File

@ -88,6 +88,7 @@ def create_binary_mask(image):
image = image.convert('L') image = image.convert('L')
return image return image
def txt2img_image_conditioning(sd_model, x, width, height): def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
@ -142,7 +143,7 @@ class StableDiffusionProcessing:
overlay_images: list = None overlay_images: list = None
eta: float = None eta: float = None
do_not_reload_embeddings: bool = False do_not_reload_embeddings: bool = False
denoising_strength: float = 0 denoising_strength: float = None
ddim_discretize: str = None ddim_discretize: str = None
s_min_uncond: float = None s_min_uncond: float = None
s_churn: float = None s_churn: float = None
@ -298,7 +299,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image
@ -535,6 +536,7 @@ class Processed:
self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info] self.infotexts = infotexts or [info]
self.version = program_version()
def js(self): def js(self):
obj = { obj = {
@ -569,6 +571,7 @@ class Processed:
"job_timestamp": self.job_timestamp, "job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip, "clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
"version": self.version,
} }
return json.dumps(obj) return json.dumps(obj)
@ -679,8 +682,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@ -711,7 +714,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process(p) p.scripts.before_process(p)
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@ -799,7 +802,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@ -873,7 +875,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
@ -886,6 +887,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob()
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@ -938,19 +941,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
if opts.return_mask or opts.save_mask:
image_mask = p.mask_for_overlay.convert('RGB') image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') if save_samples and opts.save_mask:
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask: if opts.return_mask:
output_images.append(image_mask) output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite: if opts.return_mask_composite:
output_images.append(image_mask_composite) output_images.append(image_mask_composite)
@ -958,7 +960,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob() if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None p.color_corrections = None
@ -1144,6 +1147,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr: if not self.enable_hr:
return samples return samples
devices.torch_gc()
if self.latent_scale_mode is None: if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
@ -1153,8 +1157,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@ -1162,7 +1164,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y target_height = self.hr_upscale_to_y
@ -1251,7 +1252,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
return decoded_samples return decoded_samples
def close(self): def close(self):

View File

@ -1,12 +1,11 @@
import torch.nn import torch.nn
import ldm.modules.diffusionmodules.openaimodel
import time import time
from modules import script_callbacks, shared, devices from modules import script_callbacks, shared, devices
unet_options = [] unet_options = []
current_unet_option = None current_unet_option = None
current_unet = None current_unet = None
original_forward = None # not used, only left temporarily for compatibility
def list_unets(): def list_unets():
new_unets = script_callbacks.list_unets_callback() new_unets = script_callbacks.list_unets_callback()
@ -84,7 +83,10 @@ class SdUnet(torch.nn.Module):
pass pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
try: try:
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt: if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt: if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
@ -98,4 +100,8 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
shared.skip_unet_prompt = shared.current_prompt shared.skip_unet_prompt = shared.current_prompt
print('[UNet] Used', time.time() - start, 'seconds') print('[UNet] Used', time.time() - start, 'seconds')
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward

87
api.py
View File

@ -17,19 +17,17 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin, Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Dict, List, Any from typing import Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -146,7 +144,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -264,28 +263,28 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -462,6 +461,10 @@ class Api:
if eris_consolelog: if eris_consolelog:
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt) print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
# Eris ______ # Eris ______
script_runner = scripts.scripts_txt2img script_runner = scripts.scripts_txt2img
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(False) script_runner.initialize_scripts(False)
@ -598,7 +601,6 @@ class Api:
if eris_consolelog: if eris_consolelog:
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt) print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
# Eris ______ # Eris ______
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -618,6 +620,7 @@ class Api:
if eris_imagelog: if eris_imagelog:
img2imgreq.save_images = True img2imgreq.save_images = True
# Eris ______ # Eris ______
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images, "do_not_save_samples": not img2imgreq.save_images,
@ -699,9 +702,6 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return models.PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
@ -710,9 +710,10 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
items = {**{'parameters': geninfo}, **items} params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
def progressapi(self, req: models.ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
@ -767,12 +768,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
unload_model_weights() sd_models.unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
reload_model_weights() sd_models.send_model_to_device(shared.sd_model)
return {} return {}
@ -790,9 +791,9 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@ -902,19 +903,6 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
@ -996,6 +984,25 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def get_extensions_list(self):
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)