diff --git a/CHANGELOG.md b/CHANGELOG.md index 198500b9a..60e21ff6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log for SD.Next -## Update for 2025-01-12 +## Update for 2025-01-13 - [Allegro Video](https://huggingface.co/rhymes-ai/Allegro) - optimizations: full offload and quantization support @@ -60,6 +60,7 @@ - control restore pipeline before running hires - restore args after batch run - flux controlnet + - zluda installer ## Update for 2024-12-31 diff --git a/modules/control/run.py b/modules/control/run.py index 94d1d0375..03c13240e 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -67,7 +67,10 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str shared.log.warning('Control: T2I-Adapter does not support separate init image') elif unit_type == 'controlnet' and has_models: p.extra_generation_params["Control type"] = 'ControlNet' - p.task_args['controlnet_conditioning_scale'] = [control_conditioning] + if shared.sd_model_type == 'f1': + p.task_args['controlnet_conditioning_scale'] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] + else: + p.task_args['controlnet_conditioning_scale'] = control_conditioning p.task_args['control_guidance_start'] = control_guidance_start p.task_args['control_guidance_end'] = control_guidance_end p.task_args['guess_mode'] = p.guess_mode diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py index 6510226f3..4390461ee 100644 --- a/modules/control/units/controlnet.py +++ b/modules/control/units/controlnet.py @@ -10,8 +10,8 @@ what = 'ControlNet' -debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None predefined_sd15 = { 'Canny': "lllyasviel/control_v11p_sd15_canny", 'Depth': "lllyasviel/control_v11f1p_sd15_depth", @@ -156,7 +156,7 @@ def list_models(refresh=False): else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models()) - debug(f'Control list {what}: path={cache_dir} models={models}') + debug_log(f'Control list {what}: path={cache_dir} models={models}') return models @@ -174,7 +174,7 @@ def __init__(self, model_id: str = None, device = None, dtype = None, load_confi def reset(self): if self.model is not None: - debug(f'Control {what} model unloaded') + debug_log(f'Control {what} model unloaded') self.model = None self.model_id = None @@ -233,7 +233,7 @@ def load_safetensors(self, model_id, model_path): self.load_config['original_config_file '] = config_path cls, config = self.get_class(model_id) if cls is None: - log.error(f'Control {what} model load failed: unknown base model') + log.error(f'Control {what} model load: unknown base model') else: self.model = cls.from_single_file(model_path, config=config, **self.load_config) @@ -246,13 +246,13 @@ def load(self, model_id: str = None, force: bool = True) -> str: self.reset() return if model_id not in all_models: - log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') + log.error(f'Control {what}: id="{model_id}" available={list(all_models)} unknown model') return model_path = all_models[model_id] if model_path == '': return if model_path is None: - log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + log.error(f'Control {what} model load: id="{model_id}" unknown model id') return if 'lora' in model_id.lower(): self.model = model_path @@ -269,12 +269,19 @@ def load(self, model_id: str = None, force: bool = True) -> str: if '/bin' in model_path: model_path = model_path.replace('/bin', '') self.load_config['use_safetensors'] = False + else: + self.load_config['use_safetensors'] = True if cls is None: - log.error(f'Control {what} model load failed: id="{model_id}" unknown base model') + log.error(f'Control {what} model load: id="{model_id}" unknown base model') return if variants.get(model_id, None) is not None: kwargs['variant'] = variants[model_id] - self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) + try: + self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) + except Exception as e: + log.error(f'Control {what} model load: id="{model_id}" {e}') + if debug: + errors.display(e, 'Control') if self.model is None: return if self.dtype is not None: @@ -287,7 +294,7 @@ def load(self, model_id: str = None, force: bool = True) -> str: from modules.sd_models_compile import nncf_compress_model self.model = nncf_compress_model(self.model) except Exception as e: - log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" error={e}') + log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" {e}') elif "ControlNet" in opts.optimum_quanto_weights: try: log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"') @@ -295,7 +302,7 @@ def load(self, model_id: str = None, force: bool = True) -> str: from modules.sd_models_compile import optimum_quanto_model self.model = optimum_quanto_model(self.model) except Exception as e: - log.error(f'Control {what} model Optimum Quanto failed: id="{model_id}" error={e}') + log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}') if self.device is not None: self.model.to(self.device) t1 = time.time() @@ -303,7 +310,7 @@ def load(self, model_id: str = None, force: bool = True) -> str: log.info(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: - log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + log.error(f'Control {what} model load: id="{model_id}" {e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' diff --git a/modules/sd_models.py b/modules/sd_models.py index 9cbdc67dc..1629f2dd7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,13 +13,13 @@ from rich import progress # pylint: disable=redefined-builtin import torch import safetensors.torch -import accelerate from omegaconf import OmegaConf from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect from modules.timer import Timer, process as process_timer from modules.memstats import memory_stats from modules.modeldata import model_data from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import +from modules.sd_offload import set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import model_dir = "Stable-diffusion" @@ -33,8 +33,6 @@ debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None diffusers_version = int(diffusers.__version__.split('.')[1]) checkpoint_tiles = checkpoint_titles # legacy compatibility -should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen'] -offload_hook_instance = None class NoWatermark: @@ -306,250 +304,6 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument set_diffuser_offload(sd_model, op) -def set_accelerate_to_module(model): - if hasattr(model, "pipe"): - set_accelerate_to_module(model.pipe) - if hasattr(model, "_internal_dict"): - for k in model._internal_dict.keys(): # pylint: disable=protected-access - component = getattr(model, k, None) - if isinstance(component, torch.nn.Module): - component.has_accelerate = True - - -def set_accelerate(sd_model): - sd_model.has_accelerate = True - set_accelerate_to_module(sd_model) - if hasattr(sd_model, "prior_pipe"): - set_accelerate_to_module(sd_model.prior_pipe) - if hasattr(sd_model, "decoder_pipe"): - set_accelerate_to_module(sd_model.decoder_pipe) - - -def set_diffuser_offload(sd_model, op: str = 'model'): - t0 = time.time() - if not shared.native: - shared.log.warning('Attempting to use offload with backend=original') - return - if sd_model is None: - shared.log.warning(f'{op} is not loaded') - return - if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): - sd_model.has_accelerate = False - if shared.opts.diffusers_offload_mode == "none": - if shared.sd_model_type in should_offload: - shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') - else: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if hasattr(sd_model, 'maybe_free_model_hooks'): - sd_model.maybe_free_model_hooks() - sd_model.has_accelerate = False - if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): - try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: - shared.opts.diffusers_move_base = False - shared.opts.diffusers_move_unet = False - shared.opts.diffusers_move_refiner = False - shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled') - if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access - sd_model.enable_model_cpu_offload(device=devices.device) - else: - sd_model.maybe_free_model_hooks() - set_accelerate(sd_model) - except Exception as e: - shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): - try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') - if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: - shared.opts.diffusers_move_base = False - shared.opts.diffusers_move_unet = False - shared.opts.diffusers_move_refiner = False - shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled') - if sd_model.has_accelerate: - if op == "vae": # reapply sequential offload to vae - from accelerate import cpu_offload - sd_model.vae.to("cpu") - cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access - else: - pass # do nothing if offload is already applied - else: - sd_model.enable_sequential_cpu_offload(device=devices.device) - set_accelerate(sd_model) - except Exception as e: - shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - if shared.opts.diffusers_offload_mode == "balanced": - sd_model = apply_balanced_offload(sd_model) - process_timer.add('offload', time.time() - t0) - - -class OffloadHook(accelerate.hooks.ModelHook): - def __init__(self, checkpoint_name): - if shared.opts.diffusers_offload_max_gpu_memory > 1: - shared.opts.diffusers_offload_max_gpu_memory = 0.75 - if shared.opts.diffusers_offload_max_cpu_memory > 1: - shared.opts.diffusers_offload_max_cpu_memory = 0.75 - self.checkpoint_name = checkpoint_name - self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory - self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory - self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory - self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) - self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) - self.offload_map = {} - self.param_map = {} - gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}' - shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}') - self.validate() - super().__init__() - - def validate(self): - if shared.opts.diffusers_offload_mode != 'balanced': - return - if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: - shared.opts.diffusers_offload_min_gpu_memory = 0.25 - shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') - if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: - shared.opts.diffusers_offload_max_gpu_memory = 0.75 - shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') - if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: - shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory - shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') - if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: - shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') - - def model_size(self): - return sum(self.offload_map.values()) - - def init_hook(self, module): - return module - - def pre_forward(self, module, *args, **kwargs): - if devices.normalize_device(module.device) != devices.normalize_device(devices.device): - device_index = torch.device(devices.device).index - if device_index is None: - device_index = 0 - max_memory = { device_index: self.gpu, "cpu": self.cpu } - device_map = getattr(module, "balanced_offload_device_map", None) - if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None): - device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) - offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) - module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - module.balanced_offload_device_map = device_map - module.balanced_offload_max_memory = max_memory - return args, kwargs - - def post_forward(self, module, output): - return output - - def detach_hook(self, module): - return module - - -def apply_balanced_offload(sd_model, exclude=[]): - global offload_hook_instance # pylint: disable=global-statement - if shared.opts.diffusers_offload_mode != "balanced": - return sd_model - t0 = time.time() - excluded = ['OmniGenPipeline'] - if sd_model.__class__.__name__ in excluded: - return sd_model - cached = True - checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None - if checkpoint_name is None: - checkpoint_name = sd_model.__class__.__name__ - if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name: - cached = False - offload_hook_instance = OffloadHook(checkpoint_name) - - def get_pipe_modules(pipe): - if hasattr(pipe, "_internal_dict"): - modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access - else: - modules_names = get_signature(pipe).keys() - modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] - modules = {} - for module_name in modules_names: - module_size = offload_hook_instance.offload_map.get(module_name, None) - if module_size is None: - module = getattr(pipe, module_name, None) - if not isinstance(module, torch.nn.Module): - continue - try: - module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 - param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 - except Exception as e: - shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') - module_size = 0 - offload_hook_instance.offload_map[module_name] = module_size - offload_hook_instance.param_map[module_name] = param_num - modules[module_name] = module_size - modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) - return modules - - def apply_balanced_offload_to_module(pipe): - used_gpu, used_ram = devices.torch_gc(fast=True) - if hasattr(pipe, "pipe"): - apply_balanced_offload_to_module(pipe.pipe) - if hasattr(pipe, "_internal_dict"): - keys = pipe._internal_dict.keys() # pylint: disable=protected-access - else: - keys = get_signature(pipe).keys() - keys = [k for k in keys if k not in exclude and not k.startswith('_')] - for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access - module = getattr(pipe, module_name, None) - if module is None: - continue - network_layer_name = getattr(module, "network_layer_name", None) - device_map = getattr(module, "balanced_offload_device_map", None) - max_memory = getattr(module, "balanced_offload_max_memory", None) - module = accelerate.hooks.remove_hook_from_module(module, recurse=True) - perc_gpu = used_gpu / shared.gpu_memory - try: - prev_gpu = used_gpu - do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu) - if do_offload: - module = module.to(devices.cpu, non_blocking=True) - used_gpu -= module_size - if not cached: - shared.log.debug(f'Model module={module_name} type={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}') - debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}') - except Exception as e: - if 'out of memory' in str(e): - devices.torch_gc(fast=True, force=True, reason='oom') - elif 'bitsandbytes' in str(e): - pass - else: - shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}') - if os.environ.get('SD_MOVE_DEBUG', None): - errors.display(e, f'Offload: type=balanced op=apply module={module_name}') - module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) - module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - if network_layer_name: - module.network_layer_name = network_layer_name - if device_map and max_memory: - module.balanced_offload_device_map = device_map - module.balanced_offload_max_memory = max_memory - devices.torch_gc(fast=True, force=True, reason='offload') - - apply_balanced_offload_to_module(sd_model) - if hasattr(sd_model, "pipe"): - apply_balanced_offload_to_module(sd_model.pipe) - if hasattr(sd_model, "prior_pipe"): - apply_balanced_offload_to_module(sd_model.prior_pipe) - if hasattr(sd_model, "decoder_pipe"): - apply_balanced_offload_to_module(sd_model.decoder_pipe) - set_accelerate(sd_model) - t = time.time() - t0 - process_timer.add('offload', t) - fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access - debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') - if not cached: - shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') - return sd_model - - def move_model(model, device=None, force=False): if model is None or device is None: return diff --git a/modules/sd_offload.py b/modules/sd_offload.py new file mode 100644 index 000000000..33aea13a3 --- /dev/null +++ b/modules/sd_offload.py @@ -0,0 +1,261 @@ +import os +import sys +import time +import inspect +import torch +import accelerate +from modules import shared, devices, errors +from modules.timer import process as process_timer + + +debug_move = shared.log.trace if os.environ.get('SD_MOVE_DEBUG', None) is not None else lambda *args, **kwargs: None +should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen'] +offload_hook_instance = None + + +def get_signature(cls): + signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def set_accelerate(sd_model): + def set_accelerate_to_module(model): + if hasattr(model, "pipe"): + set_accelerate_to_module(model.pipe) + if hasattr(model, "_internal_dict"): + for k in model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(model, k, None) + if isinstance(component, torch.nn.Module): + component.has_accelerate = True + + sd_model.has_accelerate = True + set_accelerate_to_module(sd_model) + if hasattr(sd_model, "prior_pipe"): + set_accelerate_to_module(sd_model.prior_pipe) + if hasattr(sd_model, "decoder_pipe"): + set_accelerate_to_module(sd_model.decoder_pipe) + + +def set_diffuser_offload(sd_model, op: str = 'model'): + t0 = time.time() + if not shared.native: + shared.log.warning('Attempting to use offload with backend=original') + return + if sd_model is None: + shared.log.warning(f'{op} is not loaded') + return + if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): + sd_model.has_accelerate = False + if shared.opts.diffusers_offload_mode == "none": + if shared.sd_model_type in should_offload: + shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') + else: + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if hasattr(sd_model, 'maybe_free_model_hooks'): + sd_model.maybe_free_model_hooks() + sd_model.has_accelerate = False + if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): + try: + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: + shared.opts.diffusers_move_base = False + shared.opts.diffusers_move_unet = False + shared.opts.diffusers_move_refiner = False + shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled') + if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access + sd_model.enable_model_cpu_offload(device=devices.device) + else: + sd_model.maybe_free_model_hooks() + set_accelerate(sd_model) + except Exception as e: + shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') + if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): + try: + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: + shared.opts.diffusers_move_base = False + shared.opts.diffusers_move_unet = False + shared.opts.diffusers_move_refiner = False + shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled') + if sd_model.has_accelerate: + if op == "vae": # reapply sequential offload to vae + from accelerate import cpu_offload + sd_model.vae.to("cpu") + cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access + else: + pass # do nothing if offload is already applied + else: + sd_model.enable_sequential_cpu_offload(device=devices.device) + set_accelerate(sd_model) + except Exception as e: + shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') + if shared.opts.diffusers_offload_mode == "balanced": + sd_model = apply_balanced_offload(sd_model) + process_timer.add('offload', time.time() - t0) + + +class OffloadHook(accelerate.hooks.ModelHook): + def __init__(self, checkpoint_name): + if shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + if shared.opts.diffusers_offload_max_cpu_memory > 1: + shared.opts.diffusers_offload_max_cpu_memory = 0.75 + self.checkpoint_name = checkpoint_name + self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory + self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory + self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory + self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) + self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) + self.offload_map = {} + self.param_map = {} + gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}' + shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}') + self.validate() + super().__init__() + + def validate(self): + if shared.opts.diffusers_offload_mode != 'balanced': + return + if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: + shared.opts.diffusers_offload_min_gpu_memory = 0.25 + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') + if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') + if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: + shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') + if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') + + def model_size(self): + return sum(self.offload_map.values()) + + def init_hook(self, module): + return module + + def pre_forward(self, module, *args, **kwargs): + if devices.normalize_device(module.device) != devices.normalize_device(devices.device): + device_index = torch.device(devices.device).index + if device_index is None: + device_index = 0 + max_memory = { device_index: self.gpu, "cpu": self.cpu } + device_map = getattr(module, "balanced_offload_device_map", None) + if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None): + device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) + offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) + module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + return args, kwargs + + def post_forward(self, module, output): + return output + + def detach_hook(self, module): + return module + + +def apply_balanced_offload(sd_model, exclude=[]): + global offload_hook_instance # pylint: disable=global-statement + if shared.opts.diffusers_offload_mode != "balanced": + return sd_model + t0 = time.time() + excluded = ['OmniGenPipeline'] + if sd_model.__class__.__name__ in excluded: + return sd_model + cached = True + checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None + if checkpoint_name is None: + checkpoint_name = sd_model.__class__.__name__ + if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name: + cached = False + offload_hook_instance = OffloadHook(checkpoint_name) + + def get_pipe_modules(pipe): + if hasattr(pipe, "_internal_dict"): + modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access + else: + modules_names = get_signature(pipe).keys() + modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] + modules = {} + for module_name in modules_names: + module_size = offload_hook_instance.offload_map.get(module_name, None) + if module_size is None: + module = getattr(pipe, module_name, None) + if not isinstance(module, torch.nn.Module): + continue + try: + module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + except Exception as e: + shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') + module_size = 0 + offload_hook_instance.offload_map[module_name] = module_size + offload_hook_instance.param_map[module_name] = param_num + modules[module_name] = module_size + modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) + return modules + + def apply_balanced_offload_to_module(pipe): + used_gpu, used_ram = devices.torch_gc(fast=True) + if hasattr(pipe, "pipe"): + apply_balanced_offload_to_module(pipe.pipe) + if hasattr(pipe, "_internal_dict"): + keys = pipe._internal_dict.keys() # pylint: disable=protected-access + else: + keys = get_signature(pipe).keys() + keys = [k for k in keys if k not in exclude and not k.startswith('_')] + for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access + module = getattr(pipe, module_name, None) + if module is None: + continue + network_layer_name = getattr(module, "network_layer_name", None) + device_map = getattr(module, "balanced_offload_device_map", None) + max_memory = getattr(module, "balanced_offload_max_memory", None) + module = accelerate.hooks.remove_hook_from_module(module, recurse=True) + perc_gpu = used_gpu / shared.gpu_memory + try: + prev_gpu = used_gpu + do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu) + if do_offload: + module = module.to(devices.cpu, non_blocking=True) + used_gpu -= module_size + if not cached: + shared.log.debug(f'Model module={module_name} type={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}') + debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}') + except Exception as e: + if 'out of memory' in str(e): + devices.torch_gc(fast=True, force=True, reason='oom') + elif 'bitsandbytes' in str(e): + pass + else: + shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}') + if os.environ.get('SD_MOVE_DEBUG', None): + errors.display(e, f'Offload: type=balanced op=apply module={module_name}') + module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) + module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + if network_layer_name: + module.network_layer_name = network_layer_name + if device_map and max_memory: + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + devices.torch_gc(fast=True, force=True, reason='offload') + + apply_balanced_offload_to_module(sd_model) + if hasattr(sd_model, "pipe"): + apply_balanced_offload_to_module(sd_model.pipe) + if hasattr(sd_model, "prior_pipe"): + apply_balanced_offload_to_module(sd_model.prior_pipe) + if hasattr(sd_model, "decoder_pipe"): + apply_balanced_offload_to_module(sd_model.decoder_pipe) + set_accelerate(sd_model) + t = time.time() - t0 + process_timer.add('offload', t) + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') + if not cached: + shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') + return sd_model