diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index ee9aa61d5b..1ed0f5e5b0 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -58,14 +58,12 @@ ) from torchao.float8 import ( - Float8LinearConfig, convert_to_float8_training, ) from torchao.float8.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, ) -from torchao.utils import is_sm_at_least_90, is_sm_at_least_100 class LNLinearSigmoid(torch.nn.Module): @@ -155,21 +153,13 @@ def do_matmul(A, B): f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) - if is_sm_at_least_90() and (not is_sm_at_least_100()): - scale_a = torch.ones(M, 1, device=device) - scale_b = torch.ones(1, N, device=device) - fast_accum = True # for axiswise - f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) - else: - f8_axs_time_s = -1.0 - # save to cache if needed if cache_filename is not None: - cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s] + cache[key] = [bf16_time_s, f8_time_s] with open(cache_filename, "w") as f: json.dump(cache, f) - return bf16_time_s, f8_time_s, f8_axs_time_s + return bf16_time_s, f8_time_s def run( @@ -229,18 +219,13 @@ def run( # gemm microbenchmarks "bf16_gemm_s", "fp8_gemm_s", - "fp8_axs_gemm_time_s", # roofline memory overhead estimates - "fp8_oh_dyn_limit", - "fp8_oh_dyn_nolimit", + "fp8_oh_estimated", + "fp8_oh_ideal", # actual e2e measurements "bf16_s", "fp8_dyn_s", - "fp8_dyn_axs_s", - # 'fp8_lw_s', "fp8_dyn_sp", - "fp8_dyn_axs_sp", - # 'fp8_lw_sp', ] results = [] @@ -251,18 +236,17 @@ def run( break if gemm_time_strategy == "benchmarks": - bf16_g1, f8_g1, f8_g1_axs = get_gemm_times( + bf16_g1, f8_g1 = get_gemm_times( M_val, K_val, N_val, True, gemm_cache_filename ) - bf16_g2, f8_g2, f8_g2_axs = get_gemm_times( + bf16_g2, f8_g2 = get_gemm_times( M_val, N_val, K_val, False, gemm_cache_filename ) - bf16_g3, f8_g3, f8_g3_axs = get_gemm_times( + bf16_g3, f8_g3 = get_gemm_times( K_val, M_val, N_val, False, gemm_cache_filename ) bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 - fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs else: assert gemm_time_strategy == "roofline", "unsupported" bf16_time_val = ( @@ -271,8 +255,6 @@ def run( fp8_gemm_time_s = ( fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - # for now, assume axiswise gemm is similar to tensorwise - fp8_axs_gemm_time_s = fp8_gemm_time_s fp8_mem_time_dyn_limit_s = ( fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) @@ -299,28 +281,6 @@ def run( m_fp8_dyn = torch.compile(m_fp8_dyn) fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) - # get the float8 dynamic axiswise scaling gpu kernel time, if supported - # on current hardware - if is_sm_at_least_90() and (not is_sm_at_least_100()): - torch._dynamo.reset() - config = Float8LinearConfig.from_recipe_name("rowwise") - m_fp8_dyn_axs = convert_to_float8_training( - copy.deepcopy(m_orig), config=config - ) - m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) - fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) - else: - fp8_dyn_axs_time_actual_s = -1.0 - - # get the lw recipe scaling gpu kernel time - # TODO(future PR): enable below once basic performance issues - # are fixed - # torch._dynamo.reset() - # config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp") - # m_fp8_lw = convert_to_float8_training(m_orig, config=config) - # m_fp8_lw = torch.compile(m_fp8_lw) - # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) - results.append( [ M_val, @@ -329,18 +289,13 @@ def run( # gemm microbenchmarks bf16_time_val, fp8_gemm_time_s, - fp8_axs_gemm_time_s, # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, - fp8_dyn_axs_time_actual_s, - # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_dyn_axs_time_actual_s, - # bf16_time_actual_s / fp8_lw_time_actual_s, ] )