Skip to content

Commit

Permalink
[bugfix] use the general method to keep the dimension of sum reduce to 1
Browse files Browse the repository at this point in the history
  • Loading branch information
pengyang1 committed Jun 7, 2024
1 parent 3dd5b06 commit 209dbc4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def layer_norm_backward_kernel(
dx_part2 += dx_hat
dx_part3 += dx_hat * x_hat

dx_2 = tl.sum(dx_part2, axis=1, keep_dims=True)
dx_3 = tl.sum(dx_part3, axis=1, keep_dims=True)
dx_2 = tl.sum(dx_part2, axis=1)[:, None]
dx_3 = tl.sum(dx_part3, axis=1)[:, None]

for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)
Expand Down

0 comments on commit 209dbc4

Please sign in to comment.