Skip to content

Commit

Permalink
roofline estimation: delete axiswise scaling, for now
Browse files Browse the repository at this point in the history
Summary:

This was not added correctly since it was reusing the tensorwise scaling
overhead estimates, deleting for now and we can add back later in a
cleaner way.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d9f2d1a40f563fb13a7b3f0b062279b2382024dd
ghstack-comment-id: 2683596591
Pull Request resolved: #1782
  • Loading branch information
vkuzo committed Feb 26, 2025
1 parent 6f5ce64 commit 2a52280
Showing 1 changed file with 7 additions and 52 deletions.
59 changes: 7 additions & 52 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@
)

from torchao.float8 import (
Float8LinearConfig,
convert_to_float8_training,
)
from torchao.float8.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
)
from torchao.utils import is_sm_at_least_90, is_sm_at_least_100


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -155,21 +153,13 @@ def do_matmul(A, B):

f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

if is_sm_at_least_90() and (not is_sm_at_least_100()):
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
else:
f8_axs_time_s = -1.0

# save to cache if needed
if cache_filename is not None:
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
cache[key] = [bf16_time_s, f8_time_s]
with open(cache_filename, "w") as f:
json.dump(cache, f)

return bf16_time_s, f8_time_s, f8_axs_time_s
return bf16_time_s, f8_time_s


def run(
Expand Down Expand Up @@ -229,18 +219,13 @@ def run(
# gemm microbenchmarks
"bf16_gemm_s",
"fp8_gemm_s",
"fp8_axs_gemm_time_s",
# roofline memory overhead estimates
"fp8_oh_dyn_limit",
"fp8_oh_dyn_nolimit",
"fp8_oh_estimated",
"fp8_oh_ideal",
# actual e2e measurements
"bf16_s",
"fp8_dyn_s",
"fp8_dyn_axs_s",
# 'fp8_lw_s',
"fp8_dyn_sp",
"fp8_dyn_axs_sp",
# 'fp8_lw_sp',
]
results = []

Expand All @@ -251,18 +236,17 @@ def run(
break

if gemm_time_strategy == "benchmarks":
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(
bf16_g1, f8_g1 = get_gemm_times(
M_val, K_val, N_val, True, gemm_cache_filename
)
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(
bf16_g2, f8_g2 = get_gemm_times(
M_val, N_val, K_val, False, gemm_cache_filename
)
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(
bf16_g3, f8_g3 = get_gemm_times(
K_val, M_val, N_val, False, gemm_cache_filename
)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = (
Expand All @@ -271,8 +255,6 @@ def run(
fp8_gemm_time_s = (
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
# for now, assume axiswise gemm is similar to tensorwise
fp8_axs_gemm_time_s = fp8_gemm_time_s

fp8_mem_time_dyn_limit_s = (
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
Expand All @@ -299,28 +281,6 @@ def run(
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)

# get the float8 dynamic axiswise scaling gpu kernel time, if supported
# on current hardware
if is_sm_at_least_90() and (not is_sm_at_least_100()):
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name("rowwise")
m_fp8_dyn_axs = convert_to_float8_training(
copy.deepcopy(m_orig), config=config
)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
else:
fp8_dyn_axs_time_actual_s = -1.0

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

results.append(
[
M_val,
Expand All @@ -329,18 +289,13 @@ def run(
# gemm microbenchmarks
bf16_time_val,
fp8_gemm_time_s,
fp8_axs_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
# e2e numbers
bf16_time_actual_s,
fp8_dyn_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
]
)

Expand Down

0 comments on commit 2a52280

Please sign in to comment.