From 209dbc4ac22049e181d40de2ba74aede17eb6e6e Mon Sep 17 00:00:00 2001 From: pengyang1 Date: Fri, 7 Jun 2024 14:36:26 +0800 Subject: [PATCH] [bugfix] use the general method to keep the dimension of sum reduce to 1 --- src/flag_gems/ops/layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index 492fa329..7a096a4a 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -132,8 +132,8 @@ def layer_norm_backward_kernel( dx_part2 += dx_hat dx_part3 += dx_hat * x_hat - dx_2 = tl.sum(dx_part2, axis=1, keep_dims=True) - dx_3 = tl.sum(dx_part3, axis=1, keep_dims=True) + dx_2 = tl.sum(dx_part2, axis=1)[:, None] + dx_3 = tl.sum(dx_part3, axis=1)[:, None] for off in range(0, N, BLOCK_COL_SIZE): cols = off + tl.arange(0, BLOCK_COL_SIZE)