Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code generation for pointwise operation & PointwiseDynamicFunction #167

Merged
merged 31 commits into from
Sep 12, 2024

Conversation

iclementine
Copy link
Collaborator

@iclementine iclementine commented Aug 16, 2024

  1. Now we generate code with nd tiles and 1d grid with grid-stride-loop, where the n is the ndim of the task space;
  2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work);
  3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous);
  4. make tile size and grid size in generated code configurable;
  5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype;
  6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8;
  7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them;
  8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs;
  9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the output.

1. Now we generate code with nd tiles and 1d grid with grid-stride-loop, where the n is the ndim of the task space;
2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work);
3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous);
4. make tile size and grid size in generated code configurable;
5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype;
6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8;
7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them;
8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs;
9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the output.
@iclementine iclementine changed the title Refar code generation for pointwise operation & PointwiseDynamicFunction Refactor code generation for pointwise operation & PointwiseDynamicFunction Aug 16, 2024
Copy link
Collaborator

@sethbrin sethbrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other, LGTM

src/flag_gems/utils/pointwise_dynamic.py Outdated Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Show resolved Hide resolved
tests/test_pointwise_dynamic.py Outdated Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Outdated Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Outdated Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Outdated Show resolved Hide resolved
src/flag_gems/utils/pointwise_dynamic.py Show resolved Hide resolved
2. add scalar function name and ndim as part of the name of the generated functions.
2. add scalar function name and ndim as part of the name of the generated functions.
@iclementine
Copy link
Collaborator Author

iclementine commented Aug 20, 2024

Since there is some performance degradation in the new PointwiseDynamicFunction, we will try improving it before we proceed.

Update: This is fixed in 8eac7f0.

.gitignore Outdated Show resolved Hide resolved
Bowen12992
Bowen12992 previously approved these changes Sep 11, 2024
Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@iclementine iclementine merged commit 436f1a7 into FlagOpen:master Sep 12, 2024
4 checks passed
DuanYaQi pushed a commit that referenced this pull request Sep 13, 2024
…nction (#167)

* Refar code generation for pointwise operation & PointwiseDynamicFunction
1. Now we generate code with nd tiles with 1d grid or 1d tiles with 1d grid with grid-stride-loop, where the n is the ndim of the task space;
2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work);
3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous);
4. make tile size and grid size in generated code configurable;
5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype;
6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8;
7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them;
8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs;
9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the outpu; add a test for flip op with input that is not c-contiguous.
10. add config as a parameter for PointwoseDynamicFunction
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants