-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] scatter & gather #96
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
6006778
[fix] Temporarily use upcasting to make prod support bf16
GwokHiujin 0a66f15
Merge branch 'master' of github.com:FlagOpen/FlagGems
GwokHiujin 49aacfb
Merge branch 'master' of github.com:FlagOpen/FlagGems
GwokHiujin 31ff2fe
[Operator] add select, scatter, gather
GwokHiujin 7638b50
[fix] del unnecessary offsets calculation
GwokHiujin 6285346
[fix] src shares the same unravel indices with index but different of…
GwokHiujin 0b008e2
[Fix] remove useless src_indices
GwokHiujin 8be0af4
[fix] remove unnecessary indices calculation in gather
GwokHiujin 70b7771
[Operator] add assertion to gather
GwokHiujin a81463c
[Operator] add assertion to gather.out
GwokHiujin 92268d4
[fix] fix the comparison between 2 BoolType tensors
GwokHiujin cae80f1
Merge branch 'master' of github.com:FlagOpen/FlagGems into scatter_ga…
GwokHiujin ea760fc
[fix] use inp's device to initialize tensors
GwokHiujin 3312acf
Merge branch 'master' into scatter_gather
GwokHiujin 9884441
[Operator] remove select op
GwokHiujin fc10444
Merge remote-tracking branch 'origin/master' into scatter_gather
GwokHiujin 6cbaabf
[chore] Add offset_calculator kernel
GwokHiujin 2c907ea
[fix] Add offset_cal_kernel to utils.init
GwokHiujin bf5c582
Merge branch 'master' into scatter_gather
GwokHiujin d3c0461
[chore] remove the previous offsets calculator in utils package's ini…
GwokHiujin b26a377
[chore] Considering perf, pause the replacement of the aTen operator …
GwokHiujin 3e96bca
[fix] Use ops.scatter/gather instead in the testing
GwokHiujin 1bf7fad
Merge branch 'master' into scatter_gather
GwokHiujin dfaf063
[fix] Remove the upcasting on the ref input data in the scatter&gathe…
GwokHiujin a3e43ee
[chore] reformat
GwokHiujin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from ..utils import libentry, offset_calculator, restride_dim | ||
|
||
|
||
def cfggen(): | ||
block_m = [1, 2, 4, 8] | ||
configs = [ | ||
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m | ||
] | ||
return configs | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def gather_kernel( | ||
inp, | ||
inp_offsets, | ||
out, | ||
index, | ||
idx_offsets, | ||
M, | ||
N, | ||
stride_dim, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) | ||
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) | ||
|
||
cur_index = tl.load(index + idx_indices, mask=mask, other=0) | ||
inp_indices += cur_index * stride_dim | ||
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0) | ||
|
||
tl.store(out + idx_indices, cur_inp, mask=mask) | ||
|
||
|
||
def gather(inp, dim, index, sparse_grad=False): | ||
logging.debug("GEMS GATHER") | ||
assert ( | ||
inp.ndim == index.ndim | ||
), "self and index should all have the same number of dimensions" | ||
assert ( | ||
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) | ||
for i in range(0, index.ndim) | ||
), "index.size(d) <= self.size(d) for all dimensions d != dim" | ||
assert ((0 <= index) * (index < inp.size(dim))).equal( | ||
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) | ||
), "0 <= index < self.size(dim)" | ||
inp = inp.contiguous() | ||
index = index.contiguous() | ||
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) | ||
|
||
inp_strided = restride_dim(inp, dim, index.shape) | ||
# FIXME: Are there any other way to get the "flatten offset" of a tensor? | ||
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) | ||
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), | ||
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), | ||
# and we do need **the whole stride[]** to accomplish this calculation! | ||
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel | ||
# so that the offset calculation can proceed in parallel | ||
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) | ||
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) | ||
N = list(index.shape)[index.ndim - 1] | ||
M = index.numel() // N | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
gather_kernel[grid]( | ||
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim) | ||
) | ||
return out | ||
|
||
|
||
def gather_out(inp, dim, index, sparse_grad=False, out=None): | ||
logging.debug("GEMS GATHER OUT") | ||
assert ( | ||
inp.ndim == index.ndim and inp.ndim == out.ndim | ||
), "self, index and out (if it is a Tensor) should all have the same number of dimensions" | ||
assert ( | ||
(0 <= index.size(i) and index.size(i) <= out.size(i)) | ||
for i in range(0, index.ndim) | ||
), "index.size(d) <= out.size(d) for all dimensions d" | ||
assert ( | ||
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) | ||
for i in range(0, index.ndim) | ||
), "index.size(d) <= self.size(d) for all dimensions d != dim" | ||
assert index.shape == out.shape, "out will have the same shape as index" | ||
assert ((0 <= index) * (index < inp.size(dim))).equal( | ||
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) | ||
), "0 <= index < self.size(dim)" | ||
if out is None: | ||
out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) | ||
inp = inp.contiguous() | ||
index = index.contiguous() | ||
out = out.contiguous() | ||
|
||
inp_strided = restride_dim(inp, dim, index.shape) | ||
# FIXME: Are there any other way to get the "flatten offset" of a tensor? | ||
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) | ||
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), | ||
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), | ||
# and we do need **the whole stride[]** to accomplish this calculation! | ||
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel | ||
# so that the offset calculation can proceed in parallel | ||
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) | ||
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) | ||
N = list(index.shape)[index.ndim - 1] | ||
M = index.numel() // N | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
gather_kernel[grid]( | ||
inp, inp_offsets, out, index, idx_offsets, M, N, inp.stride(dim) | ||
) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from ..utils import libentry, offset_calculator, restride_dim | ||
|
||
|
||
def cfggen(): | ||
block_m = [1, 2, 4, 8] | ||
configs = [ | ||
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m | ||
] | ||
return configs | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def scatter_kernel( | ||
inp_offsets, | ||
src, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
stride_dim, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) | ||
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) | ||
|
||
cur_src = tl.load(src + idx_indices, mask=mask, other=0) | ||
cur_index = tl.load(index + idx_indices, mask=mask, other=0) | ||
|
||
inp_indices += cur_index * stride_dim | ||
tl.store(out + inp_indices, cur_src, mask=mask) | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def scatter_add_kernel( | ||
inp, | ||
inp_offsets, | ||
src, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
stride_dim, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) | ||
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) | ||
|
||
cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32) | ||
cur_index = tl.load(index + idx_indices, mask=mask, other=0) | ||
|
||
inp_indices += cur_index * stride_dim | ||
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32) | ||
res = cur_inp + cur_src | ||
tl.store(out + inp_indices, res, mask=mask) | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def scatter_mul_kernel( | ||
inp, | ||
inp_offsets, | ||
src, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
stride_dim, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
inp_indices = tl.load(inp_offsets + offsets, mask=mask, other=0) | ||
idx_indices = tl.load(idx_offsets + offsets, mask=mask, other=0) | ||
|
||
cur_src = tl.load(src + idx_indices, mask=mask, other=0).to(tl.float32) | ||
cur_index = tl.load(index + idx_indices, mask=mask, other=0) | ||
|
||
inp_indices += cur_index * stride_dim | ||
cur_inp = tl.load(inp + inp_indices, mask=mask, other=0).to(tl.float32) | ||
res = cur_inp * cur_src | ||
tl.store(out + inp_indices, res, mask=mask) | ||
|
||
|
||
def scatter(inp, dim, index, src, reduction=None): | ||
assert ( | ||
inp.ndim == index.ndim and inp.ndim == src.ndim | ||
), "self, index and src (if it is a Tensor) should all have the same number of dimensions" | ||
assert ( | ||
(0 <= index.size(i) and index.size(i) <= src.size(i)) | ||
for i in range(0, index.ndim) | ||
), "index.size(d) <= src.size(d) for all dimensions d" | ||
assert ( | ||
((0 <= index.size(i) and index.size(i) <= inp.size(i)) or i == dim) | ||
for i in range(0, index.ndim) | ||
), "index.size(d) <= self.size(d) for all dimensions d != dim" | ||
inp = inp.contiguous() | ||
index = index.contiguous() | ||
src = src.contiguous() | ||
out = inp.clone() | ||
|
||
src_strided = src.as_strided(index.shape, src.stride()).contiguous() | ||
inp_strided = restride_dim(inp, dim, index.shape) | ||
# FIXME: Are there any other way to get the "flatten offset" of a tensor? | ||
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) | ||
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), | ||
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), | ||
# and we do need **the whole stride[]** to accomplish this calculation! | ||
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel | ||
# so that the offset calculation can proceed in parallel | ||
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) | ||
idx_offsets = offset_calculator(index, idx, index.stride(), dim, isInp=False) | ||
N = list(index.shape)[index.ndim - 1] | ||
M = index.numel() // N | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
if reduction is None: | ||
scatter_kernel[grid]( | ||
inp_offsets, | ||
src_strided, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
inp.stride(dim), | ||
) | ||
elif reduction == "add": | ||
scatter_add_kernel[grid]( | ||
inp, | ||
inp_offsets, | ||
src_strided, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
inp.stride(dim), | ||
) | ||
elif reduction == "multiply": | ||
scatter_mul_kernel[grid]( | ||
inp, | ||
inp_offsets, | ||
src_strided, | ||
index, | ||
idx_offsets, | ||
out, | ||
M, | ||
N, | ||
inp.stride(dim), | ||
) | ||
return out | ||
|
||
|
||
def scatter_src(inp, dim, index, src): | ||
logging.debug("GEMS SCATTER SRC") | ||
return scatter(inp, dim, index, src) | ||
|
||
|
||
def scatter_reduce(inp, dim, index, src, reduce): | ||
logging.debug("GEMS SCATTER REDUCE") | ||
# TODO: As is shown in PyTorch's document(torch.Tensor.scatter_reduce_), | ||
# this function is still **in beta** and may change in the near future. | ||
# So for now, we're just going to stick with the original "add" and "multiply" parameters. | ||
# Maybe we can add reduction options like "sum", "prod", "mean", "amax" and "amin" in the future. | ||
if reduce == "add": | ||
return scatter(inp, dim, index, src, reduction="add") | ||
elif reduce == "multiply": | ||
return scatter(inp, dim, index, src, reduction="multiply") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems if idx is always passed as a trivial iterator, it may not be materialized.