Skip to content

Commit

Permalink
[1/x] mx roofline: make the script work on NVIDIA B200
Browse files Browse the repository at this point in the history
Summary:

Makes the roofline estimation script work on NVIDIA B200:
1. skip rowwise scaling (does not work yet on B200)
2. add proper values for peak tensor core flops and peak memory
   bandwidth for B200.

This script needs a lot of future improvements to be useful, will do in a stack.

Test Plan:

```
python benchmarks/float8/float8_roofline.py ~/local/tmp/20250225_test.csv
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3c9b1cdbae19542ca45793114dbab48e89365911
ghstack-comment-id: 2683396978
Pull Request resolved: #1778
  • Loading branch information
vkuzo committed Feb 25, 2025
1 parent 98c4e2e commit ee8282e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
31 changes: 21 additions & 10 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
get_float8_mem_sympy,
get_gemm_time_sympy,
)
from torchao.utils import is_sm_at_least_90, is_sm_at_least_100


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -154,10 +155,13 @@ def do_matmul(A, B):

f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
if is_sm_at_least_90() and (not is_sm_at_least_100()):
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
else:
f8_axs_time_s = -1.0

# save to cache if needed
if cache_filename is not None:
Expand Down Expand Up @@ -298,17 +302,24 @@ def run(
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)

# get the float8 dynamic scaling gpu kernel time

torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name("rowwise")
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
# get the float8 dynamic axiswise scaling gpu kernel time, if supported
# on current hardware
if is_sm_at_least_90() and (not is_sm_at_least_100()):
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name("rowwise")
m_fp8_dyn_axs = convert_to_float8_training(
copy.deepcopy(m_orig), config=config
)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
else:
fp8_dyn_axs_time_actual_s = -1.0

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def profiler_output_to_filtered_time_by_kernel_name(
continue
elif e.key == "cudaDeviceSynchronize":
continue
elif e.key == "Activity Buffer Request":
continue

kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
return kernel_name_to_gpu_time_us
Expand Down
59 changes: 43 additions & 16 deletions torchao/float8/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,43 @@
BYTES_PER_EL_FLOAT8 = 1
BYTES_PER_EL_BF16 = 2

# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
H100_BF16_PEAK_TOPS = 989e12
H100_FP8_PEAK_TOPS = 1979e12
gpu_name_to_specs = {
"NVIDIA H100": {
# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
"bf16_peak_tops": 989e12,
"fp8_peak_tops": 1979e12,
# 2.4 TB per second, custom to Meta's H100 variant
"peak_mem_bw_bytes_sec": 2.4e12,
# based on quick experimental observation with sample large inputs
"pct_achievable_gemm_tops": 0.6,
# based on previous experience looking at pointwise triton kernels with large inputs,
# which would hit about 2.2k GBPS on Meta's H100 variant
"pct_achievable_mem_bw": 0.92,
},
"NVIDIA B200": {
# https://resources.nvidia.com/en-us-blackwell-architecture, page 19,
# divide by 2 because no sparsity
"bf16_peak_tops": 2.25e15,
"fp8_peak_tops": 4.5e15,
"fp4_peak_tops": 9.0e15,
# https://resources.nvidia.com/en-us-blackwell-architecture, page 20
# 8.0 TB per second
"peak_mem_bw_bytes_sec": 8.0e12,
# for now, copy over from H100
# TODO(future): measure once we have the hardware
"pct_achievable_gemm_tops": 0.6,
# for now, copy over from H100
# TODO(future): measure once we have the hardware
"pct_achievable_mem_bw": 0.92,
},
# TODO(future): more GPU names
}


def get_specs():
gpu_name = torch.cuda.get_device_name(0)
return gpu_name_to_specs[gpu_name]

# 2.4 TB per second, custom to Meta's H100 variant
H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12

# based on quick experimental observation with sample large inputs
H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6

# based on previous experience looking at pointwise triton kernels with large inputs,
# which would hit about 2.2k GBPS on Meta's H100 variant
H100_PCT_ACHIEVABLE_MEM_BW = 0.92

# Source: run a triton kernel with a single element read/write on an H100 and
# measure GPU time from the trace
Expand Down Expand Up @@ -65,12 +89,13 @@ def get_tensor_memory_traffic_bytes(


def get_gemm_time_sympy(M, K, N, dtype):
specs = get_specs()
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
if dtype is torch.bfloat16:
peak_tops = H100_BF16_PEAK_TOPS
peak_tops = specs["bf16_peak_tops"]
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
peak_tops = H100_FP8_PEAK_TOPS
gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
peak_tops = specs["fp8_peak_tops"]
gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
return gemm_time_s


Expand All @@ -87,6 +112,8 @@ def get_float8_mem_sympy(
assert scaling_type_weight in ("dynamic",), "unsupported"
assert scaling_type_grad_output in ("dynamic",), "unsupported"

specs = get_specs()

# there are three gemms in the fwd/bwd of a linear:
#
# input @ weight_t = output
Expand Down Expand Up @@ -148,7 +175,7 @@ def get_float8_mem_sympy(
)
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
fp8_mem_time_s = (
fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
fp8_total_mem / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
)

# Adjust final estimate for small kernel launches
Expand Down

0 comments on commit ee8282e

Please sign in to comment.