Skip to content

Commit

Permalink
Add benchmarks for ops, eager vs. inductor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed Dec 7, 2023
1 parent e8e6395 commit 7e1df81
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 10 deletions.
118 changes: 118 additions & 0 deletions benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import argparse
import torch
import sys
sys.path.append(sys.path[0])
from int8_ops import (
double_quant_eager,
double_quant,
transform_eager,
transform,
igemmlt,
mm_dequant_eager,
mm_dequant,
extract_outliers_eager,
extract_outliers,
)
import time

parser = argparse.ArgumentParser(description="Benchmarks for bnb int8 ops, eager vs inductor")
parser.add_argument("--num-active", default=20, type=int, help="number of active iterations for benchmark")
parser.add_argument("--num-warmup", default=10, type=int, help="number of warmup iterations for benchmark")
parser.add_argument("--all", action="store_true", help="Run all benchmarks")
parser.add_argument("--double-quant", action="store_true", help="Run benchmark for the double_quant op")
parser.add_argument("--transform", action="store_true", help="Run benchmark for the transform op")
parser.add_argument("--mm-dequant", action="store_true", help="Run benchmark for the mm_dequant op")
parser.add_argument("--extract-outliers", action="store_true", help="Run benchmark for the extract_outliers op")
parser.add_argument("--profile", action="store_true", help="Run all benchmarks with PyTorch profiler")
args = parser.parse_args()


def trace_handler(prof):
print(prof.key_averages().table(
sort_by="cpu_time_total", row_limit=-1))


def run_benchmark(func_name, func_eager, func_inductor, *func_args, **func_kwargs):
with torch.no_grad():
# Eager
for _ in range(args.num_warmup):
func_eager(*func_args, **func_kwargs)
t0 = time.time()
for _ in range(args.num_active):
func_eager(*func_args, **func_kwargs)
latency_eager = (time.time() - t0) / args.num_active
latency_eager = round(latency_eager * 1000, 3)
# Inductor
for _ in range(args.num_warmup):
func_inductor(*func_args, **func_kwargs)
t0 = time.time()
for _ in range(args.num_active):
func_inductor(*func_args, **func_kwargs)
latency_inductor = (time.time() - t0) / args.num_active
latency_inductor = round(latency_inductor * 1000, 3)
print(f"--- Benchmkark for {func_name} ---")
if args.profile:
print("\nProfiling for eager")
with torch.no_grad(), torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
schedule=torch.profiler.schedule(
wait=0, warmup=3, active=1, repeat=0),
on_trace_ready=trace_handler
) as p:
for _ in range(4):
func_eager(*func_args, **func_kwargs)
p.step()
print("\nProfiling for inductor")
with torch.no_grad(), torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
schedule=torch.profiler.schedule(
wait=0, warmup=3, active=1, repeat=0),
on_trace_ready=trace_handler
) as p:
for _ in range(4):
func_inductor(*func_args, **func_kwargs)
p.step()
print("\n--- Summary ---")
print(f"Eager latency: {latency_eager} ms, inductor latency: {latency_inductor} ms, "
f"speedup: {round(latency_eager - latency_inductor, 3)} ms ({round(latency_eager / latency_inductor - 1, 4) * 100}%)")
print("----------\n")


if args.double_quant or args.all:
A = torch.rand(4096, 4096) * 3.0
threshold = 3.0
func_args = (A,)
func_kwargs = {'threshold': threshold}
run_benchmark('double_quant', double_quant_eager, double_quant, *func_args, **func_kwargs)

if args.transform or args.all:
A = torch.rand(4096, 4096)
func_args = (A,)
func_kwargs = {'transpose': True}
run_benchmark('transform', transform_eager, transform, *func_args, **func_kwargs)

if args.mm_dequant or args.all:
shapeA, shapeB = (4096, 4096), (4096, 4096)
A = torch.rand(shapeA)
A_min, A_max = A.aminmax(dim=-1)
A_stats = torch.max(A_max, A_min.neg())
A_scale = A_stats / 127
A_int8 = torch.round(A / A_scale.unsqueeze(-1)).to(torch.int8)
B = torch.randn(shapeB)
B_min, B_max = B.aminmax(dim=-1)
B_stats = torch.max(B_max, B_min.neg())
B_scale = B_stats / 127
B_int8 = torch.round(B / B_scale.unsqueeze(-1)).to(torch.int8)
bias = torch.randn(shapeB[0])
C_i32, _ = igemmlt(A_int8, B_int8)
func_args = (C_i32, None, A_stats, B_stats)
func_kwargs = {'bias': bias}
run_benchmark('mm_dequant', mm_dequant_eager, mm_dequant, *func_args, **func_kwargs)

if args.extract_outliers or args.all:
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int())
A = torch.randint(-128, 127, size=shapeA).to(torch.int8)
func_args = (A, None, idx)
func_kwargs = {}
run_benchmark('extract_outliers', extract_outliers_eager, extract_outliers, *func_args, **func_kwargs)
22 changes: 12 additions & 10 deletions int8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def assert_on_cpu(tensors):
return on_cpu


@torch.compile(dynamic=True, options={"fx_graph_cache": True})
def double_quant(
def double_quant_eager(
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
"""
Expand Down Expand Up @@ -84,8 +83,7 @@ def quant_to_int8(A, stats):
return out_row, out_col, row_stats, col_stats, coo_tensor


@torch.compile(dynamic=True, options={"fx_graph_cache": True})
def transform(A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None):
def transform_eager(A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None):
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Expand Down Expand Up @@ -161,8 +159,7 @@ def igemmlt(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32):
return out, Sout


@torch.compile(dynamic=True, options={"fx_graph_cache": True})
def mm_dequant(
def mm_dequant_eager(
A,
quant_state,
row_stats,
Expand Down Expand Up @@ -190,7 +187,7 @@ def mm_dequant(
assert_on_cpu([A, row_stats, col_stats, out, bias])
assert A.dtype == torch.int32
compute_dtype = torch.float
output_dtype = mm_dequant.output_dtype
output_dtype = mm_dequant_eager.output_dtype
out_shape = A.shape
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
Expand All @@ -204,13 +201,18 @@ def mm_dequant(
out = out.to(output_dtype)
return out

mm_dequant.output_dtype = torch.bfloat16
mm_dequant_eager.output_dtype = torch.bfloat16


@torch.compile(dynamic=True, options={"fx_graph_cache": True})
def extract_outliers(A, SA, idx):
def extract_outliers_eager(A, SA, idx):
"""
Extract columns of A by idx
"""
assert_on_cpu([A])
return A[:, idx].contiguous()


double_quant = torch.compile(double_quant_eager ,dynamic=True, options={"fx_graph_cache": True})
transform = torch.compile(transform_eager, dynamic=True, options={"fx_graph_cache": True})
mm_dequant = torch.compile(mm_dequant_eager, dynamic=True, options={"fx_graph_cache": True})
extract_outliers = torch.compile(extract_outliers_eager, dynamic=True, options={"fx_graph_cache": True})

0 comments on commit 7e1df81

Please sign in to comment.