Skip to content

Commit

Permalink
[Operator] Add weight_norm op
Browse files Browse the repository at this point in the history
  • Loading branch information
TZWX-0 committed Aug 23, 2024
1 parent 2c4625e commit 54570d7
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 0 deletions.
17 changes: 17 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,23 @@ def layer_norm_args(dtype, batch, size):
bench.run()


def test_perf_weightnorm():
def weight_norm_args(dtype, batch, size):
v = torch.randn([batch, size], dtype=dtype, device="cuda")
g = torch.randn([batch, size], dtype=dtype, device="cuda")
return (v, g, 0)

bench = Benchmark(
op_name="weightnorm",
torch_op=torch._weight_norm,
arg_func=weight_norm_args,
dtypes=[torch.float16, torch.bfloat16],
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_log_softmax():
bench = Benchmark(
op_name="log_softmax",
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def enable(lib=aten_lib):
lib.impl("ge.Scalar", ge_scalar, "CUDA")
lib.impl("gelu", gelu, "AutogradCUDA")
lib.impl("native_group_norm", group_norm, "AutogradCUDA")
lib.impl("_weight_norm", weight_norm, "CUDA")
lib.impl("gt.Tensor", gt, "CUDA")
lib.impl("gt.Scalar", gt_scalar, "CUDA")
lib.impl("isfinite", isfinite, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from .unique import _unique2
from .var_mean import var_mean
from .vector_norm import vector_norm
from .weightnorm import weight_norm
from .where import where_scalar_other, where_scalar_self, where_self
from .zeros import zeros
from .zeros_like import zeros_like
Expand Down Expand Up @@ -141,6 +142,7 @@
"isinf",
"isnan",
"layer_norm",
"weight_norm",
"le",
"le_scalar",
"lt",
Expand Down
150 changes: 150 additions & 0 deletions src/flag_gems/ops/weightnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging

import torch
import triton
import triton.language as tl
from ..utils import libentry
import math

try:
from triton.language.extra.cuda.libdevice import rsqrt
except ImportError:
try:
from triton.language.math import rsqrt
except ImportError:
from triton.language.libdevice import rsqrt
import pytest


def cfggen():
block_m = [1, 2, 4]
block_n = [1024, 2048, 4096]
warps = [4, 8, 16]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w)
for m in block_m
for n in block_n
for w in warps
]
return configs

@libentry()
@triton.autotune(configs=cfggen(), key=["o_shape0", "o_shape1", "o_shape2"])
@triton.jit(do_not_specialize=["eps"])
def weight_norm_kernel(
output,
v,
v_broad,
g_broad,
o_shape0,
o_shape1,
o_shape2,
v_shape0,
v_shape2,
vb_shape0,
vb_shape1,
vb_shape2,
gb_shape0,
gb_shape1,
gb_shape2,
v_stride0,
v_stride1,
v_stride2,
vb_stride0,
vb_stride1,
vb_stride2,
gb_stride0,
gb_stride1,
gb_stride2,
eps: tl.constexpr,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE
row_offset = pid + tid_m
row_mask = row_offset < o_shape1

tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
col_offset = base + tid_n
m_idx = col_offset // v_shape2
n_idx = row_offset
k_idx = col_offset % v_shape2

mask = m_idx < v_shape0 and row_mask

v_offsets = m_idx * v_stride0 + n_idx * v_stride1 + k_idx * v_stride2
v_value = tl.load(v + v_offsets, mask = mask)
v_block += v_value * v_value
v_sum = tl.sum(v_block, axis = 1) + eps


for base in range(0, o_shape0 * o_shape2, BLOCK_COL_SIZE):
col_offset = base + tid_n
m_idx = col_offset // o_shape2
n_idx = row_offset
k_idx = col_offset % o_shape2

mask = m_idx < o_shape0 and row_mask

v_offsets = (m_idx % vb_shape0) * vb_stride0 + (n_idx % vb_shape1) * vb_stride1 + (k_idx % vb_shape2) * vb_stride2
v_value = tl.load(v_broad + v_offsets, mask = mask)
v_vec = rsqrt(v_sum[:, None]) * v_value

g_offset = (m_idx % gb_shape0) * gb_stride0 + (n_idx % gb_shape1) * gb_stride1 + (k_idx % gb_shape2) * gb_stride2
g_value = tl.load(g_broad + g_offset, mask = mask)
out = v_vec * g_value
out_offset = m_idx * o_shape1 * o_shape2 + n_idx * o_shape2 + k_idx
tl.store(output + out_offset, out, mask = mask)

def weight_norm(v, g, dim = 0):
logging.debug("GEMS WEIGHTNORM")

v = v.contiguous()
g = g.contiguous()
dim_neg = dim - len(v.shape)
output_shape = torch.broadcast_shapes(v.shape, g.shape)
output = torch.empty(output_shape, device = v.device, dtype = v.dtype)

v_broad = v
g_broad = g
v_re_shape = (v.shape[:dim_neg] + v.shape[dim_neg+1:])
g_re_shape = (g.shape[:dim_neg] + g.shape[dim_neg+1:])
for i in range(len(v_re_shape)-1, 0, -1):
if v_re_shape[i] == 1 and sum(v_re_shape[:i]) != i:
v_broad = torch.broadcast_to(v, output_shape).clone()
break
for i in range(len(g_re_shape)-1, 0, -1):
if g_re_shape[i] == 1 and sum(g_re_shape[:i]) != i:
g_broad = torch.broadcast_to(g, output_shape).clone()
break

o_g_v_vb_gb = [[math.prod(x.shape[:dim_neg]), x.shape[dim_neg], math.prod(x.shape[dim_neg+1:])] for x in [output, g, v, v_broad, g_broad]]

step = [1] * 3
v_vb_gb = [[0] * 3, [0] * 3, [0] * 3]
for i in range(2, -1, -1):
for j in range(len(v_vb_gb)):
if o_g_v_vb_gb[j + 2][i] != 1:
v_vb_gb[j][i] = step[j]
step[j] = step[j] * o_g_v_vb_gb[j + 2][i]

grid = lambda META: (triton.cdiv(o_g_v_vb_gb[0][1], META["BLOCK_ROW_SIZE"]),)
with torch.cuda.device(v.device):
weight_norm_kernel[grid](
output,
v,
v_broad,
g_broad,
o_g_v_vb_gb[0][0], o_g_v_vb_gb[0][1], o_g_v_vb_gb[0][2],
o_g_v_vb_gb[2][0], o_g_v_vb_gb[2][2],
o_g_v_vb_gb[3][0], o_g_v_vb_gb[3][1], o_g_v_vb_gb[3][2],
o_g_v_vb_gb[4][0], o_g_v_vb_gb[4][1], o_g_v_vb_gb[4][2],
v_vb_gb[0][0], v_vb_gb[0][1], v_vb_gb[0][2],
v_vb_gb[1][0], v_vb_gb[1][1], v_vb_gb[1][2],
v_vb_gb[2][0], v_vb_gb[2][1], v_vb_gb[2][2],
eps = torch.finfo(torch.float32).tiny
)
return output
15 changes: 15 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,21 @@ def test_accuracy_layernorm(shape, dtype):
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M)


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_accuracy_weightnorm(shape, dtype):
v = torch.randn(shape, dtype=dtype, device="cuda")
g = torch.randn(shape, dtype=dtype, device="cuda")

ref_v = to_reference(v, True)
ref_g = to_reference(g, True)

ref_out = torch._weight_norm(ref_v, ref_g)
res_out = flag_gems.weight_norm(v, g)

gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[1])


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_log_softmax(shape, dtype):
Expand Down

0 comments on commit 54570d7

Please sign in to comment.