Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] scatter & gather #96

Merged
merged 25 commits into from
Sep 2, 2024
Merged

[Operator] scatter & gather #96

merged 25 commits into from
Sep 2, 2024

Conversation

GwokHiujin
Copy link
Collaborator

@GwokHiujin GwokHiujin commented Jul 5, 2024

We have completed the development of the select, scatter, and gather operators. Specifically:

if d == dim:
idx_dim = add_on
idx = idx // shape[d]
# FIXME: Should we write a fast div/mod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need a fast div-mod functor, but i wonder if it already exist in triton ?

Copy link
Collaborator Author

@GwokHiujin GwokHiujin Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked out triton's source code and found that its fast division is implemented using NumPy's np.divide, and as far as I know, np's integer division is optimized. So do you think we just need to import np's div/mod operators in this file to use them?
BTW triton doesn't have a divMod-like functor who can return the pair (div i b, mod i b).

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do benchmark for the ops?

@@ -137,3 +139,30 @@ def dim_compress(inp, dims):
sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)
order = batch_dim + sorted_reduction_dim
return inp.permute(order).contiguous()


def offsetCalculator(inp, idx, strides, dim, isInp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we doc this function to help others understand it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offset calculation incurs massive overhead. Let's try to do it in a Triton kernel, shall we?

@GwokHiujin
Copy link
Collaborator Author

Hello reviewers, I’ve added the version of the offsets_calculator’s Triton kernel in the latest commits!

Based on the perf test results (which show some improvement over the previous version, but still lagging behind Torch, with latency levels remaining the same as before...), I've temporarily switched the offsets calculations in scatter & gather to the kernel implementation version.

Everyone can take a look at this part of the code for review.

@tongxin
Copy link
Contributor

tongxin commented Aug 9, 2024

Thanks Xiaoyan. The improvement is significant and we appreciate any effort that contributes to a better codebase.

Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +72 to +78
idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
# Temporarily call offsetCalculator() outside the block(although it can actually proceed in parallel),
# because the triton jit.function cannot accept Tuple as input in version 2.2.0(in 3.0.0, it's available),
# and we do need **the whole stride[]** to accomplish this calculation!
# FIXME: If stride[] can be wholely passed to triton jit.function, we can do this calculation in the kernel
# so that the offset calculation can proceed in parallel
inp_offsets = offset_calculator(inp_strided, idx, inp.stride(), dim, isInp=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems if idx is always passed as a trivial iterator, it may not be materialized.

@tongxin tongxin merged commit d78b76d into master Sep 2, 2024
1 check passed
@tongxin tongxin deleted the scatter_gather branch September 2, 2024 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants