-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PyTensor backend #362
Add PyTensor backend #362
Conversation
58831b5
to
95c9c3a
Compare
Ah you're running the CI on 3.8 which we don't support anymore :) |
95c9c3a
to
fca5ff1
Compare
fca5ff1
to
65d8cad
Compare
Most imperative tests would work fine with PyTensor except for the use of |
1f5c250
to
1d0a538
Compare
def to_numpy(self, x): | ||
return x.eval() # Will only work if there are no symbolic inputs | ||
|
||
def create_symbol(self, shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's assume that shape is tuple/list. Does it break anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were tests where it was an integer. Needs to always be a sequence for us, that's why I added the check
return self.pt.tensor(shape=shape) | ||
|
||
def eval_symbol(self, symbol, input_dict): | ||
# input_dict is actually a list of tuple? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is dict: input symbol -> input value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests always passed it as a list of tuples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, correct. Some of backends' symbols were not hashable, need to rename this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symbol_values
?
@@ -254,7 +255,7 @@ def test_functional_symbolic(): | |||
) | |||
if predicted_out_data.shape != out_shape: | |||
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}") | |||
assert np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5) | |||
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx for fixing
frameworks: ['numpy pytorch tensorflow jax'] | ||
frameworks: ['numpy pytorch tensorflow jax', 'pytensor'] | ||
exclude: | ||
- python-version: '3.8' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you document which versions won't work and why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, where would you like that, in the readme? A comment here? I always break yaml files whenever I try to put comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right in the file, there are some commens about testing of other frameworks above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
need to think about this. Maybe we'll have to skip testing because of this, or use |
Not a blocker, I registered pytensor as symbolic so it's not picked by those tests. Just saying it would work fine because we can do graphs of constants. |
1d0a538
to
ea5d803
Compare
Tests are passing with the last pytensor release. Let me know if you want me to change anything else. Want me to rename the |
Looks good, I'll merge this.
NP, will rename in a separate PR. |
Any estimation for when a next release will be cut? |
I estimate this to happen a couple of hours ago :) Recommendation to add this to your CI so if you decide some changes around symbols (hashability/boolability/etc), you could confirm that einops integration is not broken.
Tests are rather fast. |
Thanks 🙏 |
Hi!
I understand that the library is moving away from specific backends in favor of the array api standard, but this was so much easier for us to implement and test.
From a first glance the array api standard compatibility suite seems completely oblivious to lazy backends right now. Will double check and open an issue with them later if my suspicion is correct.
This also seems to test part of the test codebase that was untested since the drop of MxNet (?). Fixed two errors there.
Happy to remove PyTensor from the test suite if that's a drag / blocker. Not sure if we could test it from our CI as some of these test options are hardcoded in the library? Sounds like it should be easy to make them customizable though?
Currently all tests pass except the 3 tensor einsum multiplication. Issue opened in pymc-devs/pytensor#1184Should pass when it picks the new release