Skip to content

Commit

Permalink
flux bnb on-the-fly for unet-only
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Jan 7, 2025
1 parent 685386a commit 25babb7
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- apply settings skip hidden settings
- lora diffusers method apply only once
- lora diffusers method set prompt tags and metadata
- flux support on-the-fly quantization for bnb of unet only

## Update for 2024-12-31

Expand Down
2 changes: 1 addition & 1 deletion modules/k-diffusion
34 changes: 22 additions & 12 deletions modules/model_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=T5')
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
try:
if quant == 'fp8':
Expand Down Expand Up @@ -203,7 +203,8 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}')
if quant is not None and quant != 'none':
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
# _transformer, _text_encoder_2 = load_flux_gguf(file_path)
from modules import ggml
Expand All @@ -214,25 +215,34 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4' or 'Model' in shared.opts.bnb_quantization:
elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4':
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant: # TODO flux: fix loader for civitai nf4 models
from modules.model_flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path)
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = {}
quant_args = model_quant.create_bnb_config(quant_args)
quant_args = model_quant.create_bnb_config({})
if quant_args:
model_quant.load_bnb(f'Load model: type=Sana quant={quant_args}')
if not quant_args:
quant_args = model_quant.create_ao_config(quant_args)
if quant_args:
model_quant.load_torchao(f'Load model: type=Sana quant={quant_args}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args)
model_quant.load_bnb(f'Load model: type=FLUX quant={quant_args}')
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from modules.model_flux_nf4 import load_flux_nf4
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
quant_args = model_quant.create_ao_config({})
if quant_args:
model_quant.load_torchao(f'Load model: type=FLUX quant={quant_args}')
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=torchao dtype={devices.dtype}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args)
if transformer is not None:
return transformer
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}')
# shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization') # TODO flux transformer from-single-file with quant
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config)
if transformer is None:
shared.log.error('Failed to load UNet model')
shared.opts.sd_unet = 'None'
Expand Down
18 changes: 4 additions & 14 deletions modules/model_flux_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def _replace_with_bnb_linear(
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
bnb = model_quant.load_bnb('Load model: type=FLUX')
Expand Down Expand Up @@ -106,7 +105,6 @@ def create_quantized_param(
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)

new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
return
Expand All @@ -121,13 +119,8 @@ def create_quantized_param(
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")

if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)

if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
Expand All @@ -136,23 +129,20 @@ def create_quantized_param(
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)

else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)

module._parameters[tensor_name] = new_value # pylint: disable=protected-access


def load_flux_nf4(checkpoint_info):
def load_flux_nf4(checkpoint_info, prequantized: bool = True):
transformer = None
text_encoder_2 = None
if isinstance(checkpoint_info, str):
Expand Down Expand Up @@ -197,7 +187,7 @@ def load_flux_nf4(checkpoint_info):
if not check_quantized_param(transformer, param_name):
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
else:
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True)
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
except Exception as e:
transformer, text_encoder_2 = None, None
shared.log.error(f"Load model: type=FLUX failed to load UNET: {e}")
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ def disable_offload(sd_model):


def clear_caches():
shared.log.debug('Cache clear')
# shared.log.debug('Cache clear')
if not shared.opts.lora_legacy:
from modules.lora import networks
networks.loaded_networks.clear()
Expand Down

0 comments on commit 25babb7

Please sign in to comment.