-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tripy] Return the shape immediately if it is statically known instead of producing a trace operator. #379
base: main
Are you sure you want to change the base?
Conversation
When informally benchmarking, I did not find a large difference between the current implementation and this change, however. Perhaps the issue is how the benchmark was measured. Informal benchmark: The function to test: def func(t1, t2):
return tp.Tensor(list(t1.shape) != list(t2.shape), dtype=tp.bool) I wrote it this way because the Comparison:
After warming up, both complete in about 25 microseconds. Should we do this comparison without compiling instead? I could also vary the input values. |
self.shape_memo = [DimensionSize(dim) for dim in self.trace_tensor.shape] | ||
|
||
return self.shape_memo or [ | ||
GetDimensionSize.build([self], dim=index, always_cast_to_dimension_size=True) for index in range(self.rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memoizing the GetDimensionSize
ops led to an MLIR error upon repeated calls, and I'm not entirely sure why.
The case that fails (doc for repeat):
inp = tp.reshape(tp.arange(4, dtype=tp.int32), (2, 2))
out0 = tp.repeat(inp, 2, dim=0)
out1 = tp.repeat(inp, 2, dim=1)
np_inp = np.from_dlpack(tp.copy(inp, device=tp.device("cpu"))) # doc: omit
ref_out0 = np.repeat(np_inp, 2, 0)
assert np.array_equal(ref_out0, np.from_dlpack(tp.copy(out0, device=tp.device("cpu")))) # <- succeeds
ref_out1 = np.repeat(np_inp, 2, 1)
assert np.array_equal(ref_out1, np.from_dlpack(tp.copy(out1, device=tp.device("cpu")))) # <- fails
Error:
(t347)error: number of output elements (8) doesn't match expected number of elements (4)
This error occured while trying to compile the following FlatIR expression:
|
| t347: [rank=(3), shape=((-1, -1, -1)), dtype=(int32), loc=(gpu:0)] = DynamicReshapeOp(t305, t346)
|
Note: This originated from the following expression:
--> /tripy/docs/../tripy/frontend/ops/unsqueeze.py:56 in unsqueeze()
|
56 | return reshape(input, result_shape)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
--> /tripy/docs/../tripy/frontend/ops/repeat.py:92 in repeat()
|
92 | out = unsqueeze(input, dim + 1)
| ^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
The unsqueeze gets the shape of the input and computes a new shape as an argument to reshape
. I will see if I can figure out how that results in an error when lowering to MLIR.
if self.shape_memo is None and all(dim >= 0 for dim in self.trace_tensor.shape): | ||
self.shape_memo = [DimensionSize(dim) for dim in self.trace_tensor.shape] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's actually a benefit to storing shape_memo
. Can't we simplify this to:
if all(dim >= 0 for dim in self.trace_tensor.shape):
return list(self.trace_tensor.shape)
?
I don't think we even need to return DimensionSize
in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the signature of the shape operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but we could always change the signature. Ideally, DimensionSize
and int
behave basically the same from the user's perspective. I'm assuming you have shape_memo
here because it's expensive to construct DimensionSize
s?
@slyubomirsky most of the impact will be on the compile time I think. |
Ah, okay. I'll profile compile time then. |
Addresses issue #360, though a benchmark still has yet to be added.