diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 11a49cd8..b653928d 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -111,6 +111,10 @@ def enable(lib=aten_lib): lib.impl("log_softmax.int", log_softmax, "AutogradCUDA") lib.impl("outer", outer, "AutogradCUDA") lib.impl("cross_entropy_loss", cross_entropy_loss, "AutogradCUDA") + # lib.impl("scatter.src", scatter_src, "CUDA") + # lib.impl("scatter.reduce", scatter_reduce, "CUDA") + # lib.impl("gather", gather, "CUDA") + # lib.impl("gather.out", gather_out, "CUDA") lib.impl("isclose", isclose, "CUDA") lib.impl("allclose", allclose, "CUDA") lib.impl("flip", flip, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 9dde9c8c..354811e8 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -27,6 +27,7 @@ from .flip import flip from .full import full from .full_like import full_like +from .gather import gather, gather_out from .ge import ge, ge_scalar from .gelu import gelu from .groupnorm import group_norm @@ -71,6 +72,7 @@ from .resolve_neg import resolve_neg from .rms_norm import rms_norm from .rsqrt import rsqrt +from .scatter import scatter_reduce, scatter_src from .sigmoid import sigmoid from .silu import silu from .sin import sin @@ -121,6 +123,8 @@ "eq_scalar", "exp", "exponential_", + "gather", + "gather_out", "flip", "ones_like", "full_like", @@ -168,6 +172,8 @@ "reciprocal", "relu", "rsqrt", + "scatter_src", + "scatter_reduce", "sigmoid", "silu", "sin", diff --git a/src/flag_gems/ops/gather.py b/src/flag_gems/ops/gather.py new file mode 100644 index 00000000..d935ecd7 --- /dev/null +++ b/src/flag_gems/ops/gather.py @@ -0,0 +1,130 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offset_calculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def gather_kernel( + inp, + inp_offsets, + out, + index, + idx_offsets, + M, + N, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) + idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) + + cur_index = tl.load(index + idx_indices, mask=mask, other=0) + inp_indices += cur_index * stride_dim + cur_inp = tl.load(inp + inp_indices, mask=mask, other=0) + + tl.store(out + idx_indices, cur_inp, mask=mask) + + +def gather(inp, dim, index, sparse_grad=False): + logging.debug("GEMS GATHER") + assert ( + inp.ndim == index.ndim + ), "self and index should all have the same number of dimensions" + assert ( + ((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) + for i in range(0, index.ndim) + ), "index.size(d) <= self.size(d) for all dimensions d != dim" + assert ((0 <= index) * (index < inp.size(dim))).equal( + torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) + ), "0 <= index < self.size(dim)" + inp = inp.contiguous() + index = index.contiguous() + out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) + + inp_strided = restride_dim(inp, dim, index.shape) + # FIXME: Are there any other way to get the "flatten offset" of a tensor? + idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) + # Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), + # because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), + # and we do need **the whole stride[]** to accomplish this calculation! + # FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel + # so that the offset calculation can proceed in parallel + inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) + idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) + N = list(index.shape)[index.ndim - 1] + M = index.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + gather_kernel[grid]( + inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim) + ) + return out + + +def gather_out(inp, dim, index, sparse_grad=False, out=None): + logging.debug("GEMS GATHER OUT") + assert ( + inp.ndim == index.ndim and inp.ndim == out.ndim + ), "self, index and out (if it is a Tensor) should all have the same number of dimensions" + assert ( + (0 <= index.size(i) and index.size(i) <= out.size(i)) + for i in range(0, index.ndim) + ), "index.size(d) <= out.size(d) for all dimensions d" + assert ( + ((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) + for i in range(0, index.ndim) + ), "index.size(d) <= self.size(d) for all dimensions d != dim" + assert index.shape == out.shape, "out will have the same shape as index" + assert ((0 <= index) * (index < inp.size(dim))).equal( + torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) + ), "0 <= index < self.size(dim)" + if out is None: + out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) + inp = inp.contiguous() + index = index.contiguous() + out = out.contiguous() + + inp_strided = restride_dim(inp, dim, index.shape) + # FIXME: Are there any other way to get the "flatten offset" of a tensor? + idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) + # Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), + # because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), + # and we do need **the whole stride[]** to accomplish this calculation! + # FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel + # so that the offset calculation can proceed in parallel + inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) + idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) + N = list(index.shape)[index.ndim - 1] + M = index.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + gather_kernel[grid]( + inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim) + ) + return out diff --git a/src/flag_gems/ops/scatter.py b/src/flag_gems/ops/scatter.py new file mode 100644 index 00000000..3077011a --- /dev/null +++ b/src/flag_gems/ops/scatter.py @@ -0,0 +1,216 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offset_calculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def scatter_kernel( + inp_offsets, + src, + index, + idx_offsets, + out, + M, + N, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) + idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) + + cur_src = tl.load(src + idx_indices, mask=mask, other=0) + cur_index = tl.load(index + idx_indices, mask=mask, other=0) + + inp_indices += cur_index * stride_dim + tl.store(out + inp_indices, cur_src, mask=mask) + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def scatter_add_kernel( + inp, + inp_offsets, + src, + index, + idx_offsets, + out, + M, + N, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) + idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) + + cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32) + cur_index = tl.load(index + idx_indices, mask=mask, other=0) + + inp_indices += cur_index * stride_dim + cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32) + res = cur_inp + cur_src + tl.store(out + inp_indices, res, mask=mask) + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def scatter_mul_kernel( + inp, + inp_offsets, + src, + index, + idx_offsets, + out, + M, + N, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) + idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) + + cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32) + cur_index = tl.load(index + idx_indices, mask=mask, other=0) + + inp_indices += cur_index * stride_dim + cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32) + res = cur_inp * cur_src + tl.store(out + inp_indices, res, mask=mask) + + +def scatter(inp, dim, index, src, reduction=None): + assert ( + inp.ndim == index.ndim and inp.ndim == src.ndim + ), "self, index and src (if it is a Tensor) should all have the same number of dimensions" + assert ( + (0 <= index.size(i) and index.size(i) <= src.size(i)) + for i in range(0, index.ndim) + ), "index.size(d) <= src.size(d) for all dimensions d" + assert ( + ((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) + for i in range(0, index.ndim) + ), "index.size(d) <= self.size(d) for all dimensions d != dim" + inp = inp.contiguous() + index = index.contiguous() + src = src.contiguous() + out = inp.clone() + + src_strided = src.as_strided(index.shape, src.stride()).contiguous() + inp_strided = restride_dim(inp, dim, index.shape) + # FIXME: Are there any other way to get the "flatten offset" of a tensor? + idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) + # Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), + # because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), + # and we do need **the whole stride[]** to accomplish this calculation! + # FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel + # so that the offset calculation can proceed in parallel + inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) + idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) + N = list(index.shape)[index.ndim - 1] + M = index.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + if reduction is None: + scatter_kernel[grid]( + inp_offsets, + src_strided, + index, + idx_offsets, + out, + M, + N, + inp.stride(dim), + ) + elif reduction == "add": + scatter_add_kernel[grid]( + inp, + inp_offsets, + src_strided, + index, + idx_offsets, + out, + M, + N, + inp.stride(dim), + ) + elif reduction == "multiply": + scatter_mul_kernel[grid]( + inp, + inp_offsets, + src_strided, + index, + idx_offsets, + out, + M, + N, + inp.stride(dim), + ) + return out + + +def scatter_src(inp, dim, index, src): + logging.debug("GEMS SCATTER SRC") + return scatter(inp, dim, index, src) + + +def scatter_reduce(inp, dim, index, src, reduce): + logging.debug("GEMS SCATTER REDUCE") + # TODO: As is shown in PyTorch's document(torch.Tensor.scatter_reduce_), + # this function is still **in beta** and may change in the near future. + # So for now, we're just going to stick with the original "add" and "multiply" parameters. + # Maybe we can add reduction options like "sum", "prod", "mean", "amax" and "amin" in the future. + if reduce == "add": + return scatter(inp, dim, index, src, reduction="add") + elif reduce == "multiply": + return scatter(inp, dim, index, src, reduction="multiply") diff --git a/src/flag_gems/utils/__init__.py b/src/flag_gems/utils/__init__.py index c3bcb0f8..c4db131e 100644 --- a/src/flag_gems/utils/__init__.py +++ b/src/flag_gems/utils/__init__.py @@ -1,10 +1,12 @@ from .libentry import libentry from .pointwise_dynamic import pointwise_dynamic -from .shape_utils import broadcastable_to, dim_compress +from .shape_utils import broadcastable_to, dim_compress, offset_calculator, restride_dim __all__ = [ "libentry", "pointwise_dynamic", "dim_compress", + "restride_dim", + "offset_calculator", "broadcastable_to", ] diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index f98ff89a..0a154c82 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -3,6 +3,8 @@ from typing import Iterable, Tuple import torch +import triton +import triton.language as tl Shape = Tuple[int] Stride = Tuple[int] @@ -156,3 +158,91 @@ def can_use_int32_index(a): if max_offset > INT32_MAX: return False return True + + +def offsetCalculator(inp, idx, strides, dim, isInp): + ndim = inp.ndim + shape = list(inp.shape) + offsets = 0 + idx_dim = 0 + for d in range(0, ndim): + mod = idx % shape[d] + add_on = mod * strides[d] + offsets += add_on + if d == dim: + idx_dim = add_on + idx = idx // shape[d] + # FIXME: Should we write a fast div/mod + # to boost the '%' and '//'? (Since they may be run many times) + # See also: + # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + # - Division by Invariant Integers Using Multiplication, + # Torbjörn Granlund and Peter L. Montgomery, 1994. + return (offsets) if not isInp else (offsets - idx_dim) + + +def restride_dim(src, dim, shape, step=0, storage_offset=None): + strides = list(src.stride()) + strides[dim] *= step + return src.as_strided(shape, strides, storage_offset) + + +def cfggen(): + block_m = [1, 2, 4] + block_n = [256, 1024, 2048, 4096] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4) + for m in block_m + for n in block_n + ] + return configs + + +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def add_on_kernel( + idx, + add_on, + cur_shape, + cur_strides, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + rows_offset = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offset < M + + cols_offset = pid_y + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offset < N + block_mask = rows_mask and cols_mask + + offsets = rows_offset * N + cols_offset + cur_idx = tl.load(idx + offsets, mask=block_mask, other=1) + mod = cur_idx % cur_shape + res = mod * cur_strides + tl.store(add_on + offsets, res, mask=block_mask) + + +def offset_calculator(inp, idx, strides, dim, isInp): + ndim = inp.ndim + shape = list(inp.shape) + offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) + idx_dim = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) + for d in range(0, ndim): + add_on = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) + N = idx.size(idx.ndim - 1) + M = idx.numel() // N + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + add_on_kernel[grid](idx, add_on, shape[d], strides[d], M, N) + + offsets = torch.add(offsets, add_on) + if d == dim: + idx_dim = add_on + idx = idx // shape[d] + return offsets if not isInp else (offsets - idx_dim) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 6c9c7ee7..beb04d59 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -672,6 +672,212 @@ def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): gems_assert_close(res_out, ref_out, dtype) +@pytest.mark.parametrize("src_shape", [(128, 16 * i, 32 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("inp_shape", [(512, 32 * i, 64 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_scatter_src(src_shape, inp_shape, dim, dtype): + inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device="cuda") + size_dim = min(src_shape[dim], inp_shape[dim]) + + import random + + index_shape = [ + random.randint(1, min(src_shape[0], inp_shape[0])), + random.randint(1, min(src_shape[1], inp_shape[1])), + random.randint(1, min(src_shape[2], inp_shape[2])), + ] + index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + + m, n, o = index_shape + + index_size_dim = index_shape[dim] + # make unique indices + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, index.size(dim) + 1) + index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim] + + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + # ref_src = to_reference(src) + ref_out = torch.scatter(inp, dim, index, src) + with flag_gems.use_gems(): + from src.flag_gems.ops import scatter_src + + # res_out = scatter_src(inp, dim, index, src) + res_out = scatter_src(inp, dim, index, src) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("src_shape", [(2, 2, 2)]) +@pytest.mark.parametrize("inp_shape", [(3, 3, 3)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_scatter_add(src_shape, inp_shape, dim, dtype): + inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device="cuda") + size_dim = min(src_shape[dim], inp_shape[dim]) + + import random + + index_shape = [ + random.randint(1, min(src_shape[0], inp_shape[0])), + random.randint(1, min(src_shape[1], inp_shape[1])), + random.randint(1, min(src_shape[2], inp_shape[2])), + ] + index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + + m, n, o = index_shape + + index_size_dim = index_shape[dim] + # make unique indices + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, index.size(dim) + 1) + index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim] + + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + # ref_src = to_reference(src) + ref_out = torch.scatter(inp, dim, index, src, reduce="add") + with flag_gems.use_gems(): + from src.flag_gems.ops import scatter_reduce + + # res_out = torch.scatter(inp, dim, index, src, reduce="add") + res_out = scatter_reduce(inp, dim, index, src, reduce="add") + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("src_shape", [(128, 16 * i, 32 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("inp_shape", [(512, 32 * i, 64 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_scatter_mul(src_shape, inp_shape, dim, dtype): + inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device="cuda") + size_dim = min(src_shape[dim], inp_shape[dim]) + + import random + + index_shape = [ + random.randint(1, min(src_shape[0], inp_shape[0])), + random.randint(1, min(src_shape[1], inp_shape[1])), + random.randint(1, min(src_shape[2], inp_shape[2])), + ] + index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + + m, n, o = index_shape + + index_size_dim = index_shape[dim] + # make unique indices + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, index.size(dim) + 1) + index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim] + + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + # ref_src = to_reference(src) + ref_out = torch.scatter(inp, dim, index, src, reduce="multiply") + with flag_gems.use_gems(): + from src.flag_gems.ops import scatter_reduce + + # res_out = torch.scatter(inp, dim, index, src, reduce="multiply") + res_out = scatter_reduce(inp, dim, index, src, reduce="multiply") + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("inp_shape", [(512, 32 * i, 64 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_gather(inp_shape, dim, dtype): + inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + size_dim = inp_shape[dim] + + import random + + index_shape = [ + random.randint(1, inp_shape[0]), + random.randint(1, inp_shape[1]), + random.randint(1, inp_shape[2]), + ] + index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + + m, n, o = index_shape + + index_size_dim = index_shape[dim] + # make unique indices + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, index.size(dim) + 1) + index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim] + + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + ref_out = torch.gather(inp, dim, index) + with flag_gems.use_gems(): + from src.flag_gems.ops import gather + + # res_out = torch.gather(inp, dim, index) + res_out = gather(inp, dim, index) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("out_shape", [(128, 16 * i, 32 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("inp_shape", [(512, 32 * i, 64 * i) for i in range(1, 10, 4)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_gather_out(out_shape, inp_shape, dim, dtype): + inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + size_dim = min(out_shape[dim], inp_shape[dim]) + + import random + + index_shape = [ + random.randint(1, min(out_shape[0], inp_shape[0])), + random.randint(1, min(out_shape[1], inp_shape[1])), + random.randint(1, min(out_shape[2], inp_shape[2])), + ] + index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + out = torch.randn(tuple(index_shape), dtype=dtype, device="cuda") + + m, n, o = index_shape + + index_size_dim = index_shape[dim] + # make unique indices + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, index.size(dim) + 1) + index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim] + + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + ref_out = torch.gather(inp, dim, index, sparse_grad=False, out=out) + with flag_gems.use_gems(): + from src.flag_gems.ops import gather_out + + # res_out = torch.gather(inp, dim, index, sparse_grad=False, out=out) + res_out = gather_out(inp, dim, index, sparse_grad=False, out=out) + + gems_assert_equal(res_out, ref_out) + + @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dim", DIM_LIST) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @@ -682,9 +888,9 @@ def test_accuracy_index_select(shape, dim, dtype): index = torch.randint(0, index_size, [floor(index_size * 0.8)], device="cuda") - ref_inp = to_reference(inp) - ref_index = to_reference(index) - ref_out = torch.index_select(ref_inp, dim, ref_index) + # ref_inp = to_reference(inp) + # ref_index = to_reference(index) + ref_out = torch.index_select(inp, dim, index) with flag_gems.use_gems(): res_out = torch.index_select(inp, dim, index)