diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index c38c5235..330b6c54 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -167,6 +167,23 @@ def layer_norm_args(dtype, batch, size): bench.run() +def test_perf_weightnorm(): + def weight_norm_args(dtype, batch, size): + v = torch.randn([batch, size], dtype=dtype, device="cuda") + g = torch.randn([batch, size], dtype=dtype, device="cuda") + return (v, g, 0) + + bench = Benchmark( + op_name="weightnorm", + torch_op=torch._weight_norm, + arg_func=weight_norm_args, + dtypes=[torch.float16, torch.bfloat16], + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_log_softmax(): bench = Benchmark( op_name="log_softmax", diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 7fa23b43..2d8133d7 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -47,6 +47,7 @@ def enable(lib=aten_lib): lib.impl("ge.Scalar", ge_scalar, "CUDA") lib.impl("gelu", gelu, "AutogradCUDA") lib.impl("native_group_norm", group_norm, "AutogradCUDA") + lib.impl("_weight_norm", weight_norm, "CUDA") lib.impl("gt.Tensor", gt, "CUDA") lib.impl("gt.Scalar", gt_scalar, "CUDA") lib.impl("isfinite", isfinite, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 208208b0..f4a0b360 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -85,6 +85,7 @@ from .unique import _unique2 from .var_mean import var_mean from .vector_norm import vector_norm +from .weightnorm import weight_norm from .where import where_scalar_other, where_scalar_self, where_self from .zeros import zeros from .zeros_like import zeros_like @@ -141,6 +142,7 @@ "isinf", "isnan", "layer_norm", + "weight_norm", "le", "le_scalar", "lt", diff --git a/src/flag_gems/ops/weightnorm.py b/src/flag_gems/ops/weightnorm.py new file mode 100644 index 00000000..b295579c --- /dev/null +++ b/src/flag_gems/ops/weightnorm.py @@ -0,0 +1,150 @@ +import logging + +import torch +import triton +import triton.language as tl +from ..utils import libentry +import math + +try: + from triton.language.extra.cuda.libdevice import rsqrt +except ImportError: + try: + from triton.language.math import rsqrt + except ImportError: + from triton.language.libdevice import rsqrt +import pytest + + +def cfggen(): + block_m = [1, 2, 4] + block_n = [1024, 2048, 4096] + warps = [4, 8, 16] + configs = [ + triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w) + for m in block_m + for n in block_n + for w in warps + ] + return configs + +@libentry() +@triton.autotune(configs=cfggen(), key=["o_shape0", "o_shape1", "o_shape2"]) +@triton.jit(do_not_specialize=["eps"]) +def weight_norm_kernel( + output, + v, + v_broad, + g_broad, + o_shape0, + o_shape1, + o_shape2, + v_shape0, + v_shape2, + vb_shape0, + vb_shape1, + vb_shape2, + gb_shape0, + gb_shape1, + gb_shape2, + v_stride0, + v_stride1, + v_stride2, + vb_stride0, + vb_stride1, + vb_stride2, + gb_stride0, + gb_stride1, + gb_stride2, + eps: tl.constexpr, + BLOCK_ROW_SIZE: tl.constexpr, + BLOCK_COL_SIZE: tl.constexpr, +): + tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] + pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE + row_offset = pid + tid_m + row_mask = row_offset < o_shape1 + + tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] + v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) + for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): + col_offset = base + tid_n + m_idx = col_offset // v_shape2 + n_idx = row_offset + k_idx = col_offset % v_shape2 + + mask = m_idx < v_shape0 and row_mask + + v_offsets = m_idx * v_stride0 + n_idx * v_stride1 + k_idx * v_stride2 + v_value = tl.load(v + v_offsets, mask = mask) + v_block += v_value * v_value + v_sum = tl.sum(v_block, axis = 1) + eps + + + for base in range(0, o_shape0 * o_shape2, BLOCK_COL_SIZE): + col_offset = base + tid_n + m_idx = col_offset // o_shape2 + n_idx = row_offset + k_idx = col_offset % o_shape2 + + mask = m_idx < o_shape0 and row_mask + + v_offsets = (m_idx % vb_shape0) * vb_stride0 + (n_idx % vb_shape1) * vb_stride1 + (k_idx % vb_shape2) * vb_stride2 + v_value = tl.load(v_broad + v_offsets, mask = mask) + v_vec = rsqrt(v_sum[:, None]) * v_value + + g_offset = (m_idx % gb_shape0) * gb_stride0 + (n_idx % gb_shape1) * gb_stride1 + (k_idx % gb_shape2) * gb_stride2 + g_value = tl.load(g_broad + g_offset, mask = mask) + out = v_vec * g_value + out_offset = m_idx * o_shape1 * o_shape2 + n_idx * o_shape2 + k_idx + tl.store(output + out_offset, out, mask = mask) + +def weight_norm(v, g, dim = 0): + logging.debug("GEMS WEIGHTNORM") + + v = v.contiguous() + g = g.contiguous() + dim_neg = dim - len(v.shape) + output_shape = torch.broadcast_shapes(v.shape, g.shape) + output = torch.empty(output_shape, device = v.device, dtype = v.dtype) + + v_broad = v + g_broad = g + v_re_shape = (v.shape[:dim_neg] + v.shape[dim_neg+1:]) + g_re_shape = (g.shape[:dim_neg] + g.shape[dim_neg+1:]) + for i in range(len(v_re_shape)-1, 0, -1): + if v_re_shape[i] == 1 and sum(v_re_shape[:i]) != i: + v_broad = torch.broadcast_to(v, output_shape).clone() + break + for i in range(len(g_re_shape)-1, 0, -1): + if g_re_shape[i] == 1 and sum(g_re_shape[:i]) != i: + g_broad = torch.broadcast_to(g, output_shape).clone() + break + + o_g_v_vb_gb = [[math.prod(x.shape[:dim_neg]), x.shape[dim_neg], math.prod(x.shape[dim_neg+1:])] for x in [output, g, v, v_broad, g_broad]] + + step = [1] * 3 + v_vb_gb = [[0] * 3, [0] * 3, [0] * 3] + for i in range(2, -1, -1): + for j in range(len(v_vb_gb)): + if o_g_v_vb_gb[j + 2][i] != 1: + v_vb_gb[j][i] = step[j] + step[j] = step[j] * o_g_v_vb_gb[j + 2][i] + + grid = lambda META: (triton.cdiv(o_g_v_vb_gb[0][1], META["BLOCK_ROW_SIZE"]),) + with torch.cuda.device(v.device): + weight_norm_kernel[grid]( + output, + v, + v_broad, + g_broad, + o_g_v_vb_gb[0][0], o_g_v_vb_gb[0][1], o_g_v_vb_gb[0][2], + o_g_v_vb_gb[2][0], o_g_v_vb_gb[2][2], + o_g_v_vb_gb[3][0], o_g_v_vb_gb[3][1], o_g_v_vb_gb[3][2], + o_g_v_vb_gb[4][0], o_g_v_vb_gb[4][1], o_g_v_vb_gb[4][2], + v_vb_gb[0][0], v_vb_gb[0][1], v_vb_gb[0][2], + v_vb_gb[1][0], v_vb_gb[1][1], v_vb_gb[1][2], + v_vb_gb[2][0], v_vb_gb[2][1], v_vb_gb[2][2], + eps = torch.finfo(torch.float32).tiny + ) + return output diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 6c9c7ee7..be63813f 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -353,6 +353,21 @@ def test_accuracy_layernorm(shape, dtype): gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_weightnorm(shape, dtype): + v = torch.randn(shape, dtype=dtype, device="cuda") + g = torch.randn(shape, dtype=dtype, device="cuda") + + ref_v = to_reference(v, True) + ref_g = to_reference(g, True) + + ref_out = torch._weight_norm(ref_v, ref_g) + res_out = flag_gems.weight_norm(v, g) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[1]) + + @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_log_softmax(shape, dtype):