-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] index_add #145
base: master
Are you sure you want to change the base?
[Operator] index_add #145
Conversation
src/flag_gems/ops/index_add.py
Outdated
cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0).to(tl.float32) | ||
src_off = rows_offsets * N + cols_offsets[None, :] | ||
cur_src = tl.load(src + src_off, mask=block_mask, other=0.0).to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly lose precision for fp64 src and inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about just keep src and inp as-is without casting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly lose precision for fp64 src and inputs?
I've encountered precision loss issues in some data types (like bf16 and float32). Ignoring casting might lead to problems. I'll implement the suggested changes below and see if they resolve the issue.
src/flag_gems/ops/index_add.py
Outdated
src = dim_compress(src, dim) | ||
out = inp.clone() | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input & src is permuted into shapes
input: Shape(M, ...) where product(...) == inp_len
src: Shape(M, ...) where product(...) == N
and contiguous.
So we can view then as
input: Shape(M, inp_len)
src: Shape(M, N)
index: (N, )
Then the task is partitioned along the M
dimension in tile size of BLOCK_M
, while the N
dimension is looped in tiles of size BLOCK_N
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though it is hard to figure out a general solution now, but permuting the tensor to make the inp_len & N dimensional to be contiguous is not always good.
For example,
input & src are both 2d tensors, now index_add along axis 0, then the permutations are actually not needed to make index_add easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a key issue I constantly consider(Since it actually occurs in other operations, too). As a temporary solution, I set conditional judgments, such as: if the input dimension equals (self.ndim - 1), I don't perform the permutation. I'm uncertain if this approach is effective.
BTW Performance testing revealed that permutations can increase latency by about 7 times compared to Torch, making the reduction of unnecessary permutations crucial... ; (
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some spaces for optimization, but LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions
- Ensure index is contiguous, or consider its stride;
- keep the data loaded from src as-is to avoid down-cast;
- (Maybe) use some heuristics to make a better task partitioning & avoid unnecessary data permutations.
* Use a 2D grid with the kernel * Ensure index is contiguous * Keep the data in kernel loaded from src * Try to avoid some unnecessary permutations
We have completed the development of the index_add operator. Specifically:
index_add