Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add weight_norm op [MooreThreads] #177

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ def layer_norm_args(dtype, batch, size):
bench.run()


def test_perf_weightnorm():
def weight_norm_args(dtype, batch, size):
v = torch.randn([batch, size], dtype=dtype, device="cuda")
g = torch.randn([batch], dtype=dtype, device="cuda")
return v, g, 0

bench = Benchmark(
op_name="weight_norm",
torch_op=torch._weight_norm_interface,
arg_func=weight_norm_args,
dtypes=FLOAT_DTYPES,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_log_softmax():
bench = Benchmark(
op_name="log_softmax",
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def enable(lib=aten_lib):
lib.impl("ge.Scalar", ge_scalar, "CUDA")
lib.impl("gelu", gelu, "AutogradCUDA")
lib.impl("native_group_norm", group_norm, "AutogradCUDA")
lib.impl("_weight_norm_interface", weight_norm, "AutogradCUDA")
lib.impl("gt.Tensor", gt, "CUDA")
lib.impl("gt.Scalar", gt_scalar, "CUDA")
lib.impl("isfinite", isfinite, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from .var_mean import var_mean
from .vector_norm import vector_norm
from .vstack import vstack
from .weightnorm import weight_norm
from .where import where_scalar_other, where_scalar_self, where_self
from .zeros import zeros
from .zeros_like import zeros_like
Expand Down Expand Up @@ -167,6 +168,7 @@
"isinf",
"isnan",
"layer_norm",
"weight_norm",
"le",
"le_scalar",
"lt",
Expand Down
289 changes: 289 additions & 0 deletions src/flag_gems/ops/weightnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


def cfggen_first():
block_m = [1, 2, 4, 8, 32]
block_n = [512, 1024, 2048]
warps = [4, 8, 16]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w)
for m in block_m
for n in block_n
for w in warps
]
return configs


def cfggen_last():
block_m = [512, 1024, 2048]
block_n = [1, 2, 4, 8, 32]
warps = [4, 8, 16]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w)
for m in block_m
for n in block_n
for w in warps
]
return configs


@libentry()
@triton.autotune(configs=cfggen_last(), key=["M", "N"])
@triton.jit(do_not_specialize=["eps"])
def weight_norm_kernel_last(
output,
norm,
v,
g,
M,
N,
eps: tl.constexpr,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
bx = tl.program_id(axis=0) * BLOCK_COL_SIZE
col_offset = bx + tx
col_mask = col_offset < N

ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
for base in range(0, M, BLOCK_ROW_SIZE):
row_offset = base + ty
mask = row_offset < M and col_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_block += v_value * v_value

normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
tl.store(norm + col_offset, normalized[:, None], mask=col_mask)
g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)

for base in range(0, M, BLOCK_ROW_SIZE):
row_offset = base + ty
mask = row_offset < M and col_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_vec = v_value / normalized[:, None]
out = v_vec * g_value
tl.store(output + row_offset * N + col_offset, out, mask=mask)


@libentry()
@triton.autotune(configs=cfggen_first(), key=["M", "N"])
@triton.jit(do_not_specialize=["eps"])
def weight_norm_kernel_first(
output,
norm,
v,
g,
M,
N,
eps: tl.constexpr,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
by = tl.program_id(axis=0) * BLOCK_ROW_SIZE
row_offset = by + ty
row_mask = row_offset < M

tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for base in range(0, N, BLOCK_COL_SIZE):
col_offset = base + tx
mask = col_offset < N and row_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_block += v_value * v_value

normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be reducing on the first dimension, ie., axis=0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_block is stored in row-major order, so I perform the sum along the rows regardless of whether the reduction dimension is the first or last (xy index will be permuted for last). The test encountered an error because REDUCTION_SHAPES = (200, 40999, 3) and dim = 1 is not supported for weight normalization; this issue has now been resolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing on dim 1 is only correct provided the inputs are transposed up front. It looks like that's not the case in WeightNorm.forward. Can we further verify that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If reduction occurs in the error dimension, the result will definitely be different compared to the golden reference, but currently, they are consistent. The transpose occurs within the kernel, where threads load the number in the row direction from global, but store it in the column direction of v_block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            M = v.shape[0]
            N = math.prod(v.shape[1:])
            grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)

Above is the blocking scheme in the code, where M is the reduction dim size. It's clear the reduction axis is split. I don't see how transpose could be done in the kernel...

Copy link
Contributor Author

@TZWX-0 TZWX-0 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the kernel

// for reduce dim is first
tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)

// for reduce dim is last
ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)

how about you verify this with a simple instance, for example reduce shape = (2, 2). if reduce dim is wrong in the kernel, the result will not consistent with golden

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad... I took for granted that the input dim is the dimension to be contracted off..

tl.store(norm + row_offset, normalized[:, None], mask=row_mask)
g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)

for base in range(0, N, BLOCK_COL_SIZE):
col_offset = base + tx
mask = col_offset < N and row_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_vec = v_value / normalized[:, None]
out = v_vec * g_value
tl.store(output + row_offset * N + col_offset, out, mask=mask)


@libentry()
@triton.autotune(configs=cfggen_last(), key=["M", "N"])
@triton.jit(do_not_specialize=["eps"])
def weight_norm_bwd_kernel_last(
v_grad,
g_grad,
w,
v,
g,
norm,
M,
N,
eps: tl.constexpr,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
bx = tl.program_id(axis=0) * BLOCK_COL_SIZE
col_offset = tx + bx
col_mask = col_offset < N

g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32)

ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]

vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
for base in range(0, M, BLOCK_ROW_SIZE):
row_offset = base + ty
mask = row_offset < M and col_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
vw_block += v_value * w_value
vw_sum = tl.sum(vw_block, 1)[:, None]

for base in range(0, M, BLOCK_ROW_SIZE):
row_offset = base + ty
mask = row_offset < M and col_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_grad_value = g_value * (
w_value / (norm_value + eps)
- v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
)
tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)

g_grad_value = vw_sum / (norm_value + eps)
tl.store(g_grad + col_offset, g_grad_value, mask=col_mask)


@libentry()
@triton.autotune(configs=cfggen_first(), key=["M", "N"])
@triton.jit(do_not_specialize=["eps"])
def weight_norm_bwd_kernel_first(
v_grad,
g_grad,
w,
v,
g,
norm,
M,
N,
eps: tl.constexpr,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
by = tl.program_id(axis=0) * BLOCK_ROW_SIZE
row_offset = by + ty
row_mask = row_offset < M

g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)

tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]

v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for base in range(0, N, BLOCK_COL_SIZE):
col_offset = base + tx
mask = col_offset < N and row_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_block += v_value * w_value
vw_sum = tl.sum(v_block, 1)[:, None]

for base in range(0, N, BLOCK_COL_SIZE):
col_offset = base + tx
mask = col_offset < N and row_mask
v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
v_grad_value = g_value * (
w_value / (norm_value + eps)
- v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
)
tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)

g_grad_value = vw_sum / (norm_value + eps)
tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)


class WeightNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, v, g, dim):
logging.debug("GEMS WEIGHTNORM FORWARD")
v = v.contiguous()
g = g.contiguous()
output = torch.empty_like(v)
norm = torch.empty_like(g, dtype=torch.float32)
zhzhcookie marked this conversation as resolved.
Show resolved Hide resolved
if dim == 0:
M = v.shape[0]
N = math.prod(v.shape[1:])
grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
with torch.cuda.device(v.device):
weight_norm_kernel_first[grid](
output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
)
elif dim == len(v.shape) - 1:
M = math.prod(v.shape[:-1])
N = v.shape[dim]
grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
with torch.cuda.device(v.device):
weight_norm_kernel_last[grid](
output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
)
ctx.save_for_backward(v, g, norm)
ctx.DIM = dim
return output, norm

@staticmethod
def backward(ctx, w_grad, norm_grad):
logging.debug("GEMS WEIGHTNORM BACKWARD")
v, g, norm = ctx.saved_tensors
dim = ctx.DIM
v_grad = torch.empty_like(v)
g_grad = torch.empty_like(g)

if dim == 0:
M = v.shape[0]
N = math.prod(v.shape[1:])
grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
with torch.cuda.device(v.device):
weight_norm_bwd_kernel_first[grid](
v_grad,
g_grad,
w_grad,
v,
g,
norm,
M,
N,
eps=torch.finfo(torch.float32).tiny,
)
elif dim == len(v.shape) - 1:
M = math.prod(v.shape[:dim])
N = v.shape[dim]
grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
with torch.cuda.device(v.device):
weight_norm_bwd_kernel_last[grid](
v_grad,
g_grad,
w_grad,
v,
g,
norm,
M,
N,
eps=torch.finfo(torch.float32).tiny,
)
return v_grad, g_grad, None


def weight_norm(v, g, dim=0):
return WeightNorm.apply(v, g, dim)
33 changes: 33 additions & 0 deletions tests/test_norm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,39 @@ def test_accuracy_layernorm(shape, dtype):
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M)


@pytest.mark.weight_norm_interface
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("dim", [0, -1])
def test_accuracy_weightnorm(shape, dtype, dim):
dim = dim % len(shape)
v = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
g = torch.randn(shape[dim], dtype=dtype, device="cuda", requires_grad=True)
zhzhcookie marked this conversation as resolved.
Show resolved Hide resolved

ref_v = to_reference(v, False)
ref_g = to_reference(g, False)

ref_w_out, ref_norm_out = torch._weight_norm_interface(ref_v, ref_g, dim)
res_w_out, res_norm_out = flag_gems.weight_norm(v, g, dim)
gems_assert_close(res_w_out, ref_w_out, dtype, reduce_dim=shape[(dim - 1) % 2])
gems_assert_close(
res_norm_out, ref_norm_out, res_norm_out.dtype, reduce_dim=shape[(dim - 1) % 2]
)

res_w_grad = torch.randn_like(v)
ref_w_grad = to_reference(res_w_grad, False)

ref_v_grad, ref_g_grad = torch.autograd.grad(
ref_w_out, (ref_v, ref_g), grad_outputs=ref_w_grad
)
res_v_grad, res_g_grad = torch.autograd.grad(
res_w_out, (v, g), grad_outputs=res_w_grad
)

gems_assert_close(res_v_grad, ref_v_grad, dtype, reduce_dim=shape[(dim - 1) % 2])
gems_assert_close(res_g_grad, ref_g_grad, dtype, reduce_dim=shape[(dim - 1) % 2])


@pytest.mark.rms_norm
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
Loading