Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 24, 2025
1 parent 8d38814 commit 5e51996
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
33 changes: 27 additions & 6 deletions benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked

# don't truncate long kernel names
pd.options.display.max_colwidth = 100
Expand Down Expand Up @@ -298,11 +299,15 @@ def main(
"lowp",
"ref",
), "experiment_filter must be one of `both`, `lowp`, `ref`"
assert mode_filter in (
"fwd_bwd",
"fwd",
"cast_only",
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`"
assert (
mode_filter
in (
"fwd_bwd",
"fwd",
"cast_only",
"cast_with_to_blocked",
)
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
if mode_filter == "cast_only":
assert experiment_filter == "lowp", "unsupported"

Expand Down Expand Up @@ -378,14 +383,26 @@ def main(
# this function is only used for cast_only
to_mx_func = MXTensor.to_mx

# this function is used for cast_with_to_blocked
def cast_with_to_blocked(x_hp):
x_mx = MXTensor.to_mx(
x_hp,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
m, k = x_hp.shape
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
return x_mx._data, scale_blocked

print("m_ref", m_ref)
print("m_lowp", m_lowp)
print("input_tensor.shape", input_tensor.shape)
print("grad_output.shape", grad_output.shape)
print()

def ref_forw_backward(x):
assert mode_filter != "cast_only", "unsupported"
assert mode_filter not in ("cast_only", "cast_with_to_blocked"), "unsupported"
if enable_activation_checkpointing:
out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn)
else:
Expand All @@ -403,6 +420,9 @@ def lowp_forw_backward_wrapper(x):
gemm_kernel_choice=config.gemm_kernel_choice,
)
return
elif mode_filter == "cast_with_to_blocked":
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
return

if enable_activation_checkpointing:
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
Expand All @@ -416,6 +436,7 @@ def lowp_forw_backward_wrapper(x):
m_ref = torch.compile(m_ref, fullgraph=True)
m_lowp = torch.compile(m_lowp, fullgraph=True)
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def mx_mm(aten_op, args, kwargs=None):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert b._data.t().is_contiguous()
# TODO(future PR): use block_size instead of hardcoding 32
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
a_scale_block = to_blocked(a_scale)
Expand Down

0 comments on commit 5e51996

Please sign in to comment.