Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Dec 3, 2024
1 parent 5eb2302 commit 6545a85
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 35 deletions.
65 changes: 32 additions & 33 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128] #64, 128]
for BN in [128] #64, 128]
for s in [3] #3, 4, 7]
for w in [8] #4, 8]
for BM in [128] # 64, 128]
for BN in [128] # 64, 128]
for s in [3] # 3, 4, 7]
for w in [8] # 4, 8]
]
# TMA, WS, and CompPipe
configsTmaWS = [
Expand Down Expand Up @@ -1637,11 +1637,11 @@ def _attn_bwd_dkdv_ws(
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
#with tl.async_task([0]):
# with tl.async_task([0]):
# do = tl.load(do_ptrs)
with tl.async_task([1, 2]):
qkT = tl.dot(k, qT)
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
Expand Down Expand Up @@ -1746,10 +1746,10 @@ def keep2(conf):
"BLOCK_M2": BN,
"BLOCK_N2": BM,
},
num_stages=s, #0 or s,
num_stages=s, # 0 or s,
num_warps=w,
num_buffers_warp_spec=0, #0 or 2,
num_consumer_groups=0, #0 or 1,
num_buffers_warp_spec=0, # 0 or 2,
num_consumer_groups=0, # 0 or 1,
)
if has_warp_spec
else triton.Config(
Expand All @@ -1763,10 +1763,10 @@ def keep2(conf):
num_warps=w,
)
)
for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
for BN in [128] #64, 128]
for s in [3]#, 4, 7]
for w in [4]#, 8]
for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
for BN in [128] # 64, 128]
for s in [3] # , 4, 7]
for w in [4] # , 8]
]
configsBwdWs = [
(
Expand Down Expand Up @@ -1794,10 +1794,10 @@ def keep2(conf):
num_warps=w,
)
)
for BM in [64] #32, 64] # BLOCK_N1 % BLOCK_M1 == 0
for BN in [128] #[64] #64, 128]
for s in [3]#, 4, 7]
for w in [4]#, 8]
for BM in [64] # 32, 64] # BLOCK_N1 % BLOCK_M1 == 0
for BN in [128] # [64] #64, 128]
for s in [3] # , 4, 7]
for w in [4] # , 8]
]


Expand Down Expand Up @@ -2698,8 +2698,8 @@ def backward(ctx, do):
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128

#NUM_WARPS, NUM_STAGES = 4, 5
#BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
# NUM_WARPS, NUM_STAGES = 4, 5
# BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32

BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
Expand All @@ -2720,8 +2720,7 @@ def backward(ctx, do):
HEAD_DIM=ctx.HEAD_DIM, #
)
grid = lambda args: (N_CTX // args["BLOCK_N1"], 1, BATCH * N_HEAD)
#grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
print(q.stride(0), q.stride(1), q.stride(2), q.stride(3))
# grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
if ctx.bwdVariant == "base":
_attn_bwd[grid](
q,
Expand All @@ -2740,14 +2739,14 @@ def backward(ctx, do):
q.stride(3), #
N_HEAD,
N_CTX, #
#BLOCK_M1=BLOCK_M1,
#BLOCK_N1=BLOCK_N1, #
#BLOCK_M2=BLOCK_M2,
#BLOCK_N2=BLOCK_N2, #
# BLOCK_M1=BLOCK_M1,
# BLOCK_N1=BLOCK_N1, #
# BLOCK_M2=BLOCK_M2,
# BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
#num_warps=NUM_WARPS, #
#num_stages=NUM_STAGES, #
# num_warps=NUM_WARPS, #
# num_stages=NUM_STAGES, #
)
elif ctx.bwdVariant == "ws":
_attn_bwd_ws[grid](
Expand All @@ -2767,14 +2766,14 @@ def backward(ctx, do):
q.stride(3), #
N_HEAD,
N_CTX, #
#BLOCK_M1=BLOCK_M1,
#BLOCK_N1=BLOCK_N1, #
#BLOCK_M2=BLOCK_M2,
#BLOCK_N2=BLOCK_N2, #
# BLOCK_M1=BLOCK_M1,
# BLOCK_N1=BLOCK_N1, #
# BLOCK_M2=BLOCK_M2,
# BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
#num_warps=NUM_WARPS, #
#num_stages=NUM_STAGES, #
# num_warps=NUM_WARPS, #
# num_stages=NUM_STAGES, #
)

return dq, dk, dv, None, None, None, None
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ def tflops(
BATCH, H, N_CTX, D_HEAD = q.shape
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
tflops = 2 * flops_per_matmul
print("causal, mode: ", self.causal, self.mode)
print("fn_name: ", fn_name, metrics.latency)
if self.causal:
tflops *= 0.5
if self.mode == BenchmarkMode.BWD:
Expand Down

0 comments on commit 6545a85

Please sign in to comment.