From 679ce8b478e445e780c48d15feb98b22e37e4760 Mon Sep 17 00:00:00 2001 From: FatJhon <156064001+FatJhon@users.noreply.github.com> Date: Fri, 31 May 2024 16:14:56 +0800 Subject: [PATCH] add reduction of sum and none for CrossEntropyLoss (#41) * modify name && add reduce function * add reduce none * add test * clean code * add reduce enum * Replacing the enum interface with Intenum & add illegal detection of reduction --------- Co-authored-by: Jiang Bin --- src/flag_gems/ops/argmax.py | 4 +- src/flag_gems/ops/cross_entropy_loss.py | 121 ++++++++++++++++++++---- src/flag_gems/ops/max.py | 8 +- src/flag_gems/ops/min.py | 8 +- tests/test_reduction_ops.py | 9 +- 5 files changed, 122 insertions(+), 28 deletions(-) diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index e31ce6ec..54e1b7ec 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -35,8 +35,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr mid_ptrs = mid_value + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.argmax(mid_val, axis=0) - mid_index_ptrs = mid_index + sum_val + index_val = tl.argmax(mid_val, axis=0) + mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) tl.store(out, out_val) diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py index 951cbe72..d224a4a8 100644 --- a/src/flag_gems/ops/cross_entropy_loss.py +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -2,8 +2,15 @@ import triton import triton.language as tl import logging +from enum import IntEnum from ..utils import libentry -from .sum import sum +from .sum import sum, sum_dim + + +class Reduction(IntEnum): + NONE = 0 + MEAN = 1 + SUM = 2 @libentry() @@ -56,7 +63,7 @@ def log_softmax_and_mul_kernel( denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = tl.log(numerator / denominator) target = tl.load(target_ptr + offset, mask=mask, other=0.0) - out = softmax_output * target / (-mean_num) + out = softmax_output * target / (mean_num) output_ptrs = output_ptr + offset tl.store(output_ptrs, out, mask=mask) @@ -114,6 +121,68 @@ def softmax_and_sub_kernel( softmax_output = numerator / denominator target_ptrs = target_ptr + offset target = tl.load(target_ptrs, mask=mask, other=0.0) + out_grad_ptr = out_grad + m_offset[:, None] * K + pid_k + out_grad_value = tl.load(out_grad_ptr) + out = out_grad_value * (softmax_output - target) / mean_num + output_ptrs = output_ptr + offset + + tl.store(output_ptrs, out, mask=mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 1}, num_stages=4), + triton.Config({"BLOCK_M": 1}, num_stages=5), + triton.Config({"BLOCK_M": 2}, num_stages=4), + triton.Config({"BLOCK_M": 2}, num_stages=5), + triton.Config({"BLOCK_M": 4}, num_stages=4), + triton.Config({"BLOCK_M": 4}, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_stages=4), + triton.Config({"BLOCK_M": 8}, num_stages=5), + ], + key=[ + "M", + "N", + ], +) +@triton.heuristics( + values={ + "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), + }, +) +@triton.jit +def softmax_and_sub_reduce_kernel( + output_ptr, + input_ptr, + target_ptr, + out_grad, + mean_num, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + n_offset = tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + input_ptrs = input_ptr + offset + inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = inp - tl.max(inp, axis=1)[:, None] + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1)[:, None] + # todo: reduce unnecessary calculations through mask operations to improve performance + softmax_output = numerator / denominator + target_ptrs = target_ptr + offset + target = tl.load(target_ptrs, mask=mask, other=0.0) + out_grad_value = tl.load(out_grad) out = out_grad_value * (softmax_output - target) / mean_num output_ptrs = output_ptr + offset @@ -125,15 +194,18 @@ class CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing): logging.debug("GEMS CrossEntropyLoss") + assert reduction in Reduction._value2member_map_, "Invalid reduction" assert isinstance(input, torch.Tensor), "input is not a tensor" if input.ndim >= 2: dim = 1 else: dim = 0 - + if reduction != Reduction.MEAN.value: + mean_num = -1 + else: + mean_num = -target.numel() shape = list(input.shape) shape[dim] = 1 - mean_num = target.numel() target = torch.zeros_like(input).scatter(dim, target.view(shape), 1) M = 1 @@ -157,11 +229,15 @@ def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing N, K, ) - out_result = sum(out) + if reduction != Reduction.NONE.value: + out_result = sum(out) + else: + out_result = sum_dim(out, dim=[dim]) ctx.save_for_backward(input, target) ctx.dim = dim - ctx.mean_num = mean_num + ctx.mean_num = -mean_num + ctx.reduction = reduction return out_result @staticmethod @@ -170,6 +246,7 @@ def backward(ctx, out_grad): input, target = ctx.saved_tensors dim = ctx.dim mean_num = ctx.mean_num + reduction = ctx.reduction M = 1 N = input.shape[dim] @@ -183,16 +260,28 @@ def backward(ctx, out_grad): triton.cdiv(M, meta["BLOCK_M"]), K, ) - softmax_and_sub_kernel[grid]( - out, - inp, - target, - out_grad, - mean_num, - M, - N, - K, - ) + if reduction != Reduction.NONE.value: + softmax_and_sub_reduce_kernel[grid]( + out, + inp, + target, + out_grad, + mean_num, + M, + N, + K, + ) + else: + softmax_and_sub_kernel[grid]( + out, + inp, + target, + out_grad, + mean_num, + M, + N, + K, + ) return out, None, None, None, None, None diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 6ff4fcbd..8e27f7a2 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -20,9 +20,9 @@ def max_kernel_1( inp_ptrs = inp + offset mask = offset < M inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.max(inp_val) + max_val = tl.max(inp_val) mid_ptr = mid + pid - tl.store(mid_ptr, sum_val) + tl.store(mid_ptr, max_val) @libentry() @@ -32,8 +32,8 @@ def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): mid_ptrs = mid + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.max(mid_val) - tl.store(out, sum_val) + max_val = tl.max(mid_val) + tl.store(out, max_val) @libentry() diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index cb09eb00..1a17dde2 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -20,9 +20,9 @@ def min_kernel_1( inp_ptrs = inp + offset mask = offset < M inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf")) - sum_val = tl.min(inp_val) + min_val = tl.min(inp_val) mid_ptr = mid + pid - tl.store(mid_ptr, sum_val) + tl.store(mid_ptr, min_val) @libentry() @@ -32,8 +32,8 @@ def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): mid_ptrs = mid + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf")) - sum_val = tl.min(mid_val) - tl.store(out, sum_val) + min_val = tl.min(mid_val) + tl.store(out, min_val) @libentry() diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index ebdae6ab..8bdacc75 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -143,9 +143,12 @@ def test_accuracy_argmax(shape, dim, keepdim, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.parametrize("size_average", [None, True, False]) +@pytest.mark.parametrize("reduce", [None, True, False]) +@pytest.mark.parametrize("reduction", ["mean", "none", "sum"]) @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_cross_entropy_loss(shape, dtype): +def test_accuracy_cross_entropy_loss(shape, dtype, size_average, reduce, reduction): inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) dim = 1 up_limit = shape[dim] - 1 @@ -156,7 +159,9 @@ def test_accuracy_cross_entropy_loss(shape, dtype): ref_inp = to_reference(inp, True) ref_target = to_reference(target) - criterion = torch.nn.CrossEntropyLoss() + criterion = torch.nn.CrossEntropyLoss( + size_average=size_average, reduce=reduce, reduction=reduction + ) ref_out = criterion(ref_inp, ref_target) with flag_gems.use_gems():