forked from Hitmare/Eris_api_tensor_patch
Upload files to "TRT-Patch/modules"
This commit is contained in:
parent
7f443b287e
commit
0dfdc4f931
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,101 @@
|
|||
import torch.nn
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
import time
|
||||
from modules import script_callbacks, shared, devices
|
||||
unet_options = []
|
||||
current_unet_option = None
|
||||
current_unet = None
|
||||
|
||||
|
||||
def list_unets():
|
||||
new_unets = script_callbacks.list_unets_callback()
|
||||
|
||||
unet_options.clear()
|
||||
unet_options.extend(new_unets)
|
||||
|
||||
|
||||
def get_unet_option(option=None):
|
||||
option = option or shared.opts.sd_unet
|
||||
|
||||
if option == "None":
|
||||
return None
|
||||
|
||||
if option == "Automatic":
|
||||
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||
|
||||
options = [x for x in unet_options if x.model_name == name]
|
||||
|
||||
option = options[0].label if options else "None"
|
||||
|
||||
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||
|
||||
|
||||
def apply_unet(option=None):
|
||||
global current_unet_option
|
||||
global current_unet
|
||||
|
||||
new_option = get_unet_option(option)
|
||||
if new_option == current_unet_option:
|
||||
return
|
||||
|
||||
if current_unet is not None:
|
||||
print(f"Dectivating unet: {current_unet.option.label}")
|
||||
current_unet.deactivate()
|
||||
|
||||
current_unet_option = new_option
|
||||
if current_unet_option is None:
|
||||
current_unet = None
|
||||
|
||||
if not shared.sd_model.lowvram:
|
||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||
|
||||
return
|
||||
|
||||
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
current_unet = current_unet_option.create_unet()
|
||||
current_unet.option = current_unet_option
|
||||
print(f"Activating unet: {current_unet.option.label}")
|
||||
current_unet.activate()
|
||||
|
||||
|
||||
class SdUnetOption:
|
||||
model_name = None
|
||||
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||
|
||||
label = None
|
||||
"""name of the unet in UI"""
|
||||
|
||||
def create_unet(self):
|
||||
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SdUnet(torch.nn.Module):
|
||||
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def activate(self):
|
||||
pass
|
||||
|
||||
def deactivate(self):
|
||||
pass
|
||||
|
||||
|
||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||
try:
|
||||
if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
|
||||
if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
|
||||
raise Exception('LoRA unsupported in TRT UNet')
|
||||
f = current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||
return f
|
||||
except Exception as e:
|
||||
start = time.time()
|
||||
print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
|
||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||
shared.skip_unet_prompt = shared.current_prompt
|
||||
print('[UNet] Used', time.time() - start, 'seconds')
|
||||
|
||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
|
@ -0,0 +1,90 @@
|
|||
import sys
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules import util
|
||||
|
||||
cmd_opts = shared_cmd_options.cmd_opts
|
||||
parser = shared_cmd_options.parser
|
||||
|
||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||
parallel_processing_allowed = True
|
||||
styles_filename = cmd_opts.styles_file
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
demo = None
|
||||
|
||||
device = None
|
||||
|
||||
weight_load_location = None
|
||||
|
||||
xformers_available = False
|
||||
|
||||
hypernetworks = {}
|
||||
|
||||
loaded_hypernetworks = []
|
||||
|
||||
state = None
|
||||
|
||||
prompt_styles = None
|
||||
|
||||
interrogator = None
|
||||
|
||||
face_restorers = []
|
||||
|
||||
options_templates = None
|
||||
opts = None
|
||||
restricted_opts = None
|
||||
|
||||
sd_model: sd_models_types.WebuiSdModel = None
|
||||
|
||||
settings_components = None
|
||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||
|
||||
tab_names = []
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
||||
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||
"Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
|
||||
}
|
||||
|
||||
sd_upscalers = []
|
||||
|
||||
clip_model = None
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
gradio_theme = gr.themes.Base()
|
||||
|
||||
total_tqdm = None
|
||||
|
||||
mem_mon = None
|
||||
|
||||
options_section = options.options_section
|
||||
OptionInfo = options.OptionInfo
|
||||
OptionHTML = options.OptionHTML
|
||||
|
||||
natural_sort_key = util.natural_sort_key
|
||||
listfiles = util.listfiles
|
||||
html_path = util.html_path
|
||||
html = util.html
|
||||
walk_files = util.walk_files
|
||||
ldm_print = util.ldm_print
|
||||
|
||||
reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
|
||||
|
||||
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
|
||||
refresh_checkpoints = shared_items.refresh_checkpoints
|
||||
list_samplers = shared_items.list_samplers
|
||||
reload_hypernetworks = shared_items.reload_hypernetworks
|
||||
|
||||
current_prompt = ''
|
||||
skip_unet_prompt = ''
|
Loading…
Reference in New Issue