diff --git a/CHANGELOG.md b/CHANGELOG.md index a91cc0530..6cf7b67d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Change Log for SD.Next +## Update for 2024-10-24 + +Improvements: +- SD3 loader enhancements + - report when loading incomplete model + - handle missing model components + - handle component preloading +- OpenVINO: add accuracy option +- ZLUDA: guess GPU arch + +Fixes: +- fix send-to-control +- fix k-diffusion +- fix sd3 img2img and hires +- fix ipadapter supported model detection +- fix t2iadapter auto-download +- fix omnigen dynamic attention +- handle a1111 prompt scheduling +- handle omnigen image placeholder in prompt + ## Update for 2024-10-23 ### Highlights for 2024-10-23 diff --git a/modules/control/run.py b/modules/control/run.py index 01779119e..e3c7bbf76 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -219,7 +219,7 @@ def control_run(state: str = '', p_extra_args = {} if shared.sd_model is None: - shared.log.warning('Model not loaded') + shared.log.warning('Aborted: op=control model not loaded') return [], '', '', 'Error: model not loaded' unit_type = unit_type.strip().lower() if unit_type is not None else '' diff --git a/modules/control/units/t2iadapter.py b/modules/control/units/t2iadapter.py index 81e78c6ac..6e15abe3d 100644 --- a/modules/control/units/t2iadapter.py +++ b/modules/control/units/t2iadapter.py @@ -11,17 +11,17 @@ debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') predefined_sd15 = { - 'Segment': 'TencentARC/t2iadapter_seg_sd14v1', - 'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1', - 'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1', - 'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1', - 'Color': 'TencentARC/t2iadapter_color_sd14v1', - 'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1', - 'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2', - 'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1', - 'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2', - 'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1', - 'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2', + 'Segment': ('TencentARC/t2iadapter_seg_sd14v1', {}), + 'Zoe Depth': ('TencentARC/t2iadapter_zoedepth_sd15v1', {}), + 'OpenPose': ('TencentARC/t2iadapter_openpose_sd14v1', {}), + 'KeyPose': ('TencentARC/t2iadapter_keypose_sd14v1', {}), + 'Color': ('TencentARC/t2iadapter_color_sd14v1', {}), + 'Depth v1': ('TencentARC/t2iadapter_depth_sd14v1', {}), + 'Depth v2': ('TencentARC/t2iadapter_depth_sd15v2', {}), + 'Canny v1': ('TencentARC/t2iadapter_canny_sd14v1', {}), + 'Canny v2': ('TencentARC/t2iadapter_canny_sd15v2', {}), + 'Sketch v1': ('TencentARC/t2iadapter_sketch_sd14v1', {}), + 'Sketch v2': ('TencentARC/t2iadapter_sketch_sd15v2', {}), # 'Coadapter Canny': 'TencentARC/T2I-Adapter/models/coadapter-canny-sd15v1.pth', # 'Coadapter Color': 'TencentARC/T2I-Adapter/models/coadapter-color-sd15v1.pth', # 'Coadapter Depth': 'TencentARC/T2I-Adapter/models/coadapter-depth-sd15v1.pth', @@ -30,12 +30,12 @@ # 'Coadapter Style': 'TencentARC/T2I-Adapter/models/coadapter-style-sd15v1.pth', } predefined_sdxl = { - 'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0', - 'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0', - 'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0', - 'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0', - 'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0', - 'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0', + 'Canny XL': ('TencentARC/t2i-adapter-canny-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), + 'LineArt XL': ('TencentARC/t2i-adapter-lineart-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), + 'Sketch XL': ('TencentARC/t2i-adapter-sketch-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), + 'Zoe Depth XL': ('TencentARC/t2i-adapter-depth-zoe-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), + 'OpenPose XL': ('TencentARC/t2i-adapter-openpose-sdxl-1.0', { 'use_safetensors': True }), + 'Midas Depth XL': ('TencentARC/t2i-adapter-depth-midas-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), } models = {} @@ -96,7 +96,8 @@ def load(self, model_id: str = None, force: bool = True) -> str: if model_id not in all_models: log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') return - model_path = all_models[model_id] + model_path, model_args = all_models[model_id] + self.load_config.update(model_args) if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return diff --git a/modules/devices.py b/modules/devices.py index c23f4a256..490d2a54d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -46,6 +46,16 @@ def has_xpu() -> bool: return bool(hasattr(torch, 'xpu') and torch.xpu.is_available()) +def has_zluda() -> bool: + if not cuda_ok: + return False + try: + device = torch.device("cuda") + return torch.cuda.get_device_name(device).endswith("[ZLUDA]") + except Exception: + return False + + def get_backend(shared_cmd_opts): global args # pylint: disable=global-statement args = shared_cmd_opts @@ -55,6 +65,8 @@ def get_backend(shared_cmd_opts): name = 'directml' elif has_xpu(): name = 'ipex' + elif has_zluda(): + name = 'zluda' elif torch.cuda.is_available() and torch.version.cuda: name = 'cuda' elif torch.cuda.is_available() and torch.version.hip: @@ -109,7 +121,7 @@ def get_package_version(pkg: str): 'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} n={torch.xpu.device_count()}', 'ipex': get_package_version('intel-extension-for-pytorch'), } - elif backend == 'cuda': + elif backend == 'cuda' or backend == 'zluda': return { 'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()} arch={torch.cuda.get_arch_list()[-1]} capability={torch.cuda.get_device_capability(device)}', 'cuda': torch.version.cuda, @@ -267,9 +279,22 @@ def test_bf16(): global bf16_ok # pylint: disable=global-statement if bf16_ok is not None: return bf16_ok - if sys.platform == "darwin" or backend == 'openvino' or backend == 'directml': # override - bf16_ok = False - return bf16_ok + if opts.cuda_dtype != 'BF16': # don't override if the user sets it + if sys.platform == "darwin" or backend == 'openvino' or backend == 'directml': # override + bf16_ok = False + return bf16_ok + elif backend == 'zluda': + device_name = torch.cuda.get_device_name(device) + if device_name.startswith("AMD Radeon RX "): # only force AMD + device_name = device_name.replace("AMD Radeon RX ", "").split(" ", maxsplit=1)[0] + if len(device_name) == 4 and device_name[0] in {"5", "6"}: # RDNA 1 and 2 + bf16_ok = False + return bf16_ok + elif backend == 'rocm': + gcn_arch = getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000")[3:7] + if len(gcn_arch) == 4 and gcn_arch[0:2] == "10": # RDNA 1 and 2 + bf16_ok = False + return bf16_ok try: import torch.nn.functional as F image = torch.randn(1, 4, 32, 32).to(device=device, dtype=torch.bfloat16) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 9ee2ece46..a574e8469 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -91,7 +91,7 @@ def activate(p, extra_network_data, step=0): try: extra_network.activate(p, extra_network_args, step=step) except Exception as e: - errors.display(e, f"activating extra network: name={extra_network_name} args:{extra_network_args}") + errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): args = extra_network_data.get(extra_network_name, None) @@ -100,7 +100,7 @@ def activate(p, extra_network_data, step=0): try: extra_network.activate(p, []) except Exception as e: - errors.display(e, f"activating extra network: name={extra_network_name}") + errors.display(e, f"Activating network: type={extra_network_name}") if stepwise: p.extra_network_data = extra_network_data diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index eacab607d..cf4278e80 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -145,7 +145,7 @@ def connect_paste_params_buttons(): if binding.source_text_component is not None and fields is not None: connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None and binding.source_tabname in paste_fields: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names + paste_field_names = ['Prompt', 'Negative prompt', 'Steps'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names if "fields" in paste_fields[binding.source_tabname] and paste_fields[binding.source_tabname]["fields"] is not None: binding.paste_button.click( fn=lambda *x: x, diff --git a/modules/img2img.py b/modules/img2img.py index 2417c01b3..faf65161a 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -141,7 +141,7 @@ def img2img(id_task: str, state: str, mode: int, *args): # pylint: disable=unused-argument if shared.sd_model is None: - shared.log.warning('Model not loaded') + shared.log.warning('Aborted: op=img model not loaded') return [], '', '', 'Error: model not loaded' debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|detailer={detailer}|tiling={tiling}|hidiffusion={hidiffusion}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|resize_context={resize_context}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}') diff --git a/modules/intel/ipex/attention.py b/modules/intel/ipex/attention.py index dead035e0..22c74a78b 100644 --- a/modules/intel/ipex/attention.py +++ b/modules/intel/ipex/attention.py @@ -136,6 +136,11 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo if do_split: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + if attn_mask is not None and attn_mask.shape != query.shape: + if len(query.shape) == 4: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1)) + else: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2])) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size diff --git a/modules/intel/openvino/__init__.py b/modules/intel/openvino/__init__.py index 975a97672..157d26d96 100644 --- a/modules/intel/openvino/__init__.py +++ b/modules/intel/openvino/__init__.py @@ -7,6 +7,7 @@ from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder from openvino.frontend.pytorch.torchdynamo.partition import Partitioner from openvino.runtime import Core, Type, PartialShape, serialize +from openvino.properties import hint as ov_hints from torch._dynamo.backends.common import fake_tensor_unsupported from torch._dynamo.backends.registry import register_backend @@ -156,7 +157,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non core = Core() device = get_device() - cache_root = shared.opts.openvino_cache_path global dont_use_4bit_nncf global dont_use_nncf global dont_use_quant @@ -233,9 +233,14 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = Non else: om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=8, ratio=shared.opts.nncf_compress_weights_raito) - + hints = {} + if shared.opts.openvino_accuracy == "performance": + hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE + elif shared.opts.openvino_accuracy == "accuracy": + hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY if model_hash_str is not None: - core.set_property({'CACHE_DIR': cache_root + '/blob'}) + hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob' + core.set_property(hints) dont_use_nncf = False dont_use_quant = False dont_use_4bit_nncf = False @@ -286,7 +291,12 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs): else: om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=8, ratio=shared.opts.nncf_compress_weights_raito) - core.set_property({'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}) + hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'} + if shared.opts.openvino_accuracy == "performance": + hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE + elif shared.opts.openvino_accuracy == "accuracy": + hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY + core.set_property(hints) dont_use_nncf = False dont_use_quant = False dont_use_4bit_nncf = False diff --git a/modules/ipadapter.py b/modules/ipadapter.py index 80c1a9b7f..3f85f3b65 100644 --- a/modules/ipadapter.py +++ b/modules/ipadapter.py @@ -113,6 +113,9 @@ def unapply(pipe): # pylint: disable=arguments-differ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]): global clip_loaded # pylint: disable=global-statement + if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl': + shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported') + return False # overrides if hasattr(p, 'ip_adapter_names'): if isinstance(p.ip_adapter_names, str): @@ -183,9 +186,6 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt if not hasattr(pipe, 'load_ip_adapter'): shared.log.error(f'IP adapter: pipeline not supported: {pipe.__class__.__name__}') return False - if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl': - shared.log.error(f'IP adapter: unsupported model type: {shared.sd_model_type}') - return False for adapter_name in adapter_names: # which clip to use diff --git a/modules/model_flux.py b/modules/model_flux.py index d696f7df6..38207f73b 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -41,7 +41,7 @@ def load_flux_quanto(checkpoint_info): except Exception: shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer.dtype}") except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load Quanto transformer: {e}") + shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}") if debug: from modules import errors errors.display(e, 'FLUX Quanto:') @@ -68,7 +68,7 @@ def load_flux_quanto(checkpoint_info): except Exception: shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2.dtype}") except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load Quanto text encoder: {e}") + shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}") if debug: from modules import errors errors.display(e, 'FLUX Quanto:') @@ -100,7 +100,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu else: transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config) except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load BnB transformer: {e}") + shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}") transformer, text_encoder_2 = None, None if debug: from modules import errors @@ -222,7 +222,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch shared.opts.sd_unet = 'None' sd_unet.failed_unet.append(shared.opts.sd_unet) except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load UNet: {e}") + shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}") shared.opts.sd_unet = 'None' if debug: from modules import errors @@ -236,7 +236,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch else: text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load T5: {e}") + shared.log.error(f"Load model: type=FLUX failed to load T5: {e}") shared.opts.sd_text_encoder = 'None' if debug: from modules import errors @@ -251,7 +251,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch vae_config = os.path.join('configs', 'flux', 'vae', 'config.json') vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config) except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load VAE: {e}") + shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}") shared.opts.sd_vae = 'None' if debug: from modules import errors @@ -267,7 +267,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch if _text_encoder is not None: text_encoder_2 = _text_encoder except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load NF4 components: {e}") + shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}") if debug: from modules import errors errors.display(e, 'FLUX NF4:') @@ -279,7 +279,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch if _text_encoder is not None: text_encoder_2 = _text_encoder except Exception as e: - shared.log.error(f"Load model: type=FLUX Failed to load Quanto components: {e}") + shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}") if debug: from modules import errors errors.display(e, 'FLUX Quanto:') diff --git a/modules/model_flux_nf4.py b/modules/model_flux_nf4.py index a1b46fd54..d023907d6 100644 --- a/modules/model_flux_nf4.py +++ b/modules/model_flux_nf4.py @@ -200,7 +200,7 @@ def load_flux_nf4(checkpoint_info): create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True) except Exception as e: transformer, text_encoder_2 = None, None - shared.log.error(f"Load model: type=FLUX Failed to load UNET: {e}") + shared.log.error(f"Load model: type=FLUX failed to load UNET: {e}") if debug: from modules import errors errors.display(e, 'FLUX:') diff --git a/modules/model_sd3.py b/modules/model_sd3.py index 96f194c66..72f3a0c32 100644 --- a/modules/model_sd3.py +++ b/modules/model_sd3.py @@ -1,56 +1,49 @@ import os import diffusers import transformers +from modules import shared, devices, sd_models, sd_unet -default_repo_id = 'stabilityai/stable-diffusion-3-medium' +def load_overrides(kwargs, cache_dir): + if shared.opts.sd_unet != 'None': + try: + fn = sd_unet.unet_dict[shared.opts.sd_unet] + kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_single_file(fn, cache_dir=cache_dir, torch_dtype=devices.dtype) + shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}"') + except Exception as e: + shared.log.error(f"Load model: type=SD3 failed to load UNet: {e}") + shared.opts.sd_unet = 'None' + sd_unet.failed_unet.append(shared.opts.sd_unet) + if shared.opts.sd_text_encoder != 'None': + try: + from modules.model_te import load_t5, load_vit_l, load_vit_g + if 'vit-l' in shared.opts.sd_text_encoder.lower(): + kwargs['text_encoder'] = load_vit_l() + shared.log.debug(f'Load model: type=SD3 variant="vit-l" te="{shared.opts.sd_text_encoder}"') + elif 'vit-g' in shared.opts.sd_text_encoder.lower(): + kwargs['text_encoder_2'] = load_vit_g() + shared.log.debug(f'Load model: type=SD3 variant="vit-g" te="{shared.opts.sd_text_encoder}"') + else: + kwargs['text_encoder_3'] = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) + shared.log.debug(f'Load model: type=SD3 variant="t5" te="{shared.opts.sd_text_encoder}"') + except Exception as e: + shared.log.error(f"Load model: type=SD3 failed to load T5: {e}") + shared.opts.sd_text_encoder = 'None' + if shared.opts.sd_vae != 'None' and shared.opts.sd_vae != 'Automatic': + try: + from modules import sd_vae + vae_file = sd_vae.vae_dict[shared.opts.sd_vae] + if os.path.exists(vae_file): + vae_config = os.path.join('configs', 'flux', 'vae', 'config.json') + kwargs['vae'] = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, cache_dir=cache_dir, torch_dtype=devices.dtype) + shared.log.debug(f'Load model: type=SD3 vae="{shared.opts.sd_vae}"') + except Exception as e: + shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}") + shared.opts.sd_vae = 'None' + return kwargs -def load_sd3(checkpoint_info, cache_dir=None, config=None): - from modules import shared, devices, modelloader, sd_models - repo_id = sd_models.path_to_repo(checkpoint_info.name) - dtype = devices.dtype - kwargs = {} - if checkpoint_info.path is not None and checkpoint_info.path.endswith('.safetensors') and os.path.exists(checkpoint_info.path): - loader = diffusers.StableDiffusion3Pipeline.from_single_file - fn_size = os.path.getsize(checkpoint_info.path) - if fn_size < 5e9: - kwargs = { - 'text_encoder': transformers.CLIPTextModelWithProjection.from_pretrained( - default_repo_id, - subfolder='text_encoder', - cache_dir=cache_dir, - torch_dtype=dtype, - ), - 'text_encoder_2': transformers.CLIPTextModelWithProjection.from_pretrained( - default_repo_id, - subfolder='text_encoder_2', - cache_dir=cache_dir, - torch_dtype=dtype, - ), - 'tokenizer': transformers.CLIPTokenizer.from_pretrained( - default_repo_id, - subfolder='tokenizer', - cache_dir=cache_dir, - ), - 'tokenizer_2': transformers.CLIPTokenizer.from_pretrained( - default_repo_id, - subfolder='tokenizer_2', - cache_dir=cache_dir, - ), - 'text_encoder_3': None, - } - elif fn_size < 1e10: # if model is below 10gb it does not have te3 - kwargs = { - 'text_encoder_3': None, - } - else: - kwargs = {} - else: - modelloader.hf_login() - loader = diffusers.StableDiffusion3Pipeline.from_pretrained - kwargs['variant'] = 'fp16' - +def load_quants(kwargs, repo_id, cache_dir): if len(shared.opts.bnb_quantization) > 0: from modules.model_quant import load_bnb load_bnb('Load model: type=SD3') @@ -61,18 +54,57 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None): bnb_4bit_quant_type=shared.opts.bnb_quantization_type, bnb_4bit_compute_dtype=devices.dtype ) - if 'Model' in shared.opts.bnb_quantization: - transformer = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) + if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: + kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') - kwargs['transformer'] = transformer - if 'Text Encoder' in shared.opts.bnb_quantization: - te3 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) + if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs: + kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') - kwargs['text_encoder_3'] = te3 + return kwargs + + +def load_missing(kwargs, fn, cache_dir): + keys = sd_models.get_safetensor_keys(fn) + size = os.stat(fn).st_size // 1024 // 1024 + if size > 15000: + repo_id = 'stabilityai/stable-diffusion-3.5-large' + else: + repo_id = 'stabilityai/stable-diffusion-3-medium' + if 'text_encoder' not in kwargs and 'text_encoder' not in keys: + kwargs['text_encoder'] = transformers.CLIPTextModelWithProjection.from_pretrained(repo_id, subfolder='text_encoder', cache_dir=cache_dir, torch_dtype=devices.dtype) + shared.log.debug(f'Load model: type=SD3 missing=te1 repo="{repo_id}"') + if 'text_encoder_2' not in kwargs and 'text_encoder_2' not in keys: + kwargs['text_encoder_2'] = transformers.CLIPTextModelWithProjection.from_pretrained(repo_id, subfolder='text_encoder_2', cache_dir=cache_dir, torch_dtype=devices.dtype) + shared.log.debug(f'Load model: type=SD3 missing=te2 repo="{repo_id}"') + if 'text_encoder_3' not in kwargs and 'text_encoder_3' not in keys: + kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype) + shared.log.debug(f'Load model: type=SD3 missing=te3 repo="{repo_id}"') + # if 'transformer' not in kwargs and 'transformer' not in keys: + # kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(default_repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype) + return kwargs + + +def load_sd3(checkpoint_info, cache_dir=None, config=None): + repo_id = sd_models.path_to_repo(checkpoint_info.name) + fn = checkpoint_info.path + + kwargs = {} + kwargs = load_overrides(kwargs, cache_dir) + kwargs = load_quants(kwargs, repo_id, cache_dir) + + if fn is not None and fn.endswith('.safetensors') and os.path.exists(fn): + kwargs = load_missing(kwargs, fn, cache_dir) + loader = diffusers.StableDiffusion3Pipeline.from_single_file + repo_id = fn + else: + loader = diffusers.StableDiffusion3Pipeline.from_pretrained + kwargs['variant'] = 'fp16' + + shared.log.debug(f'Load model: type=FLUX preloaded={list(kwargs)}') pipe = loader( repo_id, - torch_dtype=dtype, + torch_dtype=devices.dtype, cache_dir=cache_dir, config=config, **kwargs, diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py index 2f0e12086..b162240e0 100644 --- a/modules/postprocess/yolo.py +++ b/modules/postprocess/yolo.py @@ -56,7 +56,7 @@ def enumerate(self): name = os.path.splitext(os.path.basename(f))[0] if name not in files: self.list[name] = os.path.join(shared.opts.yolo_dir, f) - shared.log.info(f'Available Yolo: path="{shared.opts.yolo_dir} items={len(list(self.list))} downloaded={downloaded}') + shared.log.info(f'Available Yolo: path="{shared.opts.yolo_dir}" items={len(list(self.list))} downloaded={downloaded}') return self.list def dependencies(self): diff --git a/modules/processing_args.py b/modules/processing_args.py index 0e03ec1ba..1ea91fb08 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -127,7 +127,7 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 if 'prompt' in possible: if 'OmniGen' in model.__class__.__name__: - p.prompts = [p.replace('|image|', '<|image_1|>') for p in prompts] + prompts = [p.replace('|image|', '<|image_1|>') for p in prompts] if hasattr(model, 'text_encoder') and 'prompt_embeds' in possible and len(p.prompt_embeds) > 0 and p.prompt_embeds[0] is not None: args['prompt_embeds'] = p.prompt_embeds[0] if 'StableCascade' in model.__class__.__name__ and len(getattr(p, 'negative_pooleds', [])) > 0: @@ -256,12 +256,13 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 # handle missing resolution if args.get('image', None) is not None and ('width' not in args or 'height' not in args): - if isinstance(args['image'], torch.Tensor) or isinstance(args['image'], np.ndarray): - args['width'] = 8 * args['image'].shape[-1] - args['height'] = 8 * args['image'].shape[-2] - else: - args['width'] = 8 * math.ceil(args['image'][0].width / 8) - args['height'] = 8 * math.ceil(args['image'][0].height / 8) + if 'width' in possible and 'height' in possible: + if isinstance(args['image'], torch.Tensor) or isinstance(args['image'], np.ndarray): + args['width'] = 8 * args['image'].shape[-1] + args['height'] = 8 * args['image'].shape[-2] + else: + args['width'] = 8 * math.ceil(args['image'][0].width / 8) + args['height'] = 8 * math.ceil(args['image'][0].height / 8) # handle implicit controlnet if 'control_image' in possible and 'control_image' not in args and 'image' in args: diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index b76376fd1..b6549cdac 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -121,12 +121,19 @@ def process_base(p: processing.StableDiffusionProcessing): shared.log.info(e) except ValueError as e: shared.state.interrupted = True - shared.log.error(f'Processing: args={base_args} {e}') + err_args = base_args.copy() + for k, v in base_args.items(): + if isinstance(v, torch.Tensor): + err_args[k] = f'{v.device}:{v.dtype}:{v.shape}' + shared.log.error(f'Processing: args={err_args} {e}') if shared.cmd_opts.debug: errors.display(e, 'Processing') except RuntimeError as e: shared.state.interrupted = True - shared.log.error(f'Processing: step=base args={base_args} {e}') + for k, v in base_args.items(): + if isinstance(v, torch.Tensor): + err_args[k] = f'{v.device}:{v.dtype}:{v.shape}' + shared.log.error(f'Processing: step=base args={err_args} {e}') errors.display(e, 'Processing') modelstats.analyze() diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index a7c5c296c..cc814f379 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -108,6 +108,12 @@ def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List def get_prompt_schedule(prompt, steps): t0 = time.time() + if shared.native: + # TODO prompt scheduling + # prompt schedule returns array of prompts which would require that each prompt is fed to the model per-step + # prompt scheduling should instead interpolate between each prompt in schedule + # this temporarily disables prompt scheduling + return [prompt], False temp = [] schedule = prompt_parser.get_learned_conditioning_prompt_schedules([prompt], steps)[0] if all(x == schedule[0] for x in schedule): @@ -213,6 +219,9 @@ def encode_prompts(pipe, p, prompts: list, negative_prompts: list, steps: int, c if negative_pooled is not None: negative_pooleds.append(negative_pooled) last_prompt, last_negative = prompt, negative + # TODO prompt scheduling + # interpolation should happen here and then we can re-enable prompt scheduling + # ive tried simple torch.mean and its not good-enough def fix_length(embeds): max_len = max([e.shape[1] for e in embeds if e is not None]) diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py index 9ee7d72ac..cb64482a5 100644 --- a/modules/sd_hijack_dynamic_atten.py +++ b/modules/sd_hijack_dynamic_atten.py @@ -57,6 +57,11 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo if do_split: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + if attn_mask is not None and attn_mask.shape != query.shape: + if len(query.shape) == 4: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1)) + else: + attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2])) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size diff --git a/modules/sd_models.py b/modules/sd_models.py index bfb5e565a..71a389c8a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -417,6 +417,16 @@ def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pyl return sd +def get_safetensor_keys(filename): + keys = [] + try: + with safetensors.torch.safe_open(filename, framework="pt", device="cpu") as f: + keys = f.keys() + except Exception as e: + shared.log.error(f'Load dict: path="{filename}" {e}') + return keys + + def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if not os.path.isfile(checkpoint_info.filename): return None @@ -1088,7 +1098,7 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' sd_model = load_flux(checkpoint_info, diffusers_load_config) elif model_type in ['Stable Diffusion 3']: from modules.model_sd3 import load_sd3 - shared.log.debug(f'Load {op}: model="Stable Diffusion 3" variant=medium') + shared.log.debug(f'Load {op}: model="Stable Diffusion 3"') shared.opts.scheduler = 'Default' sd_model = load_sd3(checkpoint_info, cache_dir=shared.opts.diffusers_dir, config=diffusers_load_config.get('config', None)) elif model_type in ['Meissonic']: # forced pipeline @@ -1210,7 +1220,7 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con diffusers_load_config['cache_dir'] = shared.opts.hfcache_dir sd_model = pipeline.from_ckpt(checkpoint_info.path, **diffusers_load_config) else: - shared.log.error(f'Diffusers {op} cannot load safetensor model: {checkpoint_info.path} {shared.opts.diffusers_pipeline}') + shared.log.error(f'Load {op}: file="{checkpoint_info.path}" {shared.opts.diffusers_pipeline} cannot load safetensor model') return None if shared.opts.diffusers_vae_upcast != 'default' and model_type in ['Stable Diffusion', 'Stable Diffusion XL']: diffusers_load_config['force_upcast'] = True if shared.opts.diffusers_vae_upcast == 'true' else False @@ -1224,8 +1234,11 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con diffusers_load_config.pop('local_files_only', None) shared.log.debug(f'Setting {op}: pipeline={sd_model.__class__.__name__} config={diffusers_load_config}') # pylint: disable=protected-access except Exception as e: - shared.log.error(f'Diffusers failed loading: {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__} config={diffusers_load_config} {e}') - errors.display(e, f'loading {op}={checkpoint_info.path} pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__}') + shared.log.error(f'Load {op}: file="{checkpoint_info.path}" pipeline={shared.opts.diffusers_pipeline}/{sd_model.__class__.__name__} config={diffusers_load_config} {e}') + if 'Weights for this component appear to be missing in the checkpoint' in str(e): + shared.log.error(f'Load {op}: file="{checkpoint_info.path}" is not a complete model') + else: + errors.display(e, 'Load') return None return sd_model @@ -1299,7 +1312,7 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No sd_model = load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_config, op) if sd_model is None: - shared.log.error('Diffuser model not loaded') + shared.log.error('Load {op}: no model loaded') return sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 89b5a8a79..82171e0b7 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -77,7 +77,7 @@ def create_sampler(name, model): if 'Lumina' in model.__class__.__name__: shared.log.warning(f'AlphaVLLM-Lumina: sampler="{name}" unsupported') return None - if 'StableDiffusion3Pipeline' in model.__class__.__name__: + if 'StableDiffusion3' in model.__class__.__name__: if sampler.name != 'Heun FlowMatch': return None return None diff --git a/modules/shared.py b/modules/shared.py index b2ce7ebaa..e184e53ed 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -19,7 +19,6 @@ from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611 from modules.dml import memory_providers, default_memory_provider, directml_do_hijack from modules.onnx_impl import initialize_onnx, execution_providers -from modules.zluda import initialize_zluda from modules.memstats import memory_stats import modules.interrogate import modules.memmon @@ -413,8 +412,8 @@ def get_default_modes(): if devices.backend == "rocm": default_sdp_options = ['Memory attention', 'Math attention'] - #elif devices.backend == "zluda": - # sdp_options_default = ['Math attention'] + elif devices.backend == "zluda": + default_sdp_options = ['Math attention'] else: default_sdp_options = ['Flash attention', 'Memory attention', 'Math attention'] if (cmd_opts.lowvram or cmd_opts.medvram) and ('Flash attention' not in default_sdp_options): @@ -496,6 +495,7 @@ def get_default_modes(): "openvino_sep": OptionInfo("