Compare commits

..

No commits in common. "b5dbfc2bca7830c04c5b60834b10eca0475140fa" and "a4394df01492f020aad9e9f8d0ed3114a49acf16" have entirely different histories.

6 changed files with 3353 additions and 3370 deletions

View File

@ -17,17 +17,19 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image from PIL import PngImagePlugin,Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Any from typing import Dict, List, Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -101,8 +103,7 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -220,28 +221,28 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -472,6 +473,9 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return models.PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
@ -480,10 +484,9 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
params = generation_parameters_copypaste.parse_generation_parameters(geninfo) items = {**{'parameters': geninfo}, **items}
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) return models.PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: models.ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
@ -538,12 +541,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
sd_models.unload_model_weights() unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
sd_models.send_model_to_device(shared.sd_model) reload_model_weights()
return {} return {}
@ -561,9 +564,9 @@ class Api:
return options return options
def set_config(self, req: dict[str, Any]): def set_config(self, req: Dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@ -673,6 +676,19 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
@ -754,25 +770,6 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def get_extensions_list(self):
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)

File diff suppressed because it is too large Load Diff

View File

@ -1,94 +1,92 @@
import torch.nn import torch.nn
import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None unet_options = []
current_unet = None current_unet_option = None
original_forward = None # not used, only left temporarily for compatibility current_unet = None
def list_unets():
new_unets = script_callbacks.list_unets_callback() def list_unets():
new_unets = script_callbacks.list_unets_callback()
unet_options.clear()
unet_options.extend(new_unets) unet_options.clear()
unet_options.extend(new_unets)
def get_unet_option(option=None):
option = option or shared.opts.sd_unet def get_unet_option(option=None):
option = option or shared.opts.sd_unet
if option == "None":
return None if option == "None":
return None
if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name
options = [x for x in unet_options if x.model_name == name]
options = [x for x in unet_options if x.model_name == name]
option = options[0].label if options else "None"
option = options[0].label if options else "None"
return next(iter([x for x in unet_options if x.label == option]), None)
return next(iter([x for x in unet_options if x.label == option]), None)
def apply_unet(option=None):
global current_unet_option def apply_unet(option=None):
global current_unet global current_unet_option
global current_unet
new_option = get_unet_option(option)
if new_option == current_unet_option: new_option = get_unet_option(option)
return if new_option == current_unet_option:
return
if current_unet is not None:
print(f"Dectivating unet: {current_unet.option.label}") if current_unet is not None:
current_unet.deactivate() print(f"Dectivating unet: {current_unet.option.label}")
current_unet.deactivate()
current_unet_option = new_option
if current_unet_option is None: current_unet_option = new_option
current_unet = None if current_unet_option is None:
current_unet = None
if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device) if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
return
shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc() shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc()
current_unet = current_unet_option.create_unet()
current_unet.option = current_unet_option current_unet = current_unet_option.create_unet()
print(f"Activating unet: {current_unet.option.label}") current_unet.option = current_unet_option
current_unet.activate() print(f"Activating unet: {current_unet.option.label}")
current_unet.activate()
class SdUnetOption:
model_name = None class SdUnetOption:
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" model_name = None
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
label = None
"""name of the unet in UI""" label = None
"""name of the unet in UI"""
def create_unet(self):
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" def create_unet(self):
raise NotImplementedError() """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
raise NotImplementedError()
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs): class SdUnet(torch.nn.Module):
raise NotImplementedError() def forward(self, x, timesteps, context, *args, **kwargs):
raise NotImplementedError()
def activate(self):
pass def activate(self):
pass
def deactivate(self):
pass def deactivate(self):
pass
def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None: if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs) return current_unet.forward(x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs) return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward

File diff suppressed because it is too large Load Diff

View File

@ -1,107 +1,101 @@
import torch.nn import torch.nn
import time import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
import time
unet_options = [] from modules import script_callbacks, shared, devices
current_unet_option = None unet_options = []
current_unet = None current_unet_option = None
original_forward = None # not used, only left temporarily for compatibility current_unet = None
def list_unets():
new_unets = script_callbacks.list_unets_callback() def list_unets():
new_unets = script_callbacks.list_unets_callback()
unet_options.clear()
unet_options.extend(new_unets) unet_options.clear()
unet_options.extend(new_unets)
def get_unet_option(option=None):
option = option or shared.opts.sd_unet def get_unet_option(option=None):
option = option or shared.opts.sd_unet
if option == "None":
return None if option == "None":
return None
if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name
options = [x for x in unet_options if x.model_name == name]
options = [x for x in unet_options if x.model_name == name]
option = options[0].label if options else "None"
option = options[0].label if options else "None"
return next(iter([x for x in unet_options if x.label == option]), None)
return next(iter([x for x in unet_options if x.label == option]), None)
def apply_unet(option=None):
global current_unet_option def apply_unet(option=None):
global current_unet global current_unet_option
global current_unet
new_option = get_unet_option(option)
if new_option == current_unet_option: new_option = get_unet_option(option)
return if new_option == current_unet_option:
return
if current_unet is not None:
print(f"Dectivating unet: {current_unet.option.label}") if current_unet is not None:
current_unet.deactivate() print(f"Dectivating unet: {current_unet.option.label}")
current_unet.deactivate()
current_unet_option = new_option
if current_unet_option is None: current_unet_option = new_option
current_unet = None if current_unet_option is None:
current_unet = None
if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device) if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
return
shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc() shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc()
current_unet = current_unet_option.create_unet()
current_unet.option = current_unet_option current_unet = current_unet_option.create_unet()
print(f"Activating unet: {current_unet.option.label}") current_unet.option = current_unet_option
current_unet.activate() print(f"Activating unet: {current_unet.option.label}")
current_unet.activate()
class SdUnetOption:
model_name = None class SdUnetOption:
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" model_name = None
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
label = None
"""name of the unet in UI""" label = None
"""name of the unet in UI"""
def create_unet(self):
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" def create_unet(self):
raise NotImplementedError() """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
raise NotImplementedError()
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs): class SdUnet(torch.nn.Module):
raise NotImplementedError() def forward(self, x, timesteps, context, *args, **kwargs):
raise NotImplementedError()
def activate(self):
pass def activate(self):
pass
def deactivate(self):
pass def deactivate(self):
pass
def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None: try:
return current_unet.forward(x, timesteps, context, *args, **kwargs) if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
try: if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt: raise Exception('LoRA unsupported in TRT UNet')
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt: f = current_unet.forward(x, timesteps, context, *args, **kwargs)
raise Exception('LoRA unsupported in TRT UNet') return f
f = current_unet.forward(x, timesteps, context, *args, **kwargs) except Exception as e:
return f start = time.time()
except Exception as e: print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
start = time.time() shared.sd_model.model.diffusion_model.to(devices.device)
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt) shared.skip_unet_prompt = shared.current_prompt
shared.sd_model.model.diffusion_model.to(devices.device) print('[UNet] Used', time.time() - start, 'seconds')
shared.skip_unet_prompt = shared.current_prompt
print('[UNet] Used', time.time() - start, 'seconds') return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward

89
api.py
View File

@ -17,17 +17,19 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image from PIL import PngImagePlugin,Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Any from typing import Dict, List, Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -144,8 +146,7 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -263,28 +264,28 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -461,10 +462,6 @@ class Api:
if eris_consolelog: if eris_consolelog:
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt) print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
# Eris ______ # Eris ______
script_runner = scripts.scripts_txt2img script_runner = scripts.scripts_txt2img
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(False) script_runner.initialize_scripts(False)
@ -601,6 +598,7 @@ class Api:
if eris_consolelog: if eris_consolelog:
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt) print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
# Eris ______ # Eris ______
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -620,7 +618,6 @@ class Api:
if eris_imagelog: if eris_imagelog:
img2imgreq.save_images = True img2imgreq.save_images = True
# Eris ______ # Eris ______
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images, "do_not_save_samples": not img2imgreq.save_images,
@ -651,7 +648,7 @@ class Api:
apilogimg2imgtext.replace("\n", " ").replace("\r", " ") apilogimg2imgtext.replace("\n", " ").replace("\r", " ")
apilogimg2imgfile.write(f"{apilogimg2imgtext}\n") apilogimg2imgfile.write(f"{apilogimg2imgtext}\n")
# Eris ______ # Eris ______
with self.queue_lock: with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images] p.init_images = [decode_base64_to_image(x) for x in init_images]
@ -702,6 +699,9 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return models.PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
@ -710,10 +710,9 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
params = generation_parameters_copypaste.parse_generation_parameters(geninfo) items = {**{'parameters': geninfo}, **items}
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) return models.PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: models.ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
@ -768,12 +767,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
sd_models.unload_model_weights() unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
sd_models.send_model_to_device(shared.sd_model) reload_model_weights()
return {} return {}
@ -791,9 +790,9 @@ class Api:
return options return options
def set_config(self, req: dict[str, Any]): def set_config(self, req: Dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@ -903,6 +902,19 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
@ -984,25 +996,6 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def get_extensions_list(self):
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)