-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
294 changed files
with
127,954 additions
and
17,329 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,7 +163,13 @@ cython_debug/ | |
.vscode/ | ||
!.vscode/settings.json | ||
|
||
.DS_Store | ||
*.log | ||
*.pt | ||
.tmp/ | ||
runs | ||
exps | ||
runs/ | ||
exps/ | ||
wandb | ||
wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .config import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# -*- coding: utf-8 -*- | ||
"""LLM quantization cache configuration.""" | ||
|
||
import functools | ||
import re | ||
import typing as tp | ||
from dataclasses import dataclass, field | ||
|
||
from omniconfig import configclass | ||
|
||
from deepcompressor.utils.config.path import BasePathConfig | ||
|
||
from ..nn.struct import DiffusionModelStruct | ||
|
||
__all__ = ["DiffusionQuantCacheConfig", "DiffusionPtqCacheConfig"] | ||
|
||
|
||
@dataclass | ||
class DiffusionQuantCacheConfig(BasePathConfig): | ||
"""Denoising diffusion model quantization cache path. | ||
Args: | ||
smooth (`str`, *optional*, default=`""`): | ||
The smoothing scales cache path. | ||
branch (`str`, *optional*, default=`""`): | ||
The low-rank branches cache path. | ||
wgts (`str`, *optional*, default=`""`): | ||
The weight quantizers state dict cache path. | ||
acts (`str`, *optional*, default=`""`): | ||
The activation quantizers state dict cache path | ||
""" | ||
|
||
smooth: str = "" | ||
branch: str = "" | ||
wgts: str = "" | ||
acts: str = "" | ||
|
||
@staticmethod | ||
def simplify_path(path: str, key_map: dict[str, set[str]]) -> str: | ||
"""Simplify the cache path.""" | ||
to_replace = {} | ||
# we first extract all the parts matching the pattern "(skip|include).\[[a-zA-Z0-9_\+]+\]" | ||
for part in re.finditer(r"(skip|include)\.\[[a-zA-Z0-9_\+]+\]", path): | ||
# remove the "skip." or "include." prefix | ||
part = part.group(0) | ||
if part[0] == "s": | ||
prefix, keys = part[:4], part[6:-1] | ||
else: | ||
prefix, keys = part[:7], part[9:-1] | ||
# simplify the keys | ||
keys = "+".join( | ||
( | ||
"".join((s[0] for s in x.split("_"))) | ||
for x in DiffusionModelStruct._simplify_keys(keys.split("+"), key_map=key_map) | ||
) | ||
) | ||
to_replace[part] = f"{prefix}.[{keys}]" | ||
# we then replace the parts | ||
for key, value in to_replace.items(): | ||
path = path.replace(key, value) | ||
return path | ||
|
||
def simplify(self, key_map: dict[str, set[str]]) -> tp.Self: | ||
"""Simplify the cache paths.""" | ||
return self.apply(functools.partial(self.simplify_path, key_map=key_map)) | ||
|
||
|
||
@configclass | ||
@dataclass | ||
class DiffusionPtqCacheConfig: | ||
root: str | ||
dirpath: DiffusionQuantCacheConfig = field(init=False) | ||
path: DiffusionQuantCacheConfig = field(init=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Top-level config of post-training quantization for a diffusion model.""" | ||
|
||
import os | ||
from dataclasses import dataclass, field | ||
|
||
import diffusers.training_utils | ||
import omniconfig | ||
import torch | ||
from omniconfig import ConfigParser, configclass | ||
|
||
from deepcompressor.app.llm.config import LlmCacheConfig, LlmQuantConfig | ||
from deepcompressor.data.utils import ScaleUtils | ||
from deepcompressor.utils.config.output import OutputConfig | ||
|
||
from .cache import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig | ||
from .eval import DiffusionEvalConfig | ||
from .nn.struct import DiffusionModelStruct | ||
from .pipeline import DiffusionPipelineConfig | ||
from .quant import DiffusionQuantConfig | ||
|
||
__all__ = [ | ||
"DiffusionPtqRunConfig", | ||
"DiffusionPtqCacheConfig", | ||
"DiffusionQuantCacheConfig", | ||
"DiffusionEvalConfig", | ||
"DiffusionPipelineConfig", | ||
"DiffusionQuantConfig", | ||
] | ||
|
||
|
||
@configclass | ||
@dataclass | ||
class DiffusionPtqRunConfig: | ||
"""Top-level config of post-training quantization for a diffusion model. | ||
Args: | ||
cache (`DiffusionPtqCacheConfig`): | ||
The cache configuration. | ||
output (`OutputConfig`): | ||
The output directory configuration. | ||
pipeline (`DiffusionPipelineConfig`): | ||
The diffusion pipeline configuration | ||
eval (`DiffusionEvalConfig`): | ||
The evaluation configuration. | ||
quant (`DiffusionQuantConfig`): | ||
The post-training quantization configuration. | ||
seed (`int`, *optional*, defaults to `12345`): | ||
The seed for reproducibility. | ||
skip_gen (`bool`, *optional*, defaults to `False`): | ||
Whether to skip generation. | ||
skip_eval (`bool`, *optional*, defaults to `False`): | ||
Whether to skip evaluation. | ||
load_model (`str`, *optional*, defaults to `""`): | ||
Directory path to load the model checkpoint. | ||
save_model (`str`, *optional*, defaults to `""`): | ||
Directory path to save the model checkpoint. | ||
copy_on_save (`bool`, *optional*, defaults to `False`): | ||
Whether to copy the quantization cache on save. | ||
""" | ||
|
||
cache: DiffusionPtqCacheConfig | None | ||
output: OutputConfig | ||
pipeline: DiffusionPipelineConfig | ||
eval: DiffusionEvalConfig | ||
quant: DiffusionQuantConfig = field(metadata={omniconfig.ARGPARSE_KWARGS: {"prefix": ""}}) | ||
text: LlmQuantConfig | None = None | ||
text_cache: LlmCacheConfig = field(default_factory=LlmCacheConfig) | ||
seed: int = 12345 | ||
skip_gen: bool = False | ||
skip_eval: bool = False | ||
load_from: str = "" | ||
save_model: str = "" | ||
copy_on_save: bool = False | ||
|
||
def __post_init__(self): | ||
# region set text encoder quanatization scale default dtype | ||
if self.text is not None and self.text.enabled_wgts: | ||
self.text.wgts.scale_dtypes = tuple( | ||
ScaleUtils.infer_scale_dtypes(self.text.wgts.scale_dtypes, default_dtype=self.pipeline.dtype) | ||
) | ||
if self.text is not None and self.text.enabled_ipts: | ||
self.text.ipts.scale_dtypes = tuple( | ||
ScaleUtils.infer_scale_dtypes(self.text.ipts.scale_dtypes, default_dtype=self.pipeline.dtype) | ||
) | ||
if self.text is not None and self.text.enabled_opts: | ||
self.text.opts.scale_dtypes = tuple( | ||
ScaleUtils.infer_scale_dtypes(self.text.opts.scale_dtypes, default_dtype=self.pipeline.dtype) | ||
) | ||
# endregion | ||
self.eval.num_gpus = min(torch.cuda.device_count(), self.eval.num_gpus) | ||
if self.eval.batch_size_per_gpu is None: | ||
self.eval.batch_size_per_gpu = max(1, self.eval.batch_size // self.eval.num_gpus) | ||
self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus | ||
else: | ||
self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus | ||
# region setup calib dataset path | ||
self.quant.calib.path = self.quant.calib.path.format( | ||
dtype=self.pipeline.dtype, | ||
family=self.pipeline.family, | ||
model=self.pipeline.name, | ||
protocol=self.eval.protocol, | ||
data=self.quant.calib.data, | ||
) | ||
if self.quant.calib.path: | ||
self.quant.calib.path = os.path.abspath(os.path.expanduser(self.quant.calib.path)) | ||
# endregion | ||
# region setup eval reference root | ||
self.eval.ref_root = self.eval.ref_root.format( | ||
dtype=self.pipeline.dtype, | ||
family=self.pipeline.family, | ||
model=self.pipeline.name, | ||
protocol=self.eval.protocol, | ||
) | ||
if self.eval.ref_root: | ||
self.eval.ref_root = os.path.abspath(os.path.expanduser(self.eval.ref_root)) | ||
# endregion | ||
# region setup cache directory | ||
if self.cache is not None: | ||
if self.quant.enabled_wgts or self.quant.enabled_ipts or self.quant.enabled_opts: | ||
self.cache.dirpath = self.quant.generate_cache_dirpath( | ||
root=self.cache.root, shift=self.pipeline.shift_activations, default_dtype=self.pipeline.dtype | ||
) | ||
self.cache.path = self.cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt") | ||
else: | ||
self.cache.dirpath = self.cache.path = None | ||
if self.text is not None and self.text.is_enabled(): | ||
if not self.text_cache.root: | ||
self.text_cache.root = os.path.join(self.cache.root, "diffusion") | ||
self.text_cache.dirpath = self.text.generate_cache_dirpath(root=self.text_cache.root, seed=self.seed) | ||
self.text_cache.path = self.text_cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt") | ||
# endregion | ||
# region setup output directory | ||
if self.output.dirname == "reference": | ||
assert self.eval.ref_root | ||
self.output.job = f"run-{self.eval.num_samples}" | ||
self.output.dirpath = self.eval.ref_root | ||
self.eval.ref_root = "" | ||
self.eval.gen_root = "{output}" | ||
else: | ||
if self.output.dirname == "default": | ||
self.output.dirname = self.generate_default_dirname() | ||
calib_dirname = self.quant.generate_calib_dirname() or "-" | ||
self.output.dirpath = os.path.join( | ||
self.output.root, | ||
"diffusion", | ||
self.pipeline.family, | ||
self.pipeline.name, | ||
*self.quant.generate_dirnames(default_dtype=self.pipeline.dtype)[:-1], | ||
calib_dirname, | ||
self.output.dirname, | ||
) | ||
if (self.eval.chunk_start > 0 or self.eval.chunk_step > 1) and not self.eval.chunk_only: | ||
self.output.job += f".c{self.eval.chunk_start}.{self.eval.chunk_step}" | ||
# endregion | ||
diffusers.training_utils.set_seed(self.seed) | ||
|
||
def generate_default_dirname(self) -> str: | ||
name = "-shift" if self.pipeline.shift_activations else "" | ||
if self.quant.is_enabled(): | ||
name += f"-{self.quant.generate_default_dirname()}" | ||
if self.text is not None and self.text.is_enabled(): | ||
name += f"-text-{self.text.generate_default_dirname()}" | ||
size_name = "" | ||
if self.eval.height: | ||
size_name += f".h{self.eval.height}" | ||
if self.eval.width: | ||
size_name += f".w{self.eval.width}" | ||
if size_name: | ||
name += f"-{size_name[1:]}" | ||
sampling_name = "" | ||
if self.eval.num_steps is not None: | ||
sampling_name += f".t{self.eval.num_steps}" | ||
if self.eval.guidance_scale is not None: | ||
sampling_name += f".g{self.eval.guidance_scale}" | ||
if sampling_name: | ||
name += f"-{sampling_name[1:]}" | ||
if self.eval.num_samples != -1: | ||
name += f"-s{self.eval.num_samples}" | ||
if self.eval.chunk_only: | ||
name += f".c{self.eval.chunk_start}.{self.eval.chunk_step}" | ||
assert name[0] == "-" | ||
return name[1:] | ||
|
||
@classmethod | ||
def get_parser(cls) -> ConfigParser: | ||
"""Get a parser for post-training quantization of a diffusion model. | ||
Returns: | ||
`ConfigParser`: | ||
A parser for post-training quantization of a diffusion model. | ||
""" | ||
parser = ConfigParser("Diffusion Run configuration") | ||
DiffusionQuantConfig.set_key_map(DiffusionModelStruct._get_default_key_map()) | ||
parser.add_config(cls) | ||
return parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from .base import DiffusionDataset | ||
from .calib import DiffusionCalibCacheLoader, DiffusionCalibCacheLoaderConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Dataset for diffusion models.""" | ||
|
||
import os | ||
import random | ||
import typing as tp | ||
|
||
import numpy as np | ||
import torch | ||
import torch.utils.data | ||
from torch.nn import functional as F | ||
|
||
from deepcompressor.utils.common import tree_collate, tree_map | ||
|
||
__all__ = ["DiffusionDataset"] | ||
|
||
|
||
class DiffusionDataset(torch.utils.data.Dataset): | ||
path: str | ||
filenames: list[str] | ||
filepaths: list[str] | ||
|
||
def __init__(self, path: str, num_samples: int = -1, seed: int = 0, ext: str = ".npy") -> None: | ||
if os.path.exists(path): | ||
self.path = path | ||
if "caches" in os.listdir(path): | ||
path = os.path.join(path, "caches") | ||
filenames = [f for f in sorted(os.listdir(path)) if f.endswith(ext)] | ||
if num_samples > 0 and num_samples < len(filenames): | ||
random.Random(seed).shuffle(filenames) | ||
filenames = filenames[:num_samples] | ||
filenames = sorted(filenames) | ||
self.filenames = filenames | ||
self.filepaths = [os.path.join(path, f) for f in filenames] | ||
else: | ||
raise ValueError(f"Invalid data path: {path}") | ||
|
||
def __len__(self) -> int: | ||
return len(self.filepaths) | ||
|
||
def __getitem__(self, idx) -> dict[str, tp.Any]: | ||
data = np.load(self.filepaths[idx], allow_pickle=True).item() | ||
if isinstance(data["input_args"][0], str): | ||
name = data["input_args"][0] | ||
latent = np.load(os.path.join(self.path, "latents", name)) | ||
data["input_args"][0] = latent | ||
if isinstance(data["input_kwargs"]["encoder_hidden_states"], str): | ||
name = data["input_kwargs"]["encoder_hidden_states"] | ||
text_emb = np.load(os.path.join(self.path, "text_embs", name)) | ||
data["input_kwargs"]["encoder_hidden_states"] = text_emb | ||
data = tree_map(lambda x: torch.from_numpy(x), data) | ||
|
||
# Pad encoder_hidden_states to 300 for pixart | ||
if "encoder_attention_mask" in data["input_kwargs"]: | ||
encoder_attention_mask = data["input_kwargs"]["encoder_attention_mask"] | ||
encoder_hidden_states = data["input_kwargs"]["encoder_hidden_states"] | ||
encoder_hidden_states = F.pad( | ||
encoder_hidden_states, | ||
(0, 0, 0, encoder_attention_mask.shape[1] - encoder_hidden_states.shape[1]), | ||
) | ||
data["input_kwargs"]["encoder_hidden_states"] = encoder_hidden_states | ||
|
||
return data | ||
|
||
def build_loader(self, **kwargs) -> torch.utils.data.DataLoader: | ||
return torch.utils.data.DataLoader(self, collate_fn=tree_collate, **kwargs) |
Oops, something went wrong.