Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong scale eps applied #1770

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
quantize_affine,
quantize_affine_floatx,
)
from torchao.quantization.utils import (
calculate_scale_eps_for_dtype,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
Expand Down Expand Up @@ -350,14 +353,19 @@ def from_hp_to_floatx(
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype in FP8_TYPES:
eps = (
calculate_scale_eps_for_dtype(input_float.dtype)
if torch.is_floating_point(input_float)
else torch.finfo(torch.float32).eps
)
return cls.from_hp_to_intx(
input_float=input_float,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
eps=torch.finfo(torch.float32).eps,
eps=eps,
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
Expand Down
28 changes: 28 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import functools
import importlib.util
from typing import Dict, List, Optional

Expand Down Expand Up @@ -41,6 +42,7 @@
"per_token_dynamic_quant",
"get_group_qparams_symmetric",
"recommended_inductor_config_setter",
"calculate_scale_eps_for_dtype",
]

_lm_eval_available = importlib.util.find_spec("lm_eval") is not None
Expand Down Expand Up @@ -587,3 +589,29 @@ def recommended_inductor_config_setter():
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.unique_kernel_names = True
torch.set_float32_matmul_precision("high")


@functools.lru_cache
def calculate_scale_eps_for_dtype(dtype: torch.dtype):
assert torch.is_floating_point(torch.empty(0, dtype=dtype))

def predecessor(x: torch.Tensor):
assert x.numel() == 1

dtype = x.dtype
if dtype == torch.float16:
zero = torch.tensor(0, dtype=dtype)
else:
zero = torch.tensor(0.0, dtype=dtype)
return torch.nextafter(x, zero)

x = torch.tensor(torch.finfo(dtype).max, dtype=dtype)
x_rec = 1.0 / x
while True:
if torch.any(torch.isinf(x_rec.reciprocal())).item():
x = predecessor(x)
x_rec = 1.0 / x
else:
break

return x_rec.item()
Loading