-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] About NaNs generated during FP16->FP8 quantization #1766
Comments
I'd recommend adding an epsilon when calculating the scale. For float8 training, we do so here: ao/torchao/float8/float8_utils.py Line 47 in 2a3fbff
|
I think a good enough fix for this specific branch of quantization code would be: diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py
index 05be8c5..7369b3c 100644
--- a/torchao/quantization/quant_primitives.py
+++ b/torchao/quantization/quant_primitives.py
@@ -425,7 +425,9 @@ def _quantize_affine_no_dtype_cast(
zero_point is None
), "zero_point should be None when zero_point_domain is NONE"
if _is_float8_type(quant_dtype):
- quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
+ quant = input * scale.to(torch.float).reciprocal()
+ quant[input == 0] = 0
+ quant = torch.clamp(quant, quant_min, quant_max)
else:
quant = torch.clamp(
torch.round(input * (1.0 / scale)), quant_min, quant_max Namely, casting (I'm confused by resulting data type of expressions like |
I would recommend to use an epsilon to enforce that your scale is valid for your data type, and not try to work around having an incorrect scale later on. The code which checks for zero directly also doesn't handle other values in the neighborhood of zero. |
I think we've been talking about slightly different things, but in any case you're completely right: a proper approach to fix this kind of issues is through clamping scale from below - a PR to fix this problem that I've encountered is linked above. |
Just as a comment: If this line commented out, then following code: import torch
import torchao
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
x = torch.tensor([[0.1, 0, 0, 0.2]], dtype=torch.float16)
xq = hp_tensor_to_float8_dynamic(
x,
float8_dtype=torch.float8_e5m2,
linear_mm_config=LinearMMConfig(),
)
print("xq =", xq) will pass through the code mentioned in the above comment, and will produce NaNs, for the same reason as the reproducer from the first comment on the issue. Thus the point is in doing calculations in right precision at the right place; the |
The reproducer:
The output is:
Basically, the problem is that the quantization code maps
[0,0.1]
range to[0,57344]
(here, 57344 is maximum value fortorch.float_e5m2
data type), so the scale gets very small, and then its reciprocal here becomeInf
, and then0*Inf
producesNaN
s as quantized values.This is all, for course, simply about the range and precision of involved data types, but I was just wondering is this a known issue? Would it make sense to force 0 as
input * scale.reciprocal()
result here, wherever the corresponding input elements are 0?The text was updated successfully, but these errors were encountered: