Skip to content

Commit

Permalink
Fix wrong scale eps applied
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 24, 2025
1 parent 2a3fbff commit a66f34d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
quantize_affine,
quantize_affine_floatx,
)
from torchao.quantization.utils import (
calculate_scale_eps_for_dtype,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
Expand Down Expand Up @@ -350,14 +353,19 @@ def from_hp_to_floatx(
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype in FP8_TYPES:
eps = (
calculate_scale_eps_for_dtype(input_float.dtype)
if torch.is_floating_point(input_float)
else torch.finfo(torch.float32).eps
)
return cls.from_hp_to_intx(
input_float=input_float,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
eps=torch.finfo(torch.float32).eps,
eps=eps,
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
Expand Down
28 changes: 28 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import functools
import importlib.util
from typing import Dict, List, Optional

Expand Down Expand Up @@ -41,6 +42,7 @@
"per_token_dynamic_quant",
"get_group_qparams_symmetric",
"recommended_inductor_config_setter",
"calculate_scale_eps_for_dtype",
]

_lm_eval_available = importlib.util.find_spec("lm_eval") is not None
Expand Down Expand Up @@ -587,3 +589,29 @@ def recommended_inductor_config_setter():
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.unique_kernel_names = True
torch.set_float32_matmul_precision("high")


@functools.lru_cache
def calculate_scale_eps_for_dtype(dtype: torch.dtype):
assert torch.is_floating_point(torch.empty(0, dtype=dtype))

def predecessor(x: torch.Tensor):
assert x.numel() == 1

dtype = x.dtype
if dtype == torch.float16:
zero = torch.tensor(0, dtype=dtype)
else:
zero = torch.tensor(0.0, dtype=dtype)
return torch.nextafter(x, zero)

x = torch.tensor(torch.finfo(dtype).max, dtype=dtype)
x_rec = 1.0 / x
while True:
if torch.any(torch.isinf(x_rec.reciprocal())).item():
x = predecessor(x)
x_rec = 1.0 / x
else:
break

return x_rec.item()

0 comments on commit a66f34d

Please sign in to comment.