Skip to content

Commit

Permalink
[Operator] scatter & gather (#96)
Browse files Browse the repository at this point in the history
* [fix] Temporarily use upcasting to make prod support bf16

* [Operator] add select, scatter, gather

* [fix] del unnecessary offsets calculation

* [fix] src shares the same unravel indices with index but different offsets

* [Fix] remove useless src_indices

* [fix] remove unnecessary indices calculation in gather

* [Operator] add assertion to gather

* [Operator] add assertion to gather.out

* [fix] fix the comparison between 2 BoolType tensors

* [fix] use inp's device to initialize tensors

* [Operator] remove select op

* [chore] Add offset_calculator kernel
* Use triton to do the offset calculations, the perf test results can be seen in scatter&gather doc

* [fix] Add offset_cal_kernel to utils.init

* [chore] remove the previous offsets calculator in utils package's init code

* [chore] Considering perf, pause the replacement of the aTen operator related to scatter&gather

* [fix] Use ops.scatter/gather instead in the testing

* [fix] Remove the upcasting on the ref input data in the scatter&gather test

* [chore] reformat
  • Loading branch information
GwokHiujin authored Sep 2, 2024
1 parent 40945c3 commit d78b76d
Show file tree
Hide file tree
Showing 7 changed files with 658 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def enable(lib=aten_lib):
lib.impl("log_softmax.int", log_softmax, "AutogradCUDA")
lib.impl("outer", outer, "AutogradCUDA")
lib.impl("cross_entropy_loss", cross_entropy_loss, "AutogradCUDA")
# lib.impl("scatter.src", scatter_src, "CUDA")
# lib.impl("scatter.reduce", scatter_reduce, "CUDA")
# lib.impl("gather", gather, "CUDA")
# lib.impl("gather.out", gather_out, "CUDA")
lib.impl("isclose", isclose, "CUDA")
lib.impl("allclose", allclose, "CUDA")
lib.impl("flip", flip, "CUDA")
Expand Down
6 changes: 6 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .flip import flip
from .full import full
from .full_like import full_like
from .gather import gather, gather_out
from .ge import ge, ge_scalar
from .gelu import gelu
from .groupnorm import group_norm
Expand Down Expand Up @@ -75,6 +76,7 @@
from .resolve_neg import resolve_neg
from .rms_norm import rms_norm
from .rsqrt import rsqrt
from .scatter import scatter_reduce, scatter_src
from .sigmoid import sigmoid
from .silu import silu
from .sin import sin
Expand Down Expand Up @@ -131,6 +133,8 @@
"eq_scalar",
"exp",
"exponential_",
"gather",
"gather_out",
"flip",
"ones_like",
"full_like",
Expand Down Expand Up @@ -180,6 +184,8 @@
"reciprocal",
"relu",
"rsqrt",
"scatter_src",
"scatter_reduce",
"sigmoid",
"silu",
"sin",
Expand Down
130 changes: 130 additions & 0 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offset_calculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def gather_kernel(
inp,
inp_offsets,
out,
index,
idx_offsets,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_index = tl.load(index + idx_indices, mask=mask, other=0)
inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0)

tl.store(out + idx_indices, cur_inp, mask=mask)


def gather(inp, dim, index, sparse_grad=False):
logging.debug("GEMS GATHER")
assert (
inp.ndim == index.ndim
), "self and index should all have the same number of dimensions"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
inp = inp.contiguous()
index = index.contiguous()
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)

inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
gather_kernel[grid](
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim)
)
return out


def gather_out(inp, dim, index, sparse_grad=False, out=None):
logging.debug("GEMS GATHER OUT")
assert (
inp.ndim == index.ndim and inp.ndim == out.ndim
), "self, index and out (if it is a Tensor) should all have the same number of dimensions"
assert (
(0 <= index.size(i) and index.size(i) <= out.size(i))
for i in range(0, index.ndim)
), "index.size(d) <= out.size(d) for all dimensions d"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
assert index.shape == out.shape, "out will have the same shape as index"
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
if out is None:
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
inp = inp.contiguous()
index = index.contiguous()
out = out.contiguous()

inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
gather_kernel[grid](
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim)
)
return out
216 changes: 216 additions & 0 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offset_calculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_kernel(
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
tl.store(out + inp_indices, cur_src, mask=mask)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_add_kernel(
inp,
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32)
res = cur_inp + cur_src
tl.store(out + inp_indices, res, mask=mask)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_mul_kernel(
inp,
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32)
res = cur_inp * cur_src
tl.store(out + inp_indices, res, mask=mask)


def scatter(inp, dim, index, src, reduction=None):
assert (
inp.ndim == index.ndim and inp.ndim == src.ndim
), "self, index and src (if it is a Tensor) should all have the same number of dimensions"
assert (
(0 <= index.size(i) and index.size(i) <= src.size(i))
for i in range(0, index.ndim)
), "index.size(d) <= src.size(d) for all dimensions d"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
inp = inp.contiguous()
index = index.contiguous()
src = src.contiguous()
out = inp.clone()

src_strided = src.as_strided(index.shape, src.stride()).contiguous()
inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
if reduction is None:
scatter_kernel[grid](
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
elif reduction == "add":
scatter_add_kernel[grid](
inp,
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
elif reduction == "multiply":
scatter_mul_kernel[grid](
inp,
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
return out


def scatter_src(inp, dim, index, src):
logging.debug("GEMS SCATTER SRC")
return scatter(inp, dim, index, src)


def scatter_reduce(inp, dim, index, src, reduce):
logging.debug("GEMS SCATTER REDUCE")
# TODO: As is shown in PyTorch's document(torch.Tensor.scatter_reduce_),
# this function is still **in beta** and may change in the near future.
# So for now, we're just going to stick with the original "add" and "multiply" parameters.
# Maybe we can add reduction options like "sum", "prod", "mean", "amax" and "amin" in the future.
if reduce == "add":
return scatter(inp, dim, index, src, reduction="add")
elif reduce == "multiply":
return scatter(inp, dim, index, src, reduction="multiply")
Loading

0 comments on commit d78b76d

Please sign in to comment.