Skip to content

Commit

Permalink
switch gguf loader
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Dec 21, 2024
1 parent d9a320b commit 76755c6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 29 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
### SD.Next Xmass edition: *What's new?*

While we have several new supported models, workflows and tools, this release is primarily about *quality-of-life improvements*:
- New memory management engine: list of changes that went into this one is long: changes to GPU offloading, LoRA loader, system memory management, on-the-fly quantization, etc.
- New memory management engine: list of changes that went into this one is long: changes to GPU offloading, brand new LoRA loader, system memory management, on-the-fly quantization, improved gguf loader, etc.
but main goal is enabling modern large models to run on standard consumer GPUs
without performance hits typically associated with aggressive memory swapping and needs for constant manual tweaks
- New [documentation website](https://vladmandic.github.io/sdnext-docs/)
Expand Down Expand Up @@ -165,6 +165,7 @@ All-in-all, we're around ~160 commits worth of updates, check changelog for full
- `BitsAndBytes` with 3 float-based quantization schemes
- `Optimium.Quanto` with 3 int-based and 2 float-based quantizations schemes
- `GGUF` with pre-quantized weights
- Switch `GGUF` loader from custom to diffuser native
- **IPEX**: update to IPEX 2.5.10+xpu
- **OpenVINO**: update to 2024.5.0
- **Sampler** improvements
Expand Down Expand Up @@ -202,6 +203,8 @@ All-in-all, we're around ~160 commits worth of updates, check changelog for full
- uninstall conflicting `wandb` package
- dont skip diffusers version check if quick is specified
- notify on torch install
- detect pipeline fro diffusers folder-style model
- do not recast flux quants

## Update for 2024-11-21

Expand Down
45 changes: 40 additions & 5 deletions modules/ggml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
from pathlib import Path
import os
import time
import torch
import gguf
from .gguf_utils import TORCH_COMPATIBLE_QTYPES
from .gguf_tensor import GGMLTensor
import diffusers
import transformers


def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
def install_gguf():
# pip install git+https://github.com/junejae/transformers@feature/t5-gguf
# https://github.com/ggerganov/llama.cpp/issues/9566
from installer import install
install('gguf', quiet=True)
import importlib
import gguf
from modules import shared
scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts')
if os.path.exists(scripts_dir):
os.rename(scripts_dir, scripts_dir + str(time.time()))
# monkey patch transformers/diffusers so they detect newly installed gguf pacakge correctly
ver = importlib.metadata.version('gguf')
transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
transformers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access
diffusers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
diffusers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access
shared.log.debug(f'Load GGUF: version={ver}')
return gguf


def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict:
gguf = install_gguf()
from .gguf_utils import TORCH_COMPATIBLE_QTYPES
from .gguf_tensor import GGMLTensor
sd: dict[str, GGMLTensor] = {}
stats = {}
reader = gguf.GGUFReader(path)
Expand All @@ -19,3 +43,14 @@ def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGM
stats[tensor.tensor_type.name] = 0
stats[tensor.tensor_type.name] += 1
return sd, stats


def load_gguf(path, cls, compute_dtype: torch.dtype):
_gguf = install_gguf()
module = cls.from_single_file(
path,
quantization_config = diffusers.GGUFQuantizationConfig(compute_dtype=compute_dtype),
torch_dtype=compute_dtype,
)
module.gguf = 'gguf'
return module
18 changes: 14 additions & 4 deletions modules/model_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ def load_quants(kwargs, repo_id, cache_dir):
return kwargs


"""
def load_flux_gguf(file_path):
transformer = None
model_te.install_gguf()
ggml.install_gguf()
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from modules import ggml, sd_hijack_accelerate
Expand All @@ -180,9 +181,11 @@ def load_flux_gguf(file_path):
continue
applied += 1
sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0)
transformer.gguf = 'gguf'
state_dict[param_name] = None
shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats}')
return transformer, None
"""


def load_transformer(file_path): # triggered by opts.sd_unet change
Expand All @@ -197,7 +200,9 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
}
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
_transformer, _text_encoder_2 = load_flux_gguf(file_path)
# _transformer, _text_encoder_2 = load_flux_gguf(file_path)
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == 'qint8' or quant == 'qint4':
Expand Down Expand Up @@ -336,9 +341,14 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
cls = diffusers.FluxPipeline
shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], 'quantization_method', None) or getattr(kwargs[c], 'gguf', None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass

allow_quant = 'gguf' not in (sd_unet.loaded_unet or '')
fn = checkpoint_info.path
Expand Down
13 changes: 10 additions & 3 deletions modules/model_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def load_overrides(kwargs, cache_dir):
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=safetensors')
elif fn.endswith('.gguf'):
kwargs = load_gguf(kwargs, fn)
from modules import ggml
# kwargs = load_gguf(kwargs, fn)
kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype)
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=gguf')
except Exception as e:
Expand Down Expand Up @@ -90,8 +92,9 @@ def load_missing(kwargs, fn, cache_dir):
return kwargs


