From 4a887597185be3e751e0667754121e784e92aa87 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Feb 2025 10:34:22 -0800 Subject: [PATCH] roofline estimator: simplify Summary: 1. remove estimating torch.compile limitations (interesting but hasn't been useful) 2. make clearer distinction between roofline and benchmarked values 3. sympy float -> float cast to fix pandas df formatting Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: dacab7f909eb7214a25a068d7d6ebc9a12a7614d ghstack-comment-id: 2683886949 Pull Request resolved: https://github.com/pytorch/ao/pull/1783 --- benchmarks/float8/float8_roofline.py | 164 +++++++++++------------ torchao/testing/float8/roofline_utils.py | 15 +-- 2 files changed, 83 insertions(+), 96 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 5ce9526ca4..d29ee865e6 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -6,13 +6,14 @@ """ This is a script to estimate the benefit from converting a `torch.nn.Linear` -layer to float8, by estimating the difference in e2e GPU kernel time between: +layer to float8 given a single saturated GPU, by estimating the difference +in e2e GPU kernel time between: 1. bf16 gemms in fwd and bwd, and 2. float8 gemms in fwd and bwd, and float8 overhead The gemm times are estimated either from direct measurements via benchmarks, or with a roofline estimation based on TOPS and peak compute bandwidth of an -NVIDIA H100. +NVIDIA H100 or B200. The float8 overhead times are estimated by counting memory reads and writes based on the specified float8 scaling, and estimating that we can achieve @@ -31,12 +32,10 @@ input_t @ grad_output = grad_weight KxM @ MxN => KxN -2. we properly model the worst-case of the current torch.compile limitations regarding - float8 scaling -3. assume for float8 activations/gradients that torch.compile will fuse to the +2. assume for float8 activations/gradients that torch.compile will fuse to the preceding op. Note that this is not always true in practice. -4. assume no AC (TODO model it) -5. assume no float8 all-gather (TODO model it) +3. assume no AC (TODO model it) +4. assume no float8 all-gather (TODO model it) """ import copy @@ -164,68 +163,60 @@ def do_matmul(A, B): def run( outfile: str, - gemm_time_strategy: str = "benchmarks", - model_torch_compile_limitations: bool = False, + do_benchmarks: bool = True, shape_gen_name: str = "square", gemm_cache_filename: Optional[str] = None, n_limit: Optional[int] = None, ): """ Args: - * `gemm_time_strategy`: - - `benchmarks`: use benchmarks for gemm times (more accurate for all shapes) - - `roofline`: use roofline model for gemm times (only accurate for large shapes) + * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `square`, or `sweep` * `gemm_cache_filename (optional)`: file to cache gemm benchmark results * `n_limit (optional)`: if specified, only runs `n_limit` iterations """ - print(f"gemm_time_strategy: {gemm_time_strategy}") + print(f"do_benchmarks: {do_benchmarks}") print(f"shape_gen_name: {shape_gen_name}") - assert gemm_time_strategy in ( - "benchmarks", - "roofline", - ), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'" - M, K, N = sympy.symbols("M K N") - fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=True, - ) fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy( M, K, N, - model_torch_compile_limitations=False, ) - if gemm_time_strategy == "roofline": - bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) - print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) - fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn) - print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) - print() - else: - print() + bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) + fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print() headers = [ "fwd_M", "fwd_K", "fwd_N", - # gemm microbenchmarks - "bf16_gemm_s", - "fp8_gemm_s", - # roofline memory overhead estimates - "fp8_oh_estimated", - "fp8_oh_ideal", - # actual e2e measurements - "bf16_s", - "fp8_dyn_s", - "fp8_dyn_sp", + # roofline - gemm time (fwd + bwd, 3 gemms) + "r_bf16_gemm_s", + "r_fp8_gemm_s", + # roofline - fp8 overhead time (by counting reads/writes in the ideal case) + "r_fp8_ovhd_s", + # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid) + "r_fp8_gemm_and_ovhd_s", + "r_fp8_gemm_and_ovhd_spdp", + # benchmarks - gemm time (fwd + bwd, 3 gemms) + "b_bf16_gemm_s", + "b_fp8_gemm_s", + # benchmarks - e2e LNLinearSigmoid time fwd + bwd + "b_bf16_e2e_s", + "b_fp8_e2e_s", + # note that e2e speedup is not the same as the roofline speedup: + # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time) + # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid) + # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple + # we don't break them out and don't have a roofline for them. + "b_fp8_e2e_spdp", ] results = [] @@ -235,7 +226,18 @@ def run( if n_limit is not None and idx >= n_limit: break - if gemm_time_strategy == "benchmarks": + # use roofline model to estimate gemm time + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + if do_benchmarks: bf16_g1, f8_g1 = get_gemm_times( M_val, K_val, N_val, True, gemm_cache_filename ) @@ -245,60 +247,58 @@ def run( bf16_g3, f8_g3 = get_gemm_times( K_val, M_val, N_val, False, gemm_cache_filename ) - bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 - fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 - else: - assert gemm_time_strategy == "roofline", "unsupported" - bf16_time_val = ( - bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - fp8_gemm_time_s = ( - fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) + b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3 + b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 - fp8_mem_time_dyn_limit_s = ( - fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - fp8_mem_time_dyn_nolimit_s = ( + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_fp8_ovhd_time_s = float( fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - # create the model - m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() - x = torch.randn( - M_val, K_val, dtype=torch.bfloat16, device="cuda" - ).requires_grad_() + b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 + if do_benchmarks: + # create the model + m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() + x = torch.randn( + M_val, K_val, dtype=torch.bfloat16, device="cuda" + ).requires_grad_() - # get the bf16 gpu kernel time - torch._dynamo.reset() - m_bf16 = torch.compile(copy.deepcopy(m_orig)) - bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x) + # get the bf16 gpu kernel time + torch._dynamo.reset() + m_bf16 = torch.compile(copy.deepcopy(m_orig)) + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) - # get the float8 dynamic scaling gpu kernel time + # get the float8 dynamic scaling gpu kernel time - torch._dynamo.reset() - m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig)) - m_fp8_dyn = torch.compile(m_fp8_dyn) - fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) + torch._dynamo.reset() + m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig)) + m_fp8_dyn = torch.compile(m_fp8_dyn) + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) results.append( [ M_val, K_val, N_val, - # gemm microbenchmarks - bf16_time_val, - fp8_gemm_time_s, - # roofline overhead estimates - fp8_mem_time_dyn_limit_s, - fp8_mem_time_dyn_nolimit_s, - # e2e numbers - bf16_time_actual_s, - fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_dyn_time_actual_s, + # roofline - gemm + r_bf16_gemm_time_s, + r_fp8_gemm_time_s, + # roofline - fp8 overhead + r_fp8_ovhd_time_s, + # roofline - gemm + overhead, and speedup + r_fp8_gemm_time_s + r_fp8_ovhd_time_s, + r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s), + # benchmarks - gemm + b_bf16_gemm_time_s, + b_fp8_gemm_time_s, + # benchmarks - e2e, and speedup + b_bf16_e2e_time_s, + b_fp8_e2e_time_s, + b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20), ] ) + pd.set_option("display.precision", 2) df = pd.DataFrame(results, columns=headers) print(df) df.to_csv(outfile) diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/float8/roofline_utils.py index 3ff40736ba..458acf8f7b 100644 --- a/torchao/testing/float8/roofline_utils.py +++ b/torchao/testing/float8/roofline_utils.py @@ -56,7 +56,6 @@ def get_tensor_memory_traffic_bytes( dim0, dim1, fuse_with_prev=False, - model_torch_compile_limitations=False, ): # assumes input bf16, output f8 numel = dim0 * dim1 @@ -75,15 +74,7 @@ def get_tensor_memory_traffic_bytes( # kernel 3: read in bf16, write twice in float8 (row-major and col-major) kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - else: - tc_adjustment = 0 - - return kernel_1_rw + kernel_3_rw + tc_adjustment + return kernel_1_rw + kernel_3_rw def get_gemm_time_sympy(M, K, N, dtype): @@ -101,7 +92,6 @@ def get_float8_mem_sympy( M, K, N, - model_torch_compile_limitations: bool = False, ): specs = get_specs() @@ -123,13 +113,11 @@ def get_float8_mem_sympy( M, K, fuse_with_prev=True, - model_torch_compile_limitations=model_torch_compile_limitations, ) fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes( K, N, fuse_with_prev=False, - model_torch_compile_limitations=model_torch_compile_limitations, ) fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem @@ -140,7 +128,6 @@ def get_float8_mem_sympy( M, N, fuse_with_prev=True, - model_torch_compile_limitations=model_torch_compile_limitations, ) # already casted, assuming that we save weight from fw to bw # TODO: model this if FSDP float8 all-gather is on