diff --git a/TRT-Patch/#Original/sd_unet.py b/TRT-Patch/#Original/sd_unet.py index 5525cfb..1b56744 100644 --- a/TRT-Patch/#Original/sd_unet.py +++ b/TRT-Patch/#Original/sd_unet.py @@ -1,92 +1,94 @@ -import torch.nn -import ldm.modules.diffusionmodules.openaimodel - -from modules import script_callbacks, shared, devices - -unet_options = [] -current_unet_option = None -current_unet = None - - -def list_unets(): - new_unets = script_callbacks.list_unets_callback() - - unet_options.clear() - unet_options.extend(new_unets) - - -def get_unet_option(option=None): - option = option or shared.opts.sd_unet - - if option == "None": - return None - - if option == "Automatic": - name = shared.sd_model.sd_checkpoint_info.model_name - - options = [x for x in unet_options if x.model_name == name] - - option = options[0].label if options else "None" - - return next(iter([x for x in unet_options if x.label == option]), None) - - -def apply_unet(option=None): - global current_unet_option - global current_unet - - new_option = get_unet_option(option) - if new_option == current_unet_option: - return - - if current_unet is not None: - print(f"Dectivating unet: {current_unet.option.label}") - current_unet.deactivate() - - current_unet_option = new_option - if current_unet_option is None: - current_unet = None - - if not shared.sd_model.lowvram: - shared.sd_model.model.diffusion_model.to(devices.device) - - return - - shared.sd_model.model.diffusion_model.to(devices.cpu) - devices.torch_gc() - - current_unet = current_unet_option.create_unet() - current_unet.option = current_unet_option - print(f"Activating unet: {current_unet.option.label}") - current_unet.activate() - - -class SdUnetOption: - model_name = None - """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" - - label = None - """name of the unet in UI""" - - def create_unet(self): - """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" - raise NotImplementedError() - - -class SdUnet(torch.nn.Module): - def forward(self, x, timesteps, context, *args, **kwargs): - raise NotImplementedError() - - def activate(self): - pass - - def deactivate(self): - pass - - -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - if current_unet is not None: - return current_unet.forward(x, timesteps, context, *args, **kwargs) - - return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) - +import torch.nn + +from modules import script_callbacks, shared, devices + +unet_options = [] +current_unet_option = None +current_unet = None +original_forward = None # not used, only left temporarily for compatibility + +def list_unets(): + new_unets = script_callbacks.list_unets_callback() + + unet_options.clear() + unet_options.extend(new_unets) + + +def get_unet_option(option=None): + option = option or shared.opts.sd_unet + + if option == "None": + return None + + if option == "Automatic": + name = shared.sd_model.sd_checkpoint_info.model_name + + options = [x for x in unet_options if x.model_name == name] + + option = options[0].label if options else "None" + + return next(iter([x for x in unet_options if x.label == option]), None) + + +def apply_unet(option=None): + global current_unet_option + global current_unet + + new_option = get_unet_option(option) + if new_option == current_unet_option: + return + + if current_unet is not None: + print(f"Dectivating unet: {current_unet.option.label}") + current_unet.deactivate() + + current_unet_option = new_option + if current_unet_option is None: + current_unet = None + + if not shared.sd_model.lowvram: + shared.sd_model.model.diffusion_model.to(devices.device) + + return + + shared.sd_model.model.diffusion_model.to(devices.cpu) + devices.torch_gc() + + current_unet = current_unet_option.create_unet() + current_unet.option = current_unet_option + print(f"Activating unet: {current_unet.option.label}") + current_unet.activate() + + +class SdUnetOption: + model_name = None + """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" + + label = None + """name of the unet in UI""" + + def create_unet(self): + """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" + raise NotImplementedError() + + +class SdUnet(torch.nn.Module): + def forward(self, x, timesteps, context, *args, **kwargs): + raise NotImplementedError() + + def activate(self): + pass + + def deactivate(self): + pass + + +def create_unet_forward(original_forward): + def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return original_forward(self, x, timesteps, context, *args, **kwargs) + + return UNetModel_forward +