"""
def load_gguf(kwargs, fn):
model_te.install_gguf()
ggml.install_gguf()
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers
from modules import ggml, sd_hijack_accelerate
Expand All @@ -108,10 +111,12 @@ def load_gguf(kwargs, fn):
continue
applied += 1
sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0)
transformer.gguf = 'gguf'
state_dict[param_name] = None
shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats} compute={devices.dtype}')
kwargs['transformer'] = transformer
return kwargs
"""


def load_sd3(checkpoint_info, cache_dir=None, config=None):
Expand Down Expand Up @@ -139,7 +144,9 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
# kwargs = load_missing(kwargs, fn, cache_dir)
repo_id = fn
elif fn.endswith('.gguf'):
kwargs = load_gguf(kwargs, fn)
from modules import ggml
kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype)
# kwargs = load_gguf(kwargs, fn)
kwargs = load_missing(kwargs, fn, cache_dir)
kwargs['variant'] = 'fp16'
else:
Expand Down
19 changes: 3 additions & 16 deletions modules/model_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,6 @@
loaded_te = None


def install_gguf():
# pip install git+https://github.com/junejae/transformers@feature/t5-gguf
install('gguf', quiet=True)
# https://github.com/ggerganov/llama.cpp/issues/9566
import gguf
scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts')
if os.path.exists(scripts_dir):
os.rename(scripts_dir, scripts_dir + '_gguf')
# monkey patch transformers so they detect gguf pacakge correctly
import importlib
transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
transformers.utils.import_utils._gguf_version = importlib.metadata.version('gguf') # pylint: disable=protected-access


def load_t5(name=None, cache_dir=None):
global loaded_te # pylint: disable=global-statement
if name is None:
Expand All @@ -34,8 +20,9 @@ def load_t5(name=None, cache_dir=None):
modelloader.hf_login()
repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
fn = te_dict.get(name) if name in te_dict else None
if fn is not None and 'gguf' in name.lower():
install_gguf()
if fn is not None and name.lower().endswith('gguf'):
from modules import ggml
ggml.install_gguf()
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
t5 = transformers.T5EncoderModel.from_pretrained(None, gguf_file=fn, config=t5_config, device_map="auto", cache_dir=cache_dir, torch_dtype=devices.dtype)
Expand Down
11 changes: 11 additions & 0 deletions modules/sd_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
guess = 'FLUX'
if size > 11000 and size < 16000:
warn(f'Model detected as FLUX UNET model, but attempting to load a base model: {op}={f} size={size} MB')
# guess for diffusers
index = os.path.join(f, 'model_index.json')
if os.path.exists(index) and os.path.isfile(index):
index = shared.readfile(index, silent=True)
cls = index.get('_class_name', None)
if cls is not None:
pipeline = getattr(diffusers, cls)
if 'Flux' in pipeline.__name__:
guess = 'FLUX'
if 'StableDiffusion3' in pipeline.__name__:
guess = 'Stable Diffusion 3'
# switch for specific variant
if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
guess = 'Stable Diffusion Inpaint'
Expand Down

0 comments on commit 76755c6

Please sign in to comment.