Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] scatter & gather #96

Merged
merged 25 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6006778
[fix] Temporarily use upcasting to make prod support bf16
GwokHiujin May 27, 2024
0a66f15
Merge branch 'master' of github.com:FlagOpen/FlagGems
GwokHiujin Jun 18, 2024
49aacfb
Merge branch 'master' of github.com:FlagOpen/FlagGems
GwokHiujin Jul 5, 2024
31ff2fe
[Operator] add select, scatter, gather
GwokHiujin Jul 5, 2024
7638b50
[fix] del unnecessary offsets calculation
GwokHiujin Jul 9, 2024
6285346
[fix] src shares the same unravel indices with index but different of…
GwokHiujin Jul 9, 2024
0b008e2
[Fix] remove useless src_indices
GwokHiujin Jul 9, 2024
8be0af4
[fix] remove unnecessary indices calculation in gather
GwokHiujin Jul 11, 2024
70b7771
[Operator] add assertion to gather
GwokHiujin Jul 12, 2024
a81463c
[Operator] add assertion to gather.out
GwokHiujin Jul 12, 2024
92268d4
[fix] fix the comparison between 2 BoolType tensors
GwokHiujin Jul 12, 2024
cae80f1
Merge branch 'master' of github.com:FlagOpen/FlagGems into scatter_ga…
GwokHiujin Jul 17, 2024
ea760fc
[fix] use inp's device to initialize tensors
GwokHiujin Jul 17, 2024
3312acf
Merge branch 'master' into scatter_gather
GwokHiujin Jul 18, 2024
9884441
[Operator] remove select op
GwokHiujin Jul 19, 2024
fc10444
Merge remote-tracking branch 'origin/master' into scatter_gather
GwokHiujin Aug 9, 2024
6cbaabf
[chore] Add offset_calculator kernel
GwokHiujin Aug 9, 2024
2c907ea
[fix] Add offset_cal_kernel to utils.init
GwokHiujin Aug 9, 2024
bf5c582
Merge branch 'master' into scatter_gather
GwokHiujin Aug 12, 2024
d3c0461
[chore] remove the previous offsets calculator in utils package's ini…
GwokHiujin Aug 12, 2024
b26a377
[chore] Considering perf, pause the replacement of the aTen operator …
GwokHiujin Aug 19, 2024
3e96bca
[fix] Use ops.scatter/gather instead in the testing
GwokHiujin Aug 19, 2024
1bf7fad
Merge branch 'master' into scatter_gather
GwokHiujin Aug 19, 2024
dfaf063
[fix] Remove the upcasting on the ref input data in the scatter&gathe…
GwokHiujin Aug 19, 2024
a3e43ee
[chore] reformat
GwokHiujin Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def enable(lib=aten_lib):
lib.impl("log_softmax.int", log_softmax, "AutogradCUDA")
lib.impl("outer", outer, "AutogradCUDA")
lib.impl("cross_entropy_loss", cross_entropy_loss, "AutogradCUDA")
# lib.impl("scatter.src", scatter_src, "CUDA")
# lib.impl("scatter.reduce", scatter_reduce, "CUDA")
# lib.impl("gather", gather, "CUDA")
# lib.impl("gather.out", gather_out, "CUDA")
lib.impl("isclose", isclose, "CUDA")
lib.impl("allclose", allclose, "CUDA")
lib.impl("flip", flip, "CUDA")
Expand Down
6 changes: 6 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .flip import flip
from .full import full
from .full_like import full_like
from .gather import gather, gather_out
from .ge import ge, ge_scalar
from .gelu import gelu
from .groupnorm import group_norm
Expand Down Expand Up @@ -71,6 +72,7 @@
from .resolve_neg import resolve_neg
from .rms_norm import rms_norm
from .rsqrt import rsqrt
from .scatter import scatter_reduce, scatter_src
from .sigmoid import sigmoid
from .silu import silu
from .sin import sin
Expand Down Expand Up @@ -121,6 +123,8 @@
"eq_scalar",
"exp",
"exponential_",
"gather",
"gather_out",
"flip",
"ones_like",
"full_like",
Expand Down Expand Up @@ -168,6 +172,8 @@
"reciprocal",
"relu",
"rsqrt",
"scatter_src",
"scatter_reduce",
"sigmoid",
"silu",
"sin",
Expand Down
130 changes: 130 additions & 0 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offset_calculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def gather_kernel(
inp,
inp_offsets,
out,
index,
idx_offsets,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_index = tl.load(index + idx_indices, mask=mask, other=0)
inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0)

tl.store(out + idx_indices, cur_inp, mask=mask)


def gather(inp, dim, index, sparse_grad=False):
logging.debug("GEMS GATHER")
assert (
inp.ndim == index.ndim
), "self and index should all have the same number of dimensions"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
inp = inp.contiguous()
index = index.contiguous()
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)

inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
Comment on lines +72 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems if idx is always passed as a trivial iterator, it may not be materialized.

idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
gather_kernel[grid](
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim)
)
return out


def gather_out(inp, dim, index, sparse_grad=False, out=None):
logging.debug("GEMS GATHER OUT")
assert (
inp.ndim == index.ndim and inp.ndim == out.ndim
), "self, index and out (if it is a Tensor) should all have the same number of dimensions"
assert (
(0 <= index.size(i) and index.size(i) <= out.size(i))
for i in range(0, index.ndim)
), "index.size(d) <= out.size(d) for all dimensions d"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
assert index.shape == out.shape, "out will have the same shape as index"
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
if out is None:
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
inp = inp.contiguous()
index = index.contiguous()
out = out.contiguous()

inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
gather_kernel[grid](
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim)
)
return out
216 changes: 216 additions & 0 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offset_calculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_kernel(
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
tl.store(out + inp_indices, cur_src, mask=mask)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_add_kernel(
inp,
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32)
res = cur_inp + cur_src
tl.store(out + inp_indices, res, mask=mask)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def scatter_mul_kernel(
inp,
inp_offsets,
src,
index,
idx_offsets,
out,
M,
N,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0)
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0)

cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32)
cur_index = tl.load(index + idx_indices, mask=mask, other=0)

inp_indices += cur_index * stride_dim
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32)
res = cur_inp * cur_src
tl.store(out + inp_indices, res, mask=mask)


def scatter(inp, dim, index, src, reduction=None):
assert (
inp.ndim == index.ndim and inp.ndim == src.ndim
), "self, index and src (if it is a Tensor) should all have the same number of dimensions"
assert (
(0 <= index.size(i) and index.size(i) <= src.size(i))
for i in range(0, index.ndim)
), "index.size(d) <= src.size(d) for all dimensions d"
assert (
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim)
for i in range(0, index.ndim)
), "index.size(d) <= self.size(d) for all dimensions d != dim"
inp = inp.contiguous()
index = index.contiguous()
src = src.contiguous()
out = inp.clone()

src_strided = src.as_strided(index.shape, src.stride()).contiguous()
inp_strided = restride_dim(inp, dim, index.shape)
# FIXME: Are there any other way to get the "flatten offset" of a tensor?
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False)
N = list(index.shape)[index.ndim - 1]
M = index.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
if reduction is None:
scatter_kernel[grid](
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
elif reduction == "add":
scatter_add_kernel[grid](
inp,
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
elif reduction == "multiply":
scatter_mul_kernel[grid](
inp,
inp_offsets,
src_strided,
index,
idx_offsets,
out,
M,
N,
inp.stride(dim),
)
return out


def scatter_src(inp, dim, index, src):
logging.debug("GEMS SCATTER SRC")
return scatter(inp, dim, index, src)


def scatter_reduce(inp, dim, index, src, reduce):
logging.debug("GEMS SCATTER REDUCE")
# TODO: As is shown in PyTorch's document(torch.Tensor.scatter_reduce_),
# this function is still **in beta** and may change in the near future.
# So for now, we're just going to stick with the original "add" and "multiply" parameters.
# Maybe we can add reduction options like "sum", "prod", "mean", "amax" and "amin" in the future.
if reduce == "add":
return scatter(inp, dim, index, src, reduction="add")
elif reduce == "multiply":
return scatter(inp, dim, index, src, reduction="multiply")
Loading