From d2334960f150c1ae2bee1a96a7215707914d6235 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Feb 2025 13:09:45 -0800 Subject: [PATCH] modify cast from hp to mx to help inductor fuse Summary: Thanks to investigation from @eellison, moving the reshape to the end of the cast helps inductor fuse the cast into a single kernel. This doesn't yet work with fp4, but let's unblock fp8 and deal with fp4 later. Fixes https://github.com/pytorch/ao/issues/1690 Note: in the repro with swizzling from https://github.com/pytorch/ao/issues/1773, we go from 3 to 2 kernels. Further investigation is needed whether we can fuse the swizzling. Test Plan: ``` pytest test/prototype/mx_formats/test_mx_tensor.py -x -s -k test_to_mx_inductor_single_kernel ``` Reviewers: Subscribers: Tasks: Tags: --- test/prototype/mx_formats/test_mx_tensor.py | 24 +++++++++++++++++++++ torchao/prototype/mx_formats/mx_tensor.py | 11 +++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index f5014b7e31..385d0da613 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -6,6 +6,8 @@ import pytest import torch +from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( @@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): use_fp4_custom_triton_dequant_kernel, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) +@pytest.mark.skipif( + not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater" +) +def test_to_mx_inductor_single_kernel(): + """ + Verify that inductor can fuse the cast of a high precision tensor to mx + into a single kernel + """ + # TODO(future PR): add fp4 and fp6 here + # TODO(#1773): add swizzled scale format here + x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") + block_size = 32 + to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) + out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size) + FileCheck().check("def call(").check_count( + ".run(", 1, exactly=True + ).run(code[0]) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 6c0a718c78..c25ca175e1 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -205,16 +205,25 @@ def to_mx( data_lp = torch.clamp( data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos ) - data_lp = data_lp.reshape(orig_shape) # cast to target dtype if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_lp = data_lp.to(elem_dtype) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E2M3: data_lp = f32_to_f6_e2m3_unpacked(data_lp) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: data_lp = f32_to_f6_e3m2_unpacked(data_lp) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP4: + # can't reshape at the end without handling it in the packing code, + # punt until later since we'll need to rethink the torch.compile + # approach for fp4x2 in any case + data_lp = data_lp.reshape(orig_shape) data_lp = f32_to_f4_unpacked(data_lp) data_lp = pack_uint4(data_lp) else: