Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Dec 5, 2024
1 parent 757db43 commit 2867e2f
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,12 +1634,10 @@ def _attn_bwd_dkdv_ws(
for blk_idx in range(num_steps):
with tl.async_task([0]):
qT = tl.load(qT_ptrs)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
# with tl.async_task([0]):
# do = tl.load(do_ptrs)
with tl.async_task([1, 2]):
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
pT = tl.math.exp2(qkT - m[None, :])
Expand All @@ -1661,10 +1659,11 @@ def _attn_bwd_dkdv_ws(
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.bfloat16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
# Increment pointers.
curr_m += step_m
with tl.async_task([0]):
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
return dk, dv


Expand Down Expand Up @@ -1722,10 +1721,11 @@ def _attn_bwd_dq_ws(
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
# Increment pointers.
curr_n += step_n
with tl.async_task([0]):
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
return dq


Expand Down

0 comments on commit 2867e2f

Please sign in to comment.