From a66f34dd22520877713491399c343c48f5da0b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Mon, 24 Feb 2025 19:02:22 +0100 Subject: [PATCH] Fix wrong scale eps applied --- torchao/dtypes/affine_quantized_tensor.py | 10 +++++++- torchao/quantization/utils.py | 28 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 715aaeb9ec..d18d442245 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -21,6 +21,9 @@ quantize_affine, quantize_affine_floatx, ) +from torchao.quantization.utils import ( + calculate_scale_eps_for_dtype, +) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, @@ -350,6 +353,11 @@ def from_hp_to_floatx( ): """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: + eps = ( + calculate_scale_eps_for_dtype(input_float.dtype) + if torch.is_floating_point(input_float) + else torch.finfo(torch.float32).eps + ) return cls.from_hp_to_intx( input_float=input_float, mapping_type=MappingType.SYMMETRIC, @@ -357,7 +365,7 @@ def from_hp_to_floatx( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - eps=torch.finfo(torch.float32).eps, + eps=eps, scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad00..1f6440f9f6 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import functools import importlib.util from typing import Dict, List, Optional @@ -41,6 +42,7 @@ "per_token_dynamic_quant", "get_group_qparams_symmetric", "recommended_inductor_config_setter", + "calculate_scale_eps_for_dtype", ] _lm_eval_available = importlib.util.find_spec("lm_eval") is not None @@ -587,3 +589,29 @@ def recommended_inductor_config_setter(): torch._inductor.config.fx_graph_cache = True torch._inductor.config.triton.unique_kernel_names = True torch.set_float32_matmul_precision("high") + + +@functools.lru_cache +def calculate_scale_eps_for_dtype(dtype: torch.dtype): + assert torch.is_floating_point(torch.empty(0, dtype=dtype)) + + def predecessor(x: torch.Tensor): + assert x.numel() == 1 + + dtype = x.dtype + if dtype == torch.float16: + zero = torch.tensor(0, dtype=dtype) + else: + zero = torch.tensor(0.0, dtype=dtype) + return torch.nextafter(x, zero) + + x = torch.tensor(torch.finfo(dtype).max, dtype=dtype) + x_rec = 1.0 / x + while True: + if torch.any(torch.isinf(x_rec.reciprocal())).item(): + x = predecessor(x) + x_rec = 1.0 / x + else: + break + + return x_rec.item()