Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] About NaNs generated during FP16->FP8 quantization #1766

Open
alexsamardzic opened this issue Feb 23, 2025 · 5 comments · May be fixed by #1770
Open

[QST] About NaNs generated during FP16->FP8 quantization #1766

alexsamardzic opened this issue Feb 23, 2025 · 5 comments · May be fixed by #1770

Comments

@alexsamardzic
Copy link
Collaborator

The reproducer:

import torch
import torchao

from torchao.dtypes import Float8Layout, to_affine_quantized_floatx

x = torch.tensor([[0, 0, 0.1, 0.1]], dtype=torch.float16)

x_aqt = to_affine_quantized_floatx(
    x,
    target_dtype=torch.float8_e5m2,
    block_size=[1, x.shape[1]],
    _layout=Float8Layout(mm_config=None),
)
xq, scale, _ = x_aqt.tensor_impl.get_plain()

print("x =", x)
print("xq =", xq)
print("scale =", scale)

The output is:

x = tensor([[0.0000, 0.0000, 0.1000, 0.1000]], dtype=torch.float16)
xq = tensor([[   nan,    nan, 57344., 57344.]], dtype=torch.float8_e5m2)
scale = tensor([1.7285e-06], dtype=torch.float16)

Basically, the problem is that the quantization code maps [0,0.1] range to [0,57344] (here, 57344 is maximum value for torch.float_e5m2 data type), so the scale gets very small, and then its reciprocal here become Inf, and then 0*Inf produces NaNs as quantized values.

This is all, for course, simply about the range and precision of involved data types, but I was just wondering is this a known issue? Would it make sense to force 0 as input * scale.reciprocal() result here, wherever the corresponding input elements are 0?

@vkuzo
Copy link
Contributor

vkuzo commented Feb 24, 2025

I'd recommend adding an epsilon when calculating the scale. For float8 training, we do so here:

res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
. In the past we had extra logic to adjust the epsilon for float16, but we deleted it as for training we only care about bfloat16.

@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Feb 24, 2025

I think a good enough fix for this specific branch of quantization code would be:

diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py
index 05be8c5..7369b3c 100644
--- a/torchao/quantization/quant_primitives.py
+++ b/torchao/quantization/quant_primitives.py
@@ -425,7 +425,9 @@ def _quantize_affine_no_dtype_cast(
             zero_point is None
         ), "zero_point should be None when zero_point_domain is NONE"
         if _is_float8_type(quant_dtype):
-            quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
+            quant = input * scale.to(torch.float).reciprocal()
+            quant[input == 0] = 0
+            quant = torch.clamp(quant, quant_min, quant_max)
         else:
             quant = torch.clamp(
                 torch.round(input * (1.0 / scale)), quant_min, quant_max

Namely, casting scale to float makes it possible to avoid, in the most cases, the scale reciprocal to go to Inf, and even if that happens, then the next line in the patch would avoid NaNs for 0*Inf (while the rest of Infs would be taken care by clamping). The question is, however: do we make the same change for other branches of code in _quantize_affine_no_dtype_cast() (as well as alike code throughout the whole torchao/quantization/quant_primitives.py), where (1.0 / scale) is used instead of scale.reciprocal()?

(I'm confused by resulting data type of expressions like 2.0 * torch.eye(3, dtype=...) - it seems that if the tensor data type is integer, then the result is promoted to scalar data type i.e. float, but if the tensor data type is floating point, even narrower than float like float16/bfloat16, then the result is not the wider data type, but the tensor data type instead. I did not know about that, for this reason the other branches of code would enconter the same problem as often as this one.)

@vkuzo
Copy link
Contributor

vkuzo commented Feb 24, 2025

quant[input == 0] = 0

I would recommend to use an epsilon to enforce that your scale is valid for your data type, and not try to work around having an incorrect scale later on. The code which checks for zero directly also doesn't handle other values in the neighborhood of zero.

@alexsamardzic alexsamardzic linked a pull request Feb 24, 2025 that will close this issue
@alexsamardzic
Copy link
Collaborator Author

I think we've been talking about slightly different things, but in any case you're completely right: a proper approach to fix this kind of issues is through clamping scale from below - a PR to fix this problem that I've encountered is linked above.

@alexsamardzic
Copy link
Collaborator Author

Just as a comment: If this line commented out, then following code:

import torch
import torchao

from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig

x = torch.tensor([[0.1, 0, 0, 0.2]], dtype=torch.float16)

xq = hp_tensor_to_float8_dynamic(
    x,
    float8_dtype=torch.float8_e5m2,
    linear_mm_config=LinearMMConfig(),
)

print("xq =", xq)

will pass through the code mentioned in the above comment, and will produce NaNs, for the same reason as the reproducer from the first comment on the issue. Thus the point is in doing calculations in right precision at the right place; the torchao/float8 stuff does that, and also keeps the scale as something to multiply the original tensor with, so the conversion to higher precision is kept to single place. On the other hand, torchao/quantization stuff keeps scale as something to divide the original tensor with, so the change needed to avoid (unlikely) issues would have to be repeated on number of places, wherever this division occurs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants