-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor code generation for pointwise operation & PointwiseDynamicFunction #167
Conversation
iclementine
commented
Aug 16, 2024
•
edited
Loading
edited
- Now we generate code with nd tiles and 1d grid with grid-stride-loop, where the n is the ndim of the task space;
- Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work);
- Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous);
- make tile size and grid size in generated code configurable;
- work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype;
- work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8;
- fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them;
- add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs;
- fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the output.
1. Now we generate code with nd tiles and 1d grid with grid-stride-loop, where the n is the ndim of the task space; 2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work); 3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous); 4. make tile size and grid size in generated code configurable; 5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype; 6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8; 7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them; 8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs; 9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the output.
… broadcasting, does not allocate outputs, and keeps outputs type(Tensor or StridedBuffer)
…by block_pointer; 2. heuristics_for_tile_sizes now prefer large size inner dimension, since we change to C order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other, LGTM
2. add scalar function name and ndim as part of the name of the generated functions.
2. add scalar function name and ndim as part of the name of the generated functions.
Since there is some performance degradation in the new PointwiseDynamicFunction, we will try improving it before we proceed. Update: This is fixed in 8eac7f0. |
… the key when checking for exsting overload, the key is integer
…en triton version is less than 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
…nction (#167) * Refar code generation for pointwise operation & PointwiseDynamicFunction 1. Now we generate code with nd tiles with 1d grid or 1d tiles with 1d grid with grid-stride-loop, where the n is the ndim of the task space; 2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work); 3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous); 4. make tile size and grid size in generated code configurable; 5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype; 6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8; 7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them; 8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs; 9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the outpu; add a test for flip op with input that is not c-contiguous. 10. add config as a parameter for PointwoseDynamicFunction