diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 1d081bce..4d259879 100755 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -90,6 +90,12 @@ def cross_entropy_loss_input_fn(shape, cur_dtype, device): yield inp, target +def nll_loss_input_fn(shape, cur_dtype, device): + inp = generate_tensor_input(shape, cur_dtype, device) + target = torch.randint(0, shape[-1], (shape[0],), device=device) + yield inp, target + + def cumsum_input_fn(shape, cur_dtype, device): inp = generate_tensor_input(shape, cur_dtype, device) yield inp, 1 @@ -126,6 +132,13 @@ def cumsum_input_fn(shape, cur_dtype, device): FLOAT_DTYPES + INT_DTYPES, marks=pytest.mark.cumsum, ), + pytest.param( + "nll_loss", + torch.nn.NLLLoss, + nll_loss_input_fn, + FLOAT_DTYPES, + marks=pytest.mark.NLLLoss, + ), ], ) def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes): diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 3827efeb..bfb76afa 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -119,6 +119,7 @@ def enable(lib=aten_lib): lib.impl("max.dim", max_dim, "CUDA") lib.impl("min", min, "CUDA") lib.impl("min.dim", min_dim, "CUDA") + lib.impl("nll_loss", nll_loss, "AutogradCUDA") lib.impl("amax", amax, "CUDA") lib.impl("argmax", argmax, "CUDA") lib.impl("prod", prod, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 1700c283..feda0ebc 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -60,6 +60,7 @@ from .mv import mv from .ne import ne, ne_scalar from .neg import neg +from .nllloss import nll_loss from .nonzero import nonzero from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor from .ones import ones @@ -252,5 +253,6 @@ "repeat_interleave_self_int", "vstack", "repeat_interleave_tensor", + "nll_loss", "repeat_interleave_self_tensor", ] diff --git a/src/flag_gems/ops/nllloss.py b/src/flag_gems/ops/nllloss.py new file mode 100644 index 00000000..7105e64d --- /dev/null +++ b/src/flag_gems/ops/nllloss.py @@ -0,0 +1,283 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry +from .sum import sum + + +@libentry() +@triton.autotune( + configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [256, 512, 1024]], + key=["N"], +) +@triton.jit(do_not_specialize=["ignore_index"]) +def nll_loss_2d_fwd_kernel( + inp_ptr, + tgt_ptr, + w_ptr, + w_tgt_ptr, + out_ptr, + ignore_index, + N, + C, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_n = offsets_n < N + + tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) + ignore_mask = not (tgt == ignore_index) + mask_tgt = tgt < C + + w_ptrs = w_ptr + tgt + w_tgt = tl.load(w_ptrs, mask=mask_n, other=0).to(tl.float32) + tl.store(w_tgt_ptr + offsets_n, w_tgt, mask=(mask_n & ignore_mask)) + + inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt + inp_tgt = tl.load(inp_tgt_ptrs, mask=mask_n & mask_tgt, other=-float("inf")).to( + tl.float32 + ) + out = inp_tgt * w_tgt * -1 + tl.store(out_ptr + offsets_n, out, mask=mask_n & mask_tgt & ignore_mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": n, "BLOCK_C": c}, num_warps=4) + for n in [256, 512, 1024] + for c in [1, 4, 16] + ], + key=["N", "C"], +) +@triton.jit(do_not_specialize=["ignore_index"]) +def nll_loss_2d_bwd_kernel( + out_grad_ptr, + tgt_ptr, + w_ptr, + inp_grad_ptr, + ignore_index, + total_weight, + N, + C, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offsets_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_n = offsets_n < N + mask_block = offsets_n[:, None] < N and offsets_c[None, :] < C + + tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) + out_grad = (tl.load(out_grad_ptr + offsets_n, mask=mask_n, other=0).to(tl.float32))[ + :, None + ] + ignore_mask = (tgt != ignore_index)[:, None] + + w_ptrs = w_ptr + tgt + w_tgt = tl.load(w_ptrs, mask=mask_n, other=0).to(tl.float32)[:, None] + + mask_inp = mask_block and offsets_c[None, :] == tgt[:, None] + inp_grad = -1 * out_grad * w_tgt / total_weight + inp_grad_ptrs = inp_grad_ptr + offsets_n[:, None] * C + offsets_c[None, :] + tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(mask_inp & ignore_mask)) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4) + for c in [256, 512, 1024] + for d in [1, 4, 16] + ], + key=["C", "D"], +) +@triton.jit(do_not_specialize=["ignore_index"]) +def nll_loss_multi_fwd_kernel( + inp_ptr, + tgt_ptr, + w_ptr, + w_tgt_ptr, + out_ptr, + ignore_index, + N, + C, + D, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_d = tl.program_id(1) + offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + tgt_ptrs = tgt_ptr + pid_n * D + offset_d + mask_tgt = offset_d < D + tgt = tl.load(tgt_ptrs, mask=mask_tgt, other=0) + + ignore_mask = not (tgt == ignore_index) + + w_ptrs = w_ptr + tgt + w_tgt = tl.load(w_ptrs, mask=mask_tgt, other=0).to(tl.float32) + w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d + tl.store(w_tgt_ptrs, w_tgt, mask=(mask_tgt & ignore_mask)) + + inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d + inp_tgt = tl.load(inp_tgt_ptrs, mask=mask_tgt, other=-float("inf")).to(tl.float32) + out = inp_tgt * w_tgt * -1 + out_ptrs = out_ptr + pid_n * D + offset_d + tl.store(out_ptrs, out, mask=(mask_tgt & ignore_mask)) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4) + for c in [256, 512, 1024] + for d in [1, 4, 16] + ], + key=["C", "D"], +) +@triton.jit(do_not_specialize=["ignore_index", "total_weight"]) +def nll_loss_multi_bwd_kernel( + out_grad_ptr, + tgt_ptr, + w_ptr, + inp_grad_ptr, + ignore_index, + total_weight, + N, + C, + D, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_d = tl.program_id(1) + offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + tgt_ptrs = tgt_ptr + pid_n * D + offset_d + mask_tgt = offset_d < D + tgt = tl.load(tgt_ptrs, mask=mask_tgt, other=0) + + ignore_mask = (tgt != ignore_index)[None, :] + + w_ptrs = w_ptr + tgt + w_tgt = tl.load(w_ptrs, mask=mask_tgt, other=0).to(tl.float32)[None, :] + out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d + out_grad = (tl.load(out_grad_ptrs, mask=mask_tgt, other=0).to(tl.float32))[None, :] + + for off in range(0, C, BLOCK_C): + offset_c = off + tl.arange(0, BLOCK_C) + inp_mask = offset_c[:, None] < C and offset_d[None, :] < D + inp_mask = inp_mask and offset_c[:, None] == tgt + inp_grad = -1 * out_grad * w_tgt / total_weight + inp_grad_ptrs = ( + inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] + ) + tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(inp_mask & ignore_mask)) + + +class NLLLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, inp, target, weight, reduction, ignore_index): + logging.debug("GEMS NLLLoss FWD") + shape = list(inp.shape) + dim = inp.ndim + N = 1 if dim == 1 else shape[0] + C = shape[0] if dim == 1 else shape[1] + D = inp.numel() // N // C + axis = 0 if dim == 1 else 1 + del shape[axis] + + assert ((i >= 0 and i < C) for i in target), "Target is out of bounds" + assert list(target.shape) == shape, "Invalid target size" + assert inp.ndim >= 1, "Invalid input ndim" + + if weight is None: + weight = torch.ones( + [ + C, + ], + dtype=inp.dtype, + device=inp.device, + ) + + inp = inp.contiguous() + tgt = target.contiguous() + w = weight.contiguous() + out = torch.zeros(shape, dtype=torch.float32, device=inp.device) + w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device) + + if inp.ndim > 2: + grid = lambda meta: (N, triton.cdiv(D, meta["BLOCK_D"])) + with torch.cuda.device(inp.device): + nll_loss_multi_fwd_kernel[grid]( + inp, tgt, w, w_tgt, out, ignore_index, N, C, D + ) + else: + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) + with torch.cuda.device(inp.device): + nll_loss_2d_fwd_kernel[grid]( + inp, tgt, w, w_tgt, out, ignore_index, N, C + ) + + ctx.save_for_backward(inp, tgt, w) + ctx.N = N + ctx.C = C + ctx.D = D + ctx.ignore_index = ignore_index + ctx.total_weight = 1 + ctx.shape = shape + + # redution: 0-None, 1-mean, 2-sum + if reduction == 0: + res = out.to(inp.dtype) + elif reduction == 1: + ctx.total_weight = sum(w_tgt).item() + res = sum(out).to(inp.dtype) / ctx.total_weight + else: + res = sum(out).to(inp.dtype) + + return res + + @staticmethod + def backward(ctx, out_grad): + logging.debug("GEMS NLLLoss BWD") + inp, tgt, w = ctx.saved_tensors + N = ctx.N + C = ctx.C + D = ctx.D + ignore_index = ctx.ignore_index + total_weight = ctx.total_weight + shape = ctx.shape + + out_grad = out_grad.broadcast_to(shape).contiguous() + inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) + + if inp.ndim > 2: + grid = lambda meta: (N, triton.cdiv(D, meta["BLOCK_D"])) + nll_loss_multi_bwd_kernel[grid]( + out_grad, tgt, w, inp_grad, ignore_index, total_weight, N, C, D + ) + else: + grid = lambda meta: ( + triton.cdiv(N, meta["BLOCK_N"]), + triton.cdiv(C, meta["BLOCK_C"]), + ) + nll_loss_2d_bwd_kernel[grid]( + out_grad, tgt, w, inp_grad, ignore_index, total_weight, N, C + ) + + return inp_grad, None, None, None, None, None + + +def nll_loss(inp, target, weight=None, reduction=1, ignore_index=-100): + return NLLLoss.apply(inp, target, weight, reduction, ignore_index) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 4197ccc9..f6a587a9 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -174,6 +174,50 @@ def test_accuracy_cross_entropy_loss_probabilities( gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) +@pytest.mark.NLLLoss +@pytest.mark.parametrize("reduction", ["mean", "none", "sum"]) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("ignore_index", [1, 200, -100]) +def test_accuracy_nll_loss(shape, dtype, ignore_index, reduction): + dim = 1 + up_limit = shape[dim] - 1 + target_shape = list(shape) + del target_shape[dim] + + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + m = torch.nn.LogSoftmax(dim=1) + inp = m(inp) + + target = torch.randint(0, up_limit, target_shape, device="cuda") + weight = torch.randn(shape[dim], dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + ref_target = to_reference(target) + ref_weight = to_reference(weight, True) + + ref_criterion = torch.nn.NLLLoss( + weight=ref_weight, + ignore_index=ignore_index, + reduction=reduction, + ) + res_criterion = torch.nn.NLLLoss( + weight=weight, + ignore_index=ignore_index, + reduction=reduction, + ) + + ref_out = ref_criterion(ref_inp, ref_target) + with flag_gems.use_gems(): + res_out = res_criterion(inp, target) + gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[dim]) + + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) + + CUMSUM_SHAPES = ( [(2, 32)] if QUICK_MODE else REDUCTION_SHAPES + [(2637,), (16, 1025, 255)] )