Skip to content

Commit

Permalink
Merge pull request #21 from bobboli/main
Browse files Browse the repository at this point in the history
[minor] Fix some minor issues
  • Loading branch information
synxlin authored Nov 7, 2024
2 parents 58a3a16 + 9abc1d1 commit ca25987
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
20 changes: 14 additions & 6 deletions lmquant/llm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,25 @@ def _pre_layer_kwargs_hook(
kwargs: dict[str, tp.Any],
kwargs_cache: dict[str, tp.Any],
) -> None:
def _check_equality(_k, _v, _cached):
if isinstance(_v, DynamicCache):
assert _cached is None, f"kwargs_cache[{_k}] should be None"
elif isinstance(_v, torch.Tensor):
assert _v.allclose(_cached), f"kwargs_cache[{_k}] should be the same as kwargs[{_k}]"
elif isinstance(_v, tuple):
assert len(_v) == len(
_cached), f"kwargs_cache[{_k}] is a tuple, and should have the same length as kwargs[{_k}]"
for i in range(len(_v)):
_check_equality(_k, _v[i], _cached[i])
else:
assert _v == _cached, f"kwargs_cache[{_k}] should be the same as {_v}"

if kwargs_cache:
assert len(kwargs_cache) == len(kwargs), "kwargs_cache should have the same length as kwargs"
for k, v in kwargs.items():
assert k in kwargs_cache, f"kwargs_cache should have the same keys as kwargs, but missing {k}"
cached = kwargs_cache[k]
if isinstance(v, DynamicCache):
assert cached is None, f"kwargs_cache[{k}] should be None"
elif isinstance(v, torch.Tensor):
assert v.allclose(cached), f"kwargs_cache[{k}] should be the same as kwargs[{k}]"
else:
assert v == cached, f"kwargs_cache[{k}] should be the same as kwargs[{k}]"
_check_equality(k, v, cached)
else:
for k, v in kwargs.items():
if isinstance(v, DynamicCache):
Expand Down
2 changes: 1 addition & 1 deletion lmquant/llm/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run( # noqa: C901
# region rotate model
if needs_rotation:
logger.info(f"* Building model {config.model.name} from {config.model.path}")
model, tokenizer = config.model.build(dtype=torch.float32, cpu=config.model.size > 30)
model, tokenizer = config.model.build(dtype=torch.float32)
model = LlmModelStruct.build(model)
config.quant.num_hidden_layers = model.config.num_hidden_layers
if config.quant.develop_dtype is None:
Expand Down
4 changes: 3 additions & 1 deletion lmquant/quant/functional/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,23 @@ def gptq_quantize( # noqa: C901
# endregion
# region step 5: get the inverse of the Hessian matrix
stable_inv, num_inv_tries = False, 0
hessian_inv = None
while (not stable_inv) and num_inv_tries < gptq_config.num_inv_tries:
num_inv_tries += 1
try:
hessian_inv = torch.linalg.cholesky(hessian)
hessian_inv = torch.cholesky_inverse(hessian_inv)
hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True)
except RuntimeError:
hessian_diag += (gptq_config.damp_percentage * 0.1) * hessian_diag_mean
hessian_diag += gptq_config.damp_percentage * hessian_diag_mean
continue
stable_inv = True
if num_inv_tries > 1:
logger = logging.getLogger(f"{__name__}.GPTQ")
logger.debug(
" - GPTQ Hessian is not stable %s %d tries.", "until" if stable_inv else "after", num_inv_tries
)
assert stable_inv and hessian_inv is not None, "GPTQ Hessian is not stable! Consider increase damp_percentage."
assert not hessian_inv.isinf().any(), "Inverse of Hessian matrix contains Inf."
assert not hessian_inv.isnan().any(), "Inverse of Hessian matrix contains NaN."
del hessian, hessian_diag, hessian_diag_mean, num_inv_tries
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ lm_eval = ">= 0.4.2"
accelerate = ">= 0.26.0"
datasets = ">= 2.16.0"
sentencepiece = ">= 0.1.99"
omniconfig = ">= 0.1.5"
omniconfig = "== 0.1.5"
protobuf = ">= 5.26.0"

[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit ca25987

Please sign in to comment.