Skip to content

Commit

Permalink
modify cast from hp to mx to help inductor fuse
Browse files Browse the repository at this point in the history
Summary:

Thanks to investigation from @eellison, moving the reshape
to the end of the cast helps inductor fuse the cast into a single
kernel.  This doesn't yet work with fp4, but let's unblock fp8 and deal
with fp4 later.

Fixes #1690

Note: in the repro with swizzling from
#1773, we go from 3 to 2 kernels.
Further investigation is needed whether we can fuse the swizzling.

Test Plan:

```
pytest test/prototype/mx_formats/test_mx_tensor.py -x -s -k test_to_mx_inductor_single_kernel
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Feb 26, 2025
1 parent d00ee41 commit d233496
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
24 changes: 24 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest
import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
Expand Down Expand Up @@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.skipif(
not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater"
)
def test_to_mx_inductor_single_kernel():
"""
Verify that inductor can fuse the cast of a high precision tensor to mx
into a single kernel
"""
# TODO(future PR): add fp4 and fp6 here
# TODO(#1773): add swizzled scale format here
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
block_size = 32
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
FileCheck().check("def call(").check_count(
".run(", 1, exactly=True
).run(code[0])
11 changes: 10 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,25 @@ def to_mx(
data_lp = torch.clamp(
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = data_lp.reshape(orig_shape)

# cast to target dtype
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_lp = data_lp.to(elem_dtype)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E2M3:
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E3M2:
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP4:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
else:
Expand Down

0 comments on commit d233496

Please sign in to comment.