forked from Hitmare/Eris_api_tensor_patch
Merge pull request 'hitmare-patch-1' (#2) from hitmare-patch-1 into main
Reviewed-on: Hitmare/Eris_api_tensor_patch#2 code was checked and changed parts were copied. SD was able to start without errors
This commit is contained in:
commit
b5dbfc2bca
|
@ -17,19 +17,17 @@ from fastapi.encoders import jsonable_encoder
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin, Image
|
||||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Dict, List, Any
|
from typing import Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
@ -103,7 +101,8 @@ def decode_base64_to_image(encoding):
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
|
if isinstance(image, str):
|
||||||
|
return image
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
|
@ -221,28 +220,28 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
||||||
|
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
||||||
|
|
||||||
if shared.cmd_opts.api_server_stop:
|
if shared.cmd_opts.api_server_stop:
|
||||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
@ -473,9 +472,6 @@ class Api:
|
||||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
|
||||||
return models.PNGInfoResponse(info="")
|
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return models.PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
@ -484,9 +480,10 @@ class Api:
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||||
|
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||||
|
|
||||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
@ -541,12 +538,12 @@ class Api:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
unload_model_weights()
|
sd_models.unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
reload_model_weights()
|
sd_models.send_model_to_device(shared.sd_model)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -564,9 +561,9 @@ class Api:
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: Dict[str, Any]):
|
def set_config(self, req: dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
@ -676,19 +673,6 @@ class Api:
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
|
||||||
try:
|
|
||||||
shared.state.begin(job="preprocess")
|
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
|
||||||
shared.state.end()
|
|
||||||
return models.PreprocessResponse(info='preprocess complete')
|
|
||||||
except KeyError as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
|
||||||
finally:
|
|
||||||
shared.state.end()
|
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="train_embedding")
|
shared.state.begin(job="train_embedding")
|
||||||
|
@ -770,6 +754,25 @@ class Api:
|
||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
|
def get_extensions_list(self):
|
||||||
|
from modules import extensions
|
||||||
|
extensions.list_extensions()
|
||||||
|
ext_list = []
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
ext: extensions.Extension
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
if ext.remote is not None:
|
||||||
|
ext_list.append({
|
||||||
|
"name": ext.name,
|
||||||
|
"remote": ext.remote,
|
||||||
|
"branch": ext.branch,
|
||||||
|
"commit_hash":ext.commit_hash,
|
||||||
|
"commit_date":ext.commit_date,
|
||||||
|
"version":ext.version,
|
||||||
|
"enabled":ext.enabled
|
||||||
|
})
|
||||||
|
return ext_list
|
||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,92 +1,94 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
|
||||||
|
from modules import script_callbacks, shared, devices
|
||||||
from modules import script_callbacks, shared, devices
|
|
||||||
|
unet_options = []
|
||||||
unet_options = []
|
current_unet_option = None
|
||||||
current_unet_option = None
|
current_unet = None
|
||||||
current_unet = None
|
original_forward = None # not used, only left temporarily for compatibility
|
||||||
|
|
||||||
|
def list_unets():
|
||||||
def list_unets():
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
|
||||||
|
unet_options.clear()
|
||||||
unet_options.clear()
|
unet_options.extend(new_unets)
|
||||||
unet_options.extend(new_unets)
|
|
||||||
|
|
||||||
|
def get_unet_option(option=None):
|
||||||
def get_unet_option(option=None):
|
option = option or shared.opts.sd_unet
|
||||||
option = option or shared.opts.sd_unet
|
|
||||||
|
if option == "None":
|
||||||
if option == "None":
|
return None
|
||||||
return None
|
|
||||||
|
if option == "Automatic":
|
||||||
if option == "Automatic":
|
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||||
name = shared.sd_model.sd_checkpoint_info.model_name
|
|
||||||
|
options = [x for x in unet_options if x.model_name == name]
|
||||||
options = [x for x in unet_options if x.model_name == name]
|
|
||||||
|
option = options[0].label if options else "None"
|
||||||
option = options[0].label if options else "None"
|
|
||||||
|
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||||
return next(iter([x for x in unet_options if x.label == option]), None)
|
|
||||||
|
|
||||||
|
def apply_unet(option=None):
|
||||||
def apply_unet(option=None):
|
global current_unet_option
|
||||||
global current_unet_option
|
global current_unet
|
||||||
global current_unet
|
|
||||||
|
new_option = get_unet_option(option)
|
||||||
new_option = get_unet_option(option)
|
if new_option == current_unet_option:
|
||||||
if new_option == current_unet_option:
|
return
|
||||||
return
|
|
||||||
|
if current_unet is not None:
|
||||||
if current_unet is not None:
|
print(f"Dectivating unet: {current_unet.option.label}")
|
||||||
print(f"Dectivating unet: {current_unet.option.label}")
|
current_unet.deactivate()
|
||||||
current_unet.deactivate()
|
|
||||||
|
current_unet_option = new_option
|
||||||
current_unet_option = new_option
|
if current_unet_option is None:
|
||||||
if current_unet_option is None:
|
current_unet = None
|
||||||
current_unet = None
|
|
||||||
|
if not shared.sd_model.lowvram:
|
||||||
if not shared.sd_model.lowvram:
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
|
||||||
|
return
|
||||||
return
|
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||||
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
devices.torch_gc()
|
||||||
devices.torch_gc()
|
|
||||||
|
current_unet = current_unet_option.create_unet()
|
||||||
current_unet = current_unet_option.create_unet()
|
current_unet.option = current_unet_option
|
||||||
current_unet.option = current_unet_option
|
print(f"Activating unet: {current_unet.option.label}")
|
||||||
print(f"Activating unet: {current_unet.option.label}")
|
current_unet.activate()
|
||||||
current_unet.activate()
|
|
||||||
|
|
||||||
|
class SdUnetOption:
|
||||||
class SdUnetOption:
|
model_name = None
|
||||||
model_name = None
|
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||||
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
|
||||||
|
label = None
|
||||||
label = None
|
"""name of the unet in UI"""
|
||||||
"""name of the unet in UI"""
|
|
||||||
|
def create_unet(self):
|
||||||
def create_unet(self):
|
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||||
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
raise NotImplementedError()
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
class SdUnet(torch.nn.Module):
|
||||||
class SdUnet(torch.nn.Module):
|
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||||
def forward(self, x, timesteps, context, *args, **kwargs):
|
raise NotImplementedError()
|
||||||
raise NotImplementedError()
|
|
||||||
|
def activate(self):
|
||||||
def activate(self):
|
pass
|
||||||
pass
|
|
||||||
|
def deactivate(self):
|
||||||
def deactivate(self):
|
pass
|
||||||
pass
|
|
||||||
|
|
||||||
|
def create_unet_forward(original_forward):
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
if current_unet is not None:
|
if current_unet is not None:
|
||||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return UNetModel_forward
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,101 +1,107 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import time
|
||||||
|
from modules import script_callbacks, shared, devices
|
||||||
import time
|
|
||||||
from modules import script_callbacks, shared, devices
|
unet_options = []
|
||||||
unet_options = []
|
current_unet_option = None
|
||||||
current_unet_option = None
|
current_unet = None
|
||||||
current_unet = None
|
original_forward = None # not used, only left temporarily for compatibility
|
||||||
|
|
||||||
|
def list_unets():
|
||||||
def list_unets():
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
|
||||||
|
unet_options.clear()
|
||||||
unet_options.clear()
|
unet_options.extend(new_unets)
|
||||||
unet_options.extend(new_unets)
|
|
||||||
|
|
||||||
|
def get_unet_option(option=None):
|
||||||
def get_unet_option(option=None):
|
option = option or shared.opts.sd_unet
|
||||||
option = option or shared.opts.sd_unet
|
|
||||||
|
if option == "None":
|
||||||
if option == "None":
|
return None
|
||||||
return None
|
|
||||||
|
if option == "Automatic":
|
||||||
if option == "Automatic":
|
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||||
name = shared.sd_model.sd_checkpoint_info.model_name
|
|
||||||
|
options = [x for x in unet_options if x.model_name == name]
|
||||||
options = [x for x in unet_options if x.model_name == name]
|
|
||||||
|
option = options[0].label if options else "None"
|
||||||
option = options[0].label if options else "None"
|
|
||||||
|
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||||
return next(iter([x for x in unet_options if x.label == option]), None)
|
|
||||||
|
|
||||||
|
def apply_unet(option=None):
|
||||||
def apply_unet(option=None):
|
global current_unet_option
|
||||||
global current_unet_option
|
global current_unet
|
||||||
global current_unet
|
|
||||||
|
new_option = get_unet_option(option)
|
||||||
new_option = get_unet_option(option)
|
if new_option == current_unet_option:
|
||||||
if new_option == current_unet_option:
|
return
|
||||||
return
|
|
||||||
|
if current_unet is not None:
|
||||||
if current_unet is not None:
|
print(f"Dectivating unet: {current_unet.option.label}")
|
||||||
print(f"Dectivating unet: {current_unet.option.label}")
|
current_unet.deactivate()
|
||||||
current_unet.deactivate()
|
|
||||||
|
current_unet_option = new_option
|
||||||
current_unet_option = new_option
|
if current_unet_option is None:
|
||||||
if current_unet_option is None:
|
current_unet = None
|
||||||
current_unet = None
|
|
||||||
|
if not shared.sd_model.lowvram:
|
||||||
if not shared.sd_model.lowvram:
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
|
||||||
|
return
|
||||||
return
|
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||||
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
devices.torch_gc()
|
||||||
devices.torch_gc()
|
|
||||||
|
current_unet = current_unet_option.create_unet()
|
||||||
current_unet = current_unet_option.create_unet()
|
current_unet.option = current_unet_option
|
||||||
current_unet.option = current_unet_option
|
print(f"Activating unet: {current_unet.option.label}")
|
||||||
print(f"Activating unet: {current_unet.option.label}")
|
current_unet.activate()
|
||||||
current_unet.activate()
|
|
||||||
|
|
||||||
|
class SdUnetOption:
|
||||||
class SdUnetOption:
|
model_name = None
|
||||||
model_name = None
|
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||||
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
|
||||||
|
label = None
|
||||||
label = None
|
"""name of the unet in UI"""
|
||||||
"""name of the unet in UI"""
|
|
||||||
|
def create_unet(self):
|
||||||
def create_unet(self):
|
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||||
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
raise NotImplementedError()
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
class SdUnet(torch.nn.Module):
|
||||||
class SdUnet(torch.nn.Module):
|
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||||
def forward(self, x, timesteps, context, *args, **kwargs):
|
raise NotImplementedError()
|
||||||
raise NotImplementedError()
|
|
||||||
|
def activate(self):
|
||||||
def activate(self):
|
pass
|
||||||
pass
|
|
||||||
|
def deactivate(self):
|
||||||
def deactivate(self):
|
pass
|
||||||
pass
|
|
||||||
|
|
||||||
|
def create_unet_forward(original_forward):
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
try:
|
if current_unet is not None:
|
||||||
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
|
try:
|
||||||
raise Exception('LoRA unsupported in TRT UNet')
|
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
|
||||||
f = current_unet.forward(x, timesteps, context, *args, **kwargs)
|
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
|
||||||
return f
|
raise Exception('LoRA unsupported in TRT UNet')
|
||||||
except Exception as e:
|
f = current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
start = time.time()
|
return f
|
||||||
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
|
except Exception as e:
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
start = time.time()
|
||||||
shared.skip_unet_prompt = shared.current_prompt
|
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
|
||||||
print('[UNet] Used', time.time() - start, 'seconds')
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
shared.skip_unet_prompt = shared.current_prompt
|
||||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
print('[UNet] Used', time.time() - start, 'seconds')
|
||||||
|
|
||||||
|
|
||||||
|
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return UNetModel_forward
|
||||||
|
|
||||||
|
|
89
api.py
89
api.py
|
@ -17,19 +17,17 @@ from fastapi.encoders import jsonable_encoder
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin, Image
|
||||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Dict, List, Any
|
from typing import Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
@ -146,7 +144,8 @@ def decode_base64_to_image(encoding):
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
|
if isinstance(image, str):
|
||||||
|
return image
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
|
@ -264,28 +263,28 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
||||||
|
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
||||||
|
|
||||||
if shared.cmd_opts.api_server_stop:
|
if shared.cmd_opts.api_server_stop:
|
||||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
@ -462,6 +461,10 @@ class Api:
|
||||||
if eris_consolelog:
|
if eris_consolelog:
|
||||||
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
|
print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
|
@ -598,7 +601,6 @@ class Api:
|
||||||
if eris_consolelog:
|
if eris_consolelog:
|
||||||
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
|
print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
|
@ -618,6 +620,7 @@ class Api:
|
||||||
if eris_imagelog:
|
if eris_imagelog:
|
||||||
img2imgreq.save_images = True
|
img2imgreq.save_images = True
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||||
"do_not_save_samples": not img2imgreq.save_images,
|
"do_not_save_samples": not img2imgreq.save_images,
|
||||||
|
@ -648,7 +651,7 @@ class Api:
|
||||||
apilogimg2imgtext.replace("\n", " ").replace("\r", " ")
|
apilogimg2imgtext.replace("\n", " ").replace("\r", " ")
|
||||||
apilogimg2imgfile.write(f"{apilogimg2imgtext}\n")
|
apilogimg2imgfile.write(f"{apilogimg2imgtext}\n")
|
||||||
# Eris ______
|
# Eris ______
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
|
@ -699,9 +702,6 @@ class Api:
|
||||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
|
||||||
return models.PNGInfoResponse(info="")
|
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return models.PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
@ -710,9 +710,10 @@ class Api:
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||||
|
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||||
|
|
||||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
@ -767,12 +768,12 @@ class Api:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
unload_model_weights()
|
sd_models.unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
reload_model_weights()
|
sd_models.send_model_to_device(shared.sd_model)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -790,9 +791,9 @@ class Api:
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: Dict[str, Any]):
|
def set_config(self, req: dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
@ -902,19 +903,6 @@ class Api:
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
|
||||||
try:
|
|
||||||
shared.state.begin(job="preprocess")
|
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
|
||||||
shared.state.end()
|
|
||||||
return models.PreprocessResponse(info='preprocess complete')
|
|
||||||
except KeyError as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
|
||||||
finally:
|
|
||||||
shared.state.end()
|
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="train_embedding")
|
shared.state.begin(job="train_embedding")
|
||||||
|
@ -996,6 +984,25 @@ class Api:
|
||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
|
def get_extensions_list(self):
|
||||||
|
from modules import extensions
|
||||||
|
extensions.list_extensions()
|
||||||
|
ext_list = []
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
ext: extensions.Extension
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
if ext.remote is not None:
|
||||||
|
ext_list.append({
|
||||||
|
"name": ext.name,
|
||||||
|
"remote": ext.remote,
|
||||||
|
"branch": ext.branch,
|
||||||
|
"commit_hash":ext.commit_hash,
|
||||||
|
"commit_date":ext.commit_date,
|
||||||
|
"version":ext.version,
|
||||||
|
"enabled":ext.enabled
|
||||||
|
})
|
||||||
|
return ext_list
|
||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
|
Loading…
Reference in New Issue