Compare commits
No commits in common. "b5dbfc2bca7830c04c5b60834b10eca0475140fa" and "a4394df01492f020aad9e9f8d0ed3114a49acf16" have entirely different histories.
b5dbfc2bca
...
a4394df014
|
@ -17,17 +17,19 @@ from fastapi.encoders import jsonable_encoder
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin, Image
|
from PIL import PngImagePlugin,Image
|
||||||
|
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Any
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
@ -101,8 +103,7 @@ def decode_base64_to_image(encoding):
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
if isinstance(image, str):
|
|
||||||
return image
|
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
|
@ -220,28 +221,28 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
|
||||||
|
|
||||||
if shared.cmd_opts.api_server_stop:
|
if shared.cmd_opts.api_server_stop:
|
||||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
@ -472,6 +473,9 @@ class Api:
|
||||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
|
if(not req.image.strip()):
|
||||||
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return models.PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
@ -480,10 +484,9 @@ class Api:
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
items = {**{'parameters': geninfo}, **items}
|
||||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
@ -538,12 +541,12 @@ class Api:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
sd_models.unload_model_weights()
|
unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
sd_models.send_model_to_device(shared.sd_model)
|
reload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -561,9 +564,9 @@ class Api:
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: dict[str, Any]):
|
def set_config(self, req: Dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
@ -673,6 +676,19 @@ class Api:
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
|
def preprocess(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin(job="preprocess")
|
||||||
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
|
shared.state.end()
|
||||||
|
return models.PreprocessResponse(info='preprocess complete')
|
||||||
|
except KeyError as e:
|
||||||
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="train_embedding")
|
shared.state.begin(job="train_embedding")
|
||||||
|
@ -754,25 +770,6 @@ class Api:
|
||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def get_extensions_list(self):
|
|
||||||
from modules import extensions
|
|
||||||
extensions.list_extensions()
|
|
||||||
ext_list = []
|
|
||||||
for ext in extensions.extensions:
|
|
||||||
ext: extensions.Extension
|
|
||||||
ext.read_info_from_repo()
|
|
||||||
if ext.remote is not None:
|
|
||||||
ext_list.append({
|
|
||||||
"name": ext.name,
|
|
||||||
"remote": ext.remote,
|
|
||||||
"branch": ext.branch,
|
|
||||||
"commit_hash":ext.commit_hash,
|
|
||||||
"commit_date":ext.commit_date,
|
|
||||||
"version":ext.version,
|
|
||||||
"enabled":ext.enabled
|
|
||||||
})
|
|
||||||
return ext_list
|
|
||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
|
|
@ -142,7 +142,7 @@ class StableDiffusionProcessing:
|
||||||
overlay_images: list = None
|
overlay_images: list = None
|
||||||
eta: float = None
|
eta: float = None
|
||||||
do_not_reload_embeddings: bool = False
|
do_not_reload_embeddings: bool = False
|
||||||
denoising_strength: float = None
|
denoising_strength: float = 0
|
||||||
ddim_discretize: str = None
|
ddim_discretize: str = None
|
||||||
s_min_uncond: float = None
|
s_min_uncond: float = None
|
||||||
s_churn: float = None
|
s_churn: float = None
|
||||||
|
@ -296,7 +296,7 @@ class StableDiffusionProcessing:
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
|
@ -533,7 +533,6 @@ class Processed:
|
||||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||||
self.infotexts = infotexts or [info]
|
self.infotexts = infotexts or [info]
|
||||||
self.version = program_version()
|
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
|
@ -568,7 +567,6 @@ class Processed:
|
||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
"clip_skip": self.clip_skip,
|
"clip_skip": self.clip_skip,
|
||||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||||
"version": self.version,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
|
@ -679,8 +677,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
|
@ -711,7 +709,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process(p)
|
p.scripts.before_process(p)
|
||||||
|
|
||||||
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
|
@ -799,6 +797,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
|
@ -872,6 +871,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
|
@ -884,8 +884,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
|
@ -938,27 +936,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||||
if opts.return_mask or opts.save_mask:
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
if save_samples and opts.save_mask:
|
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
|
||||||
if opts.return_mask:
|
|
||||||
output_images.append(image_mask)
|
|
||||||
|
|
||||||
if opts.return_mask_composite or opts.save_mask_composite:
|
if opts.save_mask:
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
if save_samples and opts.save_mask_composite:
|
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
if opts.save_mask_composite:
|
||||||
if opts.return_mask_composite:
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
output_images.append(image_mask_composite)
|
|
||||||
|
if opts.return_mask:
|
||||||
|
output_images.append(image_mask)
|
||||||
|
|
||||||
|
if opts.return_mask_composite:
|
||||||
|
output_images.append(image_mask_composite)
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if not infotexts:
|
state.nextjob()
|
||||||
infotexts.append(Processed(p, []).infotext(p, 0))
|
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
|
@ -1144,7 +1142,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
|
@ -1154,6 +1151,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
|
@ -1161,6 +1160,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
|
@ -1249,6 +1249,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return decoded_samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
|
||||||
from modules import script_callbacks, shared, devices
|
from modules import script_callbacks, shared, devices
|
||||||
|
|
||||||
unet_options = []
|
unet_options = []
|
||||||
current_unet_option = None
|
current_unet_option = None
|
||||||
current_unet = None
|
current_unet = None
|
||||||
original_forward = None # not used, only left temporarily for compatibility
|
|
||||||
|
|
||||||
def list_unets():
|
def list_unets():
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
|
@ -83,12 +84,9 @@ class SdUnet(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_unet_forward(original_forward):
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
if current_unet is not None:
|
||||||
if current_unet is not None:
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
|
||||||
|
|
||||||
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
return UNetModel_forward
|
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,6 @@ def create_binary_mask(image):
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||||
|
|
||||||
|
@ -143,7 +142,7 @@ class StableDiffusionProcessing:
|
||||||
overlay_images: list = None
|
overlay_images: list = None
|
||||||
eta: float = None
|
eta: float = None
|
||||||
do_not_reload_embeddings: bool = False
|
do_not_reload_embeddings: bool = False
|
||||||
denoising_strength: float = None
|
denoising_strength: float = 0
|
||||||
ddim_discretize: str = None
|
ddim_discretize: str = None
|
||||||
s_min_uncond: float = None
|
s_min_uncond: float = None
|
||||||
s_churn: float = None
|
s_churn: float = None
|
||||||
|
@ -299,7 +298,7 @@ class StableDiffusionProcessing:
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
|
@ -536,7 +535,6 @@ class Processed:
|
||||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||||
self.infotexts = infotexts or [info]
|
self.infotexts = infotexts or [info]
|
||||||
self.version = program_version()
|
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
|
@ -571,7 +569,6 @@ class Processed:
|
||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
"clip_skip": self.clip_skip,
|
"clip_skip": self.clip_skip,
|
||||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||||
"version": self.version,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
|
@ -682,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
|
@ -714,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process(p)
|
p.scripts.before_process(p)
|
||||||
|
|
||||||
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
|
@ -802,6 +799,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
|
@ -875,6 +873,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
|
@ -887,8 +886,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
|
@ -941,27 +938,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||||
if opts.return_mask or opts.save_mask:
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
if save_samples and opts.save_mask:
|
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
|
||||||
if opts.return_mask:
|
|
||||||
output_images.append(image_mask)
|
|
||||||
|
|
||||||
if opts.return_mask_composite or opts.save_mask_composite:
|
if opts.save_mask:
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
if save_samples and opts.save_mask_composite:
|
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
if opts.save_mask_composite:
|
||||||
if opts.return_mask_composite:
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
output_images.append(image_mask_composite)
|
|
||||||
|
if opts.return_mask:
|
||||||
|
output_images.append(image_mask)
|
||||||
|
|
||||||
|
if opts.return_mask_composite:
|
||||||
|
output_images.append(image_mask_composite)
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if not infotexts:
|
state.nextjob()
|
||||||
infotexts.append(Processed(p, []).infotext(p, 0))
|
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
|
@ -1147,7 +1144,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
|
@ -1157,6 +1153,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
|
@ -1164,6 +1162,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
|
@ -1252,6 +1251,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return decoded_samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from modules import script_callbacks, shared, devices
|
from modules import script_callbacks, shared, devices
|
||||||
|
|
||||||
unet_options = []
|
unet_options = []
|
||||||
current_unet_option = None
|
current_unet_option = None
|
||||||
current_unet = None
|
current_unet = None
|
||||||
original_forward = None # not used, only left temporarily for compatibility
|
|
||||||
|
|
||||||
def list_unets():
|
def list_unets():
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
|
@ -83,25 +84,18 @@ class SdUnet(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_unet_forward(original_forward):
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
try:
|
||||||
if current_unet is not None:
|
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
|
||||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
|
||||||
try:
|
raise Exception('LoRA unsupported in TRT UNet')
|
||||||
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
|
f = current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
|
return f
|
||||||
raise Exception('LoRA unsupported in TRT UNet')
|
except Exception as e:
|
||||||
f = current_unet.forward(x, timesteps, context, *args, **kwargs)
|
start = time.time()
|
||||||
return f
|
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
|
||||||
except Exception as e:
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
start = time.time()
|
shared.skip_unet_prompt = shared.current_prompt
|
||||||
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
|
print('[UNet] Used', time.time() - start, 'seconds')
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
|
||||||
shared.skip_unet_prompt = shared.current_prompt
|
|
||||||
print('[UNet] Used', time.time() - start, 'seconds')
|
|
||||||
|
|
||||||
|
|
||||||
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
|
||||||
|
|
||||||
return UNetModel_forward
|
|
||||||
|
|
||||||
|
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
87
api.py
87
api.py
|
@ -17,17 +17,19 @@ from fastapi.encoders import jsonable_encoder
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin, Image
|
from PIL import PngImagePlugin,Image
|
||||||
|
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Any
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
@ -144,8 +146,7 @@ def decode_base64_to_image(encoding):
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
if isinstance(image, str):
|
|
||||||
return image
|
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
|
@ -263,28 +264,28 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
|
||||||
|
|
||||||
if shared.cmd_opts.api_server_stop:
|
if shared.cmd_opts.api_server_stop:
|
||||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
@ -461,10 +462,6 @@ class Api:
|
||||||
if eris_consolelog:
|
if eris_consolelog:
|
||||||
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
|
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
|
@ -601,6 +598,7 @@ class Api:
|
||||||
if eris_consolelog:
|
if eris_consolelog:
|
||||||
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
|
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
|
@ -620,7 +618,6 @@ class Api:
|
||||||
if eris_imagelog:
|
if eris_imagelog:
|
||||||
img2imgreq.save_images = True
|
img2imgreq.save_images = True
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||||
"do_not_save_samples": not img2imgreq.save_images,
|
"do_not_save_samples": not img2imgreq.save_images,
|
||||||
|
@ -702,6 +699,9 @@ class Api:
|
||||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
|
if(not req.image.strip()):
|
||||||
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return models.PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
@ -710,10 +710,9 @@ class Api:
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
items = {**{'parameters': geninfo}, **items}
|
||||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
@ -768,12 +767,12 @@ class Api:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
sd_models.unload_model_weights()
|
unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
sd_models.send_model_to_device(shared.sd_model)
|
reload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -791,9 +790,9 @@ class Api:
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: dict[str, Any]):
|
def set_config(self, req: Dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
@ -903,6 +902,19 @@ class Api:
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
|
def preprocess(self, args: dict):
|
||||||
|
try:
|
||||||
|
shared.state.begin(job="preprocess")
|
||||||
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
|
shared.state.end()
|
||||||
|
return models.PreprocessResponse(info='preprocess complete')
|
||||||
|
except KeyError as e:
|
||||||
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="train_embedding")
|
shared.state.begin(job="train_embedding")
|
||||||
|
@ -984,25 +996,6 @@ class Api:
|
||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def get_extensions_list(self):
|
|
||||||
from modules import extensions
|
|
||||||
extensions.list_extensions()
|
|
||||||
ext_list = []
|
|
||||||
for ext in extensions.extensions:
|
|
||||||
ext: extensions.Extension
|
|
||||||
ext.read_info_from_repo()
|
|
||||||
if ext.remote is not None:
|
|
||||||
ext_list.append({
|
|
||||||
"name": ext.name,
|
|
||||||
"remote": ext.remote,
|
|
||||||
"branch": ext.branch,
|
|
||||||
"commit_hash":ext.commit_hash,
|
|
||||||
"commit_date":ext.commit_date,
|
|
||||||
"version":ext.version,
|
|
||||||
"enabled":ext.enabled
|
|
||||||
})
|
|
||||||
return ext_list
|
|
||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
|
Loading…
Reference in New Issue