From cc119297bacf44fc470796665c6583ceb558b4de Mon Sep 17 00:00:00 2001 From: Hitmare Date: Sat, 23 Dec 2023 15:03:35 +0000 Subject: [PATCH] Update for 1.7.0 Update of the TRT Patch for the 1.7.0 version of the original file --- TRT-Patch/modules/sd_unet.py | 208 ++++++++++++++++++----------------- 1 file changed, 107 insertions(+), 101 deletions(-) diff --git a/TRT-Patch/modules/sd_unet.py b/TRT-Patch/modules/sd_unet.py index 49daf1d..3109c5d 100644 --- a/TRT-Patch/modules/sd_unet.py +++ b/TRT-Patch/modules/sd_unet.py @@ -1,101 +1,107 @@ -import torch.nn -import ldm.modules.diffusionmodules.openaimodel - -import time -from modules import script_callbacks, shared, devices -unet_options = [] -current_unet_option = None -current_unet = None - - -def list_unets(): - new_unets = script_callbacks.list_unets_callback() - - unet_options.clear() - unet_options.extend(new_unets) - - -def get_unet_option(option=None): - option = option or shared.opts.sd_unet - - if option == "None": - return None - - if option == "Automatic": - name = shared.sd_model.sd_checkpoint_info.model_name - - options = [x for x in unet_options if x.model_name == name] - - option = options[0].label if options else "None" - - return next(iter([x for x in unet_options if x.label == option]), None) - - -def apply_unet(option=None): - global current_unet_option - global current_unet - - new_option = get_unet_option(option) - if new_option == current_unet_option: - return - - if current_unet is not None: - print(f"Dectivating unet: {current_unet.option.label}") - current_unet.deactivate() - - current_unet_option = new_option - if current_unet_option is None: - current_unet = None - - if not shared.sd_model.lowvram: - shared.sd_model.model.diffusion_model.to(devices.device) - - return - - shared.sd_model.model.diffusion_model.to(devices.cpu) - devices.torch_gc() - - current_unet = current_unet_option.create_unet() - current_unet.option = current_unet_option - print(f"Activating unet: {current_unet.option.label}") - current_unet.activate() - - -class SdUnetOption: - model_name = None - """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" - - label = None - """name of the unet in UI""" - - def create_unet(self): - """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" - raise NotImplementedError() - - -class SdUnet(torch.nn.Module): - def forward(self, x, timesteps, context, *args, **kwargs): - raise NotImplementedError() - - def activate(self): - pass - - def deactivate(self): - pass - - -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - try: - if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt: - if '[TRT]' in shared.opts.sd_unet and '