Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 26, 2025
2 parents 4bed185 + ac27fdd commit acc907a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
26 changes: 25 additions & 1 deletion benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ def main(
"fwd",
"cast_only",
"cast_with_to_blocked",
"cast_only_dim0_dim1",
)
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`"
if mode_filter == "cast_only":
assert experiment_filter == "lowp", "unsupported"

Expand Down Expand Up @@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
return x_mx._data, scale_blocked

# this function is used for cast_only_dim0_dim1
def cast_only_dim0_dim1(x_hp):
x_hp_t_c = x_hp.t().contiguous()
x_mx_dim0 = MXTensor.to_mx(
x_hp,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
x_mx_dim1 = MXTensor.to_mx(
x_hp_t_c,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
return x_mx_dim0, x_mx_dim1

print("m_ref", m_ref)
print("m_lowp", m_lowp)
print("input_tensor.shape", input_tensor.shape)
Expand Down Expand Up @@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
elif mode_filter == "cast_with_to_blocked":
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
return
elif mode_filter == "cast_only_dim0_dim1":
_input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1(
input_tensor,
)
return

if enable_activation_checkpointing:
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
Expand All @@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
m_lowp = torch.compile(m_lowp, fullgraph=True)
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True)

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
Expand Down
24 changes: 24 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest
import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
Expand Down Expand Up @@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
def test_to_mx_inductor_single_kernel():
"""
Verify that inductor can fuse the cast of a high precision tensor to mx
into a single kernel
"""
# TODO(future PR): add fp4 and fp6 here
# TODO(#1773): add swizzled scale format here
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
block_size = 32
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
11 changes: 10 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,25 @@ def to_mx(
data_lp = torch.clamp(
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = data_lp.reshape(orig_shape)

# cast to target dtype
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_lp = data_lp.to(elem_dtype)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E2M3:
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E3M2:
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP4:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
else:
Expand Down

0 comments on commit acc907a

Please sign in to comment.