Skip to content

Commit

Permalink
Refar code generation for pointwise operation & PointwiseDynamicFunction
Browse files Browse the repository at this point in the history
1. Now we generate code with nd tiles and 1d grid with grid-stride-loop, where the n is the ndim of the task space;
2. Add some simple logic to simplify task space(when all operand have the same shape and same stride, and all of them are non overlapping and dense, we simplify the task space into a 1d space, although we can use better policy but we leave it for future work);
3. Use a smarter policy for output layout inference:(the output will follow the stride order of the first tensor that has the same shape as the broadcasted shape, pre-defined ouputs has higher priority than all input tensors; otherwise, the output in c-contiguous);
4. make tile size and grid size in generated code configurable;
5. work around the problem that save to block pointer does not automatically cast the value to the pointer's dtype;
6. work around the problem that values loaded from a pointer to bool is int8 and block pointer from pointer to bool has dtype int8;
7. fix the bitwise-* operators without those work arounds, and add test-cases with bool inputs & outputs for them;
8. add TypedPtr and StridedBuffer as drop-in replament for torch.Tensor to be used in generated triton kernel & wrappers, which allows some unsafe reinterpretation of Tensors(dtype, shape, stride, offset), which cannot be done by torch APIs;
9. fix a bug in flip op where the flipped view(shifted data pointer and negative strides) from input in applied to the output.
  • Loading branch information
iclementine committed Aug 16, 2024
1 parent a156268 commit eaba92b
Show file tree
Hide file tree
Showing 11 changed files with 1,215 additions and 764 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ build/
# Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties`
# should NOT be excluded as they contain compiler settings and other important
# information for Eclipse / Flash Builder.
playground/
14 changes: 11 additions & 3 deletions src/flag_gems/ops/flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import triton

from ..utils import pointwise_dynamic
from ..utils.tensor_wrapper import StridedBuffer


@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
@triton.jit
def flip_func(x, **kwargs):
def copy_func(x):
return x


Expand All @@ -29,10 +30,17 @@ def flip(A: torch.Tensor, dims) -> torch.Tensor:
n = 0
offset = 0
for i in range(len(flip_dims_b)):
if flip_dims_b[i] and A.size()[i] > 1 and A.stride()[i] != 0:
if flip_dims_b[i] and A.size(i) > 1 and A.stride(i) != 0:
offset += strides[i] * (A.shape[i] - 1)
strides[i] = -strides[i]
n += 1
if n == 0 or A.numel() <= 1:
return A.clone()
return flip_func(A, out0_offset=offset, out0_strides=strides)
out = torch.empty_like(A)
# a flipped view of A
flipped_A = StridedBuffer(A, strides=strides, offset=offset)

# TODO: flip op can have a custom task simplification method, but we skip it now and just use A's rank.
overload = copy_func.instantiate(A.ndim)
overload(flipped_A, out0=out)
return out
3 changes: 3 additions & 0 deletions src/flag_gems/utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def writelines(self, lines):
for line in lines:
self.writeline(line)

def writemultiline(self, s):
self.writelines(s.splitlines())

def indent(self, offset=1):
@contextlib.contextmanager
def ctx():
Expand Down
Loading

0 comments on commit eaba92b

Please sign in to comment.