From 25babb7075258239c1aaae5bd104be8f7ca6e17b Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 7 Jan 2025 17:13:14 -0500 Subject: [PATCH] flux bnb on-the-fly for unet-only Signed-off-by: Vladimir Mandic --- CHANGELOG.md | 1 + modules/k-diffusion | 2 +- modules/model_flux.py | 34 ++++++++++++++++++++++------------ modules/model_flux_nf4.py | 18 ++++-------------- modules/sd_models.py | 2 +- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59993200f..4edcf4603 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ - apply settings skip hidden settings - lora diffusers method apply only once - lora diffusers method set prompt tags and metadata + - flux support on-the-fly quantization for bnb of unet only ## Update for 2024-12-31 diff --git a/modules/k-diffusion b/modules/k-diffusion index 21d12c91a..8018de0b4 160000 --- a/modules/k-diffusion +++ b/modules/k-diffusion @@ -1 +1 @@ -Subproject commit 21d12c91ad4550e8fcf3308ff9fe7116b3f19a08 +Subproject commit 8018de0b43da8d66617f3ef10d3f2a41c1d78836 diff --git a/modules/model_flux.py b/modules/model_flux.py index 9b42de3e7..f8d112953 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -84,7 +84,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu repo_path = checkpoint_info else: repo_path = checkpoint_info.path - model_quant.load_bnb('Load model: type=T5') + model_quant.load_bnb('Load model: type=FLUX') quant = model_quant.get_quant(repo_path) try: if quant == 'fp8': @@ -203,7 +203,8 @@ def load_transformer(file_path): # triggered by opts.sd_unet change "torch_dtype": devices.dtype, "cache_dir": shared.opts.hfcache_dir, } - shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}') + if quant is not None and quant != 'none': + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}') if 'gguf' in file_path.lower(): # _transformer, _text_encoder_2 = load_flux_gguf(file_path) from modules import ggml @@ -214,25 +215,34 @@ def load_transformer(file_path): # triggered by opts.sd_unet change _transformer, _text_encoder_2 = load_flux_quanto(file_path) if _transformer is not None: transformer = _transformer - elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4' or 'Model' in shared.opts.bnb_quantization: + elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4': _transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config) if _transformer is not None: transformer = _transformer elif 'nf4' in quant: # TODO flux: fix loader for civitai nf4 models from modules.model_flux_nf4 import load_flux_nf4 - _transformer, _text_encoder_2 = load_flux_nf4(file_path) + _transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True) if _transformer is not None: transformer = _transformer else: - quant_args = {} - quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_bnb_config({}) if quant_args: - model_quant.load_bnb(f'Load model: type=Sana quant={quant_args}') - if not quant_args: - quant_args = model_quant.create_ao_config(quant_args) - if quant_args: - model_quant.load_torchao(f'Load model: type=Sana quant={quant_args}') - transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args) + model_quant.load_bnb(f'Load model: type=FLUX quant={quant_args}') + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}') + from modules.model_flux_nf4 import load_flux_nf4 + transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False) + if transformer is not None: + return transformer + quant_args = model_quant.create_ao_config({}) + if quant_args: + model_quant.load_torchao(f'Load model: type=FLUX quant={quant_args}') + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=torchao dtype={devices.dtype}') + transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args) + if transformer is not None: + return transformer + shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}') + # shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization') # TODO flux transformer from-single-file with quant + transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config) if transformer is None: shared.log.error('Failed to load UNet model') shared.opts.sd_unet = 'None' diff --git a/modules/model_flux_nf4.py b/modules/model_flux_nf4.py index d023907d6..b00c3320e 100644 --- a/modules/model_flux_nf4.py +++ b/modules/model_flux_nf4.py @@ -24,7 +24,6 @@ def _replace_with_bnb_linear( ): """ Private method that wraps the recursion for module replacement. - Returns the converted model and a boolean that indicates if the conversion has been successfull or not. """ bnb = model_quant.load_bnb('Load model: type=FLUX') @@ -106,7 +105,6 @@ def create_quantized_param( new_value = old_value.to(target_device) else: new_value = param_value.to(target_device) - new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) module._parameters[tensor_name] = new_value # pylint: disable=protected-access return @@ -121,13 +119,8 @@ def create_quantized_param( raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") if pre_quantized: - if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( - param_name + ".quant_state.bitsandbytes__nf4" not in state_dict - ): - raise ValueError( - f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." - ) - + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict): + raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.") quantized_stats = {} for k, v in state_dict.items(): # `startswith` to counter for edge cases where `param_name` @@ -136,23 +129,20 @@ def create_quantized_param( quantized_stats[k] = v if unexpected_keys is not None and k in unexpected_keys: unexpected_keys.remove(k) - new_value = bnb.nn.Params4bit.from_prequantized( data=param_value, quantized_stats=quantized_stats, requires_grad=False, device=target_device, ) - else: new_value = param_value.to("cpu") kwargs = old_value.__dict__ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) - module._parameters[tensor_name] = new_value # pylint: disable=protected-access -def load_flux_nf4(checkpoint_info): +def load_flux_nf4(checkpoint_info, prequantized: bool = True): transformer = None text_encoder_2 = None if isinstance(checkpoint_info, str): @@ -197,7 +187,7 @@ def load_flux_nf4(checkpoint_info): if not check_quantized_param(transformer, param_name): set_module_tensor_to_device(transformer, param_name, device=0, value=param) else: - create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True) + create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized) except Exception as e: transformer, text_encoder_2 = None, None shared.log.error(f"Load model: type=FLUX failed to load UNET: {e}") diff --git a/modules/sd_models.py b/modules/sd_models.py index 8b0eb0a89..1916fa7ec 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1613,7 +1613,7 @@ def disable_offload(sd_model): def clear_caches(): - shared.log.debug('Cache clear') + # shared.log.debug('Cache clear') if not shared.opts.lora_legacy: from modules.lora import networks networks.loaded_networks.clear()