-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] scatter & gather #96
Conversation
if d == dim: | ||
idx_dim = add_on | ||
idx = idx // shape[d] | ||
# FIXME: Should we write a fast div/mod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do need a fast div-mod functor, but i wonder if it already exist in triton ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked out triton's source code and found that its fast division is implemented using NumPy's np.divide
, and as far as I know, np's integer division is optimized. So do you think we just need to import np's div/mod operators in this file to use them?
BTW triton doesn't have a divMod-like functor who can return the pair (div i b, mod i b)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do benchmark for the ops?
@@ -137,3 +139,30 @@ def dim_compress(inp, dims): | |||
sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True) | |||
order = batch_dim + sorted_reduction_dim | |||
return inp.permute(order).contiguous() | |||
|
|||
|
|||
def offsetCalculator(inp, idx, strides, dim, isInp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we doc this function to help others understand it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The offset calculation incurs massive overhead. Let's try to do it in a Triton kernel, shall we?
* Use triton to do the offset calculations, the perf test results can be seen in scatter&gather doc
Hello reviewers, I’ve added the version of the offsets_calculator’s Triton kernel in the latest commits! Based on the perf test results (which show some improvement over the previous version, but still lagging behind Torch, with latency levels remaining the same as before...), I've temporarily switched the offsets calculations in scatter & gather to the kernel implementation version. Everyone can take a look at this part of the code for review. |
Thanks Xiaoyan. The improvement is significant and we appreciate any effort that contributes to a better codebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) | ||
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel), | ||
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available), | ||
# and we do need **the whole stride[]** to accomplish this calculation! | ||
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel | ||
# so that the offset calculation can proceed in parallel | ||
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems if idx is always passed as a trivial iterator, it may not be materialized.
We have completed the development of the
select
,scatter
, andgather
operators. Specifically:The corresponding aTen operators are
select.int
,scatter.src
,scatter.reduce
,scatter_add
,gather
, andgather.out
.According to the document, since
scatter.reduce
is still in beta, we have only implemented the add and multiply reduce options, which were originally supported in thescatter
operator's parameters. Torch plans to implement additional reduce options such as mean, amax, and amin in the future. We can evaluate adding these reduce options or not based on our future work plans.See also:
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
The
scatter
andgather
operators face reproducibility issues in output due to potential non-deterministic results caused by non-unique indices. It's worth noting that this issue is not exclusive to scatter-related operators, so further discussions might be needed on potential solutions. For example, following torch's approach by sacrificing some performance to ensure reproducibility for the same set of inputs.See also:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
To avoid issues caused by non-deterministic results, we designed test cases with unique indices to ensure consistent output.