Skip to content

Commit

Permalink
Merge pull request #3519 from vladmandic/dev
Browse files Browse the repository at this point in the history
refresh master
  • Loading branch information
vladmandic authored Oct 24, 2024
2 parents 3ba9ebc + 0d5d2c2 commit 06dd9e5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 14 deletions.
6 changes: 3 additions & 3 deletions modules/ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def unapply(pipe): # pylint: disable=arguments-differ

def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
global clip_loaded # pylint: disable=global-statement
if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl':
shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported')
return False
# overrides
if hasattr(p, 'ip_adapter_names'):
if isinstance(p.ip_adapter_names, str):
Expand All @@ -132,6 +129,9 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
if hasattr(p, 'ip_adapter_images'):
del p.ip_adapter_images
return False
if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl':
shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported')
return False
if hasattr(p, 'ip_adapter_scales'):
adapter_scales = p.ip_adapter_scales
if hasattr(p, 'ip_adapter_crops'):
Expand Down
53 changes: 44 additions & 9 deletions modules/model_sd3.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os
import diffusers
import transformers
from modules import shared, devices, sd_models, sd_unet
from modules import shared, devices, sd_models, sd_unet, model_te


def load_overrides(kwargs, cache_dir):
if shared.opts.sd_unet != 'None':
try:
fn = sd_unet.unet_dict[shared.opts.sd_unet]
kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_single_file(fn, cache_dir=cache_dir, torch_dtype=devices.dtype)
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}"')
if fn.endswith('.safetensors'):
kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_single_file(fn, cache_dir=cache_dir, torch_dtype=devices.dtype)
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=safetensors')
elif fn.endswith('.gguf'):
kwargs = load_gguf(kwargs, fn)
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=gguf')
except Exception as e:
shared.log.error(f"Load model: type=SD3 failed to load UNet: {e}")
shared.opts.sd_unet = 'None'
Expand Down Expand Up @@ -84,6 +90,30 @@ def load_missing(kwargs, fn, cache_dir):
return kwargs


def load_gguf(kwargs, fn):
model_te.install_gguf()
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers
from modules import ggml, sd_hijack_accelerate
with init_empty_weights():
config = diffusers.SD3Transformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
transformer = diffusers.SD3Transformer2DModel.from_config(config).to(devices.dtype)
expected_state_dict_keys = list(transformer.state_dict().keys())
state_dict, stats = ggml.load_gguf_state_dict(fn, devices.dtype)
state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
applied, skipped = 0, 0
for param_name, param in state_dict.items():
if param_name not in expected_state_dict_keys:
skipped += 1
continue
applied += 1
sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0)
state_dict[param_name] = None
shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats} compute={devices.dtype}')
kwargs['transformer'] = transformer
return kwargs


def load_sd3(checkpoint_info, cache_dir=None, config=None):
repo_id = sd_models.path_to_repo(checkpoint_info.name)
fn = checkpoint_info.path
Expand All @@ -92,15 +122,20 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
kwargs = load_overrides(kwargs, cache_dir)
kwargs = load_quants(kwargs, repo_id, cache_dir)

if fn is not None and fn.endswith('.safetensors') and os.path.exists(fn):
kwargs = load_missing(kwargs, fn, cache_dir)
loader = diffusers.StableDiffusion3Pipeline.from_single_file
repo_id = fn
loader = diffusers.StableDiffusion3Pipeline.from_pretrained
if fn is not None and os.path.exists(fn):
if fn.endswith('.safetensors'):
loader = diffusers.StableDiffusion3Pipeline.from_single_file
kwargs = load_missing(kwargs, fn, cache_dir)
repo_id = fn
elif fn.endswith('.gguf'):
kwargs = load_gguf(kwargs, fn)
kwargs = load_missing(kwargs, fn, cache_dir)
kwargs['variant'] = 'fp16'
else:
loader = diffusers.StableDiffusion3Pipeline.from_pretrained
kwargs['variant'] = 'fp16'

shared.log.debug(f'Load model: type=FLUX preloaded={list(kwargs)}')
shared.log.debug(f'Load model: type=SD3 preloaded={list(kwargs)}')

pipe = loader(
repo_id,
Expand Down
1 change: 1 addition & 0 deletions modules/processing_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def process_base(p: processing.StableDiffusionProcessing):
errors.display(e, 'Processing')
except RuntimeError as e:
shared.state.interrupted = True
err_args = base_args.copy()
for k, v in base_args.items():
if isinstance(v, torch.Tensor):
err_args[k] = f'{v.device}:{v.dtype}:{v.shape}'
Expand Down
5 changes: 3 additions & 2 deletions modules/sd_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def load_unet(model):
if prior_text_encoder is not None:
model.prior_pipe.text_encoder = None # Prevent OOM
model.prior_pipe.text_encoder = prior_text_encoder.to(devices.device, dtype=devices.dtype)
elif "Flux" in model.__class__.__name__:
sd_models.load_diffuser() # TODO forcing reloading entire flux as loading transformers only leads to massive memory usage
elif "Flux" in model.__class__.__name__ or "StableDiffusion3" in model.__class__.__name__:
loaded_unet = shared.opts.sd_unet
sd_models.load_diffuser() # TODO forcing reloading entire model as loading transformers only leads to massive memory usage
"""
from modules.model_flux import load_transformer
transformer = load_transformer(unet_dict[shared.opts.sd_unet])
Expand Down

0 comments on commit 06dd9e5

Please sign in to comment.