Skip to content

Commit

Permalink
roofline estimator: simplify
Browse files Browse the repository at this point in the history
Summary:

1. remove estimating torch.compile limitations (interesting but hasn't
   been useful)
2. make clearer distinction between roofline and benchmarked values

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3348fff4abce61a52265e94badbe58426bd342ab
ghstack-comment-id: 2683886949
Pull Request resolved: #1783
  • Loading branch information
vkuzo committed Feb 26, 2025
1 parent 2a52280 commit 6acc083
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 90 deletions.
143 changes: 67 additions & 76 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@
input_t @ grad_output = grad_weight
KxM @ MxN => KxN
2. we properly model the worst-case of the current torch.compile limitations regarding
float8 scaling
3. assume for float8 activations/gradients that torch.compile will fuse to the
2. assume for float8 activations/gradients that torch.compile will fuse to the
preceding op. Note that this is not always true in practice.
4. assume no AC (TODO model it)
5. assume no float8 all-gather (TODO model it)
3. assume no AC (TODO model it)
4. assume no float8 all-gather (TODO model it)
"""

import copy
Expand Down Expand Up @@ -164,68 +162,55 @@ def do_matmul(A, B):

def run(
outfile: str,
gemm_time_strategy: str = "benchmarks",
model_torch_compile_limitations: bool = False,
do_benchmarks: bool = True,
shape_gen_name: str = "square",
gemm_cache_filename: Optional[str] = None,
n_limit: Optional[int] = None,
):
"""
Args:
* `gemm_time_strategy`:
- `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
- `roofline`: use roofline model for gemm times (only accurate for large shapes)
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `square`, or `sweep`
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
"""

print(f"gemm_time_strategy: {gemm_time_strategy}")
print(f"do_benchmarks: {do_benchmarks}")
print(f"shape_gen_name: {shape_gen_name}")

assert gemm_time_strategy in (
"benchmarks",
"roofline",
), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'"

M, K, N = sympy.symbols("M K N")

fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=True,
)
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=False,
)

if gemm_time_strategy == "roofline":
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print()
else:
print()
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print()

headers = [
"fwd_M",
"fwd_K",
"fwd_N",
# gemm microbenchmarks
"bf16_gemm_s",
"fp8_gemm_s",
# roofline memory overhead estimates
"fp8_oh_estimated",
"fp8_oh_ideal",
# gemm time - roofline
"r_bf16_gemm_s",
"r_fp8_gemm_s",
# gemm time - microbenchmarks
"b_bf16_gemm_s",
"b_fp8_gemm_s",
# memory overhead - roofline
"r_fp8_ovhd",
# gemm + overhead roofline estimations (does not include prev/next ops)
"r_gemm_and_ovhd_fp8_s",
"r_gemm_and_ovhd_fp8_spdp",
# actual e2e measurements
"bf16_s",
"fp8_dyn_s",
"fp8_dyn_sp",
"b_e2e_bf16_s",
"b_e2e_fp8_s",
"b_e2e_fp8_spdp",
]
results = []

Expand All @@ -235,7 +220,17 @@ def run(
if n_limit is not None and idx >= n_limit:
break

if gemm_time_strategy == "benchmarks":
# use roofline model to estimate gemm time
r_bf16_gemm_time_s = (
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
r_fp8_gemm_time_s = (
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# if enabled, also measured observed gemm time
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
if do_benchmarks:
bf16_g1, f8_g1 = get_gemm_times(
M_val, K_val, N_val, True, gemm_cache_filename
)
Expand All @@ -245,57 +240,53 @@ def run(
bf16_g3, f8_g3 = get_gemm_times(
K_val, M_val, N_val, False, gemm_cache_filename
)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = (
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_gemm_time_s = (
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3

fp8_mem_time_dyn_limit_s = (
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_dyn_nolimit_s = (
r_fp8_mem_time_dyn_nolimit_s = (
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
x = torch.randn(
M_val, K_val, dtype=torch.bfloat16, device="cuda"
).requires_grad_()
bf16_time_actual_s, fp8_dyn_time_actual_s = 0, 0
if do_benchmarks:
# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
x = torch.randn(
M_val, K_val, dtype=torch.bfloat16, device="cuda"
).requires_grad_()

# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)
# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)

# get the float8 dynamic scaling gpu kernel time
# get the float8 dynamic scaling gpu kernel time

torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)

results.append(
[
M_val,
K_val,
N_val,
# gemm microbenchmarks
bf16_time_val,
fp8_gemm_time_s,
# gemm roofline
r_bf16_gemm_time_s,
r_fp8_gemm_time_s,
# gemm benchmarks
b_bf16_gemm_time_s,
b_fp8_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
# e2e numbers
r_fp8_mem_time_dyn_nolimit_s,
# e2e roofline estimations
r_fp8_gemm_time_s + r_fp8_mem_time_dyn_nolimit_s,
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_mem_time_dyn_nolimit_s),
# e2e benchmark numbers
bf16_time_actual_s,
fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / (fp8_dyn_time_actual_s + 1e-20),
]
)

Expand Down
15 changes: 1 addition & 14 deletions torchao/float8/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_tensor_memory_traffic_bytes(
dim0,
dim1,
fuse_with_prev=False,
model_torch_compile_limitations=False,
):
# assumes input bf16, output f8
numel = dim0 * dim1
Expand All @@ -75,15 +74,7 @@ def get_tensor_memory_traffic_bytes(
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel

if model_torch_compile_limitations:
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
# has an extra memory read of the input in fp8
# context: https://github.com/pytorch/pytorch/issues/130015
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
else:
tc_adjustment = 0

return kernel_1_rw + kernel_3_rw + tc_adjustment
return kernel_1_rw + kernel_3_rw


def get_gemm_time_sympy(M, K, N, dtype):
Expand All @@ -101,7 +92,6 @@ def get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations: bool = False,
):
specs = get_specs()

Expand All @@ -123,13 +113,11 @@ def get_float8_mem_sympy(
M,
K,
fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations,
)
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
K,
N,
fuse_with_prev=False,
model_torch_compile_limitations=model_torch_compile_limitations,
)
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem

Expand All @@ -140,7 +128,6 @@ def get_float8_mem_sympy(
M,
N,
fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations,
)
# already casted, assuming that we save weight from fw to bw
# TODO: model this if FSDP float8 all-gather is on
Expand Down

0 comments on commit 6acc083

Please sign in to comment.