Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement @as_jax_op to wrap a JAX function for use in PyTensor #1120

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

jdehning
Copy link

@jdehning jdehning commented Dec 12, 2024

Description

Add a decorator that transforms a JAX function such that it can be used in PyTensor. Shape and dtype inference works automatically and input and output can be any nested python structure (e.g. Pytrees). Furthermore, using a transformed function as an argument for another transformed function should also work.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

ToDos

  • Implement Op.__props__
  • Let make_node be specified by the user, to support non-inferrable shapes
    - JAXOp is now directly usable by the user
  • Add tests for JAXOp
  • Add some meaningful error messages to common runtime errors

📚 Documentation preview 📚: https://pytensor--1120.org.readthedocs.build/en/1120/

@jdehning
Copy link
Author

jdehning commented Dec 12, 2024

I have a question, where should I put the @as_jax_op. Currently, it is in a new file pytensor/link/jax/ops.py. Does that make sense? Also, how should one access it? Only by calling pytensor.link.jax.ops.as_jax_op? Or include it in a __init__.py such that pytensor.as_jax_op works?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 12, 2024

We can put in init as long as imports work in a way that jax is still optional for Pytensor users (obviously calling the decorator can raise if it's not installed, hopefully with an informative message)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks sweet. I'll do a more careful review later, just skimmed through and annotated some thoughts

self.num_inputs = len(inputs)

# Define our output variables
outputs = [pt.as_tensor_variable(type()) for type in self.output_types]
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use jax machinery to infer the output types from the input types? Can we created TraceDArrays (or whatever they're called) and pass them through the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scrap that, JAX doesn't let you trace arrays without unknown shape

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trace the shape through the JAX function in line 119 of the file. It won't work for unknown shape. But if one specifies the shape at the beginning of a graph, i.e. x = pm.Normal("x", shape=(3,)), and it loses static shape information afterwards, for instance because of a pt.cumsum, line 99 (pytensor.compile.builders.infer_shape) will be able to infer the shape. But that is a good comment, I will raise an error if pytensor.compile.builders.infer_shape isn't able to infer the shape. I think it makes sense to only use this wrapper if the shape information is known.

Copy link
Author

@jdehning jdehning Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see a point where it will lead to problems: If there is an input x = pm.Data("x", shape=(None,), value= np.array([0., 0])): in the first run, it will work, as pytensor.compile.builders.infer_shape will infer the shape as (2,), but if one changes with x.set_value(np.array([0., 0, 0])) the shape of x, it will lead to an error, as the Pytensor Op has been created with an explicit shape. I could simply add a parameter to as_jax_op to force all output shapes to None, then it should work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will write more tests, then it will be clearer what I mean

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can't use shape unless it's static. Ideally it shouldn't fail for unknown shapes, but then the user has to tell user the output types.

We can allow the user to specify a make_node callable? That way it can be made to work with different dtypes/ndims if the jax function handles those fine

return (result,) # Pytensor requires a tuple here

# vector-jacobian product Op
class VJPSolOp(Op):
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nice follow up would be to also create a "ValueAndGrad" version of the Op that gets introduced in rewrites when both the Op and the VJP of Op (or JVP) are in the final graph.

This need not be a blocker for this PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see exactly what you mean. Is ValueAndGrad used by Pytensor? I searched the codebase but didn't find a mention of it. Does it have to do with implementing L_op? I haven't really understood the difference between it and grad

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX has the value and grad concept to more optimally compute both together. PyTensor doesn't have that concept because everything is lazy but we can exploit it during the rewrite phase.

If a user compiles a function that includes both forward and gradient of the same wrapped JAX Op, we could replace it by a third Op whose perform implementation requests jax to compute both.

This is not relevant when the autodiff is done in JAX, but it's relevant when it's done in PyTensor

jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
)

@jax_funcify.register(SolOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can dispatch on the base class just once?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This jax.funcify is once registering SolOp, once VJPSolOp. You mean, one could include the gradient calculation in SolOp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can define SolOp class outside the decorator and dispatch on that.

Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I didn't think of that. Have a look at whether I implemented it like you had envisioned

pytensor/link/jax/ops.py Outdated Show resolved Hide resolved
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
)

@jax_funcify.register(SolOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can define SolOp class outside the decorator and dispatch on that.

Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it

@ricardoV94
Copy link
Member

Big level picture. What's going on with the flattening of inputs and why is it needed?

@jdehning
Copy link
Author

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

@ricardoV94
Copy link
Member

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

And if I have a matrix input function will this work or expect it to be a vector instead?

@jdehning
Copy link
Author

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

And if I have a matrix input function will this work or expect it to be a vector instead?

It will work, it doesn't change pytensor.Variables, a matrix will stay a matrix. What it does, is to flatten nested python structure, e.g. {"a": tensor_a, "b": [tensor_b, tensor_c]} becomes [tensor_a, tensor_b, tensor_c] (and a treedef object which saves the structure of the tree), where tensor_x are three different tensors of potentially different shape and dtype. As pytensor operators accept a list of tensors as input, the flattened version can be used to define our op. The shapes of the tensors aren't changed. This is also basically how operators in JAX are written, see the second code box in this paragraph: https://jax.readthedocs.io/en/latest/autodidax.html#pytrees-and-flattening-user-functions-inputs-and-outputs

@jdehning
Copy link
Author

I would begin in parallel to write an example notebook. I opened an issue here

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, mostly related to tests and a couple of questions regarding PR scope.

Some of the advanced behaviors are a bit opaque from the outside, and I don't get if this is related to functionality that is actually needed (but it's perhaps easier to test like this) or we could do without for the sake of a simpler implementation.

I also have to try this locally, I'm curious how it behaves without static shapes on the inputs.

Overall, this is still looking great and very promising.

pytensor/link/jax/ops.py Outdated Show resolved Hide resolved
pytensor/link/jax/ops.py Outdated Show resolved Hide resolved
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]

x = pt.cumsum(x) # Now x has an unknown shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently an implementation detail, better to have x = tensor(..., shape=(None,)).

How does this work btw, what is out.type.shape?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your comment above, so I guess you are using PyTensor infer_shape stuff to figure out the output shape even if at write time cumsum did not.

However it will still not work if a root input has no static shape. I would suggest allowing users to define make_node of a JAX Op which exists exactly for this purpose. JAX doesn't have a concept of f(vector)->vector of unknown shape (because shapes are always concrete during tracing), but PyTensor is perfectly happy about this.

tests/link/jax/test_as_jax_op.py Show resolved Hide resolved
tests/link/jax/test_as_jax_op.py Outdated Show resolved Hide resolved
tests/link/jax/test_as_jax_op.py Outdated Show resolved Hide resolved
tests/link/jax/test_as_jax_op.py Outdated Show resolved Hide resolved

@as_jax_op
def f(x, y, message):
return x * jnp.ones(3), "Success: " + message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here, this output is just ignored? Do we need to support this sort of functionality?

Copy link
Author

@jdehning jdehning Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extended this test, the output can be used, but not by pytensor. We don't need to support this functionality, but it doesn't hurt much.

Copy link
Author

@jdehning jdehning Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is replaced by the test "test_pytree_input_with_non_graph_args"

fn, _ = compare_jax_and_py(fg, test_values)


def test_as_jax_op13():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this functionality?

@jdehning
Copy link
Author

Regarding the functionality and whether one should remove some of it for sake of simplicity: My goal was that diffrax.diffeqsolve can be easily wrapped, but I understand that this might not be the goal for an inclusion in pytensor. The wrapping of diffrax.diffeqsolve requires three parts:

  1. Infering shapes. That includes infering the dimensionality of (None,) dimensions in the pytensor graph by walking up the predecessors until a non-None dimension is found in in its parents; and infering the output shapes from the jax function. I think we want to keep this, as it drastically facilitates the usage of wrapper.
  2. Automatically dealing with non-numerical arguments and outputs. Some of the arguments of diffrax.diffeqsolve are not array-like objects, and also some of the returned values are non-array objects. The additional code for this functionality is not much, basically 4 lines: the pt_vars/static_vars partitioning with eqx.partition in line 99, the output partitioning in line 307, the input combination with eqx.combine in line 302 and the output combination in line 213. I would argue that this is quite useful: non-numerical arguments are quite common in jax functions (but non-numerical outputs less), and it eliminates a whole category of potential runtime errors, that is that the wrapper tries to either transform non-numerical variables to jax, or non-numerical outputs to pytensor variables.
  3. Allowing wrapped jax functions as arguments to a wrapped function. The reason I programmed it, is that if ODEs have time-dependent variables, one has to define the time-dependent variables as functions that interpolate between the variables between their definition timepoints. As this function has to be called from inside the system of differential equation defined in JAX, one cannot directly use a pytensor function; and also one cannot use a jax function if one wants to add the time-dependent variables to the pytensor graph. The workaround for time-dependent ODEs if we would remove this functionality would be to write everything in a JAX function: the definition of the interpolation function and the ODE solver, and wrap with @as_jax_op the whole function. This functionality does add quite a bit of complexity, namely the class _WrappedFunc and several eqx.partition and eqx.combine in the rest of the code. One could think to remove it, I don't have a strong opinion about it. I also don't think it is useful for other use cases besides differential equations.

One additional remark, removing the functionality of point 2 and 3 would also remove the additionally dependency on equinox, I don't know how relevant it is for the decision
I will go through you other remarks later.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 28, 2025

Infering shapes. That includes infering the dimensionality of (None,) dimensions in the pytensor graph by walking up the predecessors until a non-None dimension is found in in its parents; and infering the output shapes from the jax function. I think we want to keep this, as it drastically facilitates the usage of wrapper.

There's already something like that: infer_static_shape that's used in Ops like Alloc and RandomVariable where the output types are functions of the input values, not just their types. However, note that something like x = vector(shape=(None,)) is a fundamental valid PyTensor type and we shouldn't a-priori prevent jaxifying Ops with these types of inputs. It's also very common. All PyMC models with dims actually look like None shape variables, because those are allowed to change size.

I suggested allowing the user to specify make_node which is the Op API of specifying how input types translate into output types in PyTensor. The static-shape logic you're doing can be a default, but shouldn't be the only option because it's fundamentally limited.

Automatically dealing with non-numerical arguments and outputs.

My issue with non-numerical outputs is that, from reading the tests, are arbitrarily truncated? In that test where a JAX function has a string output. PyTensor is rather flexible in what types of variables it can accommodate, for instance we have string types implemented here: https://github.com/pymc-devs/pymc/blob/e0e751199319e68f376656e2477c1543606c49c7/pymc/pytensorf.py#L1101-L1116

PyTensor itself has sparse matrices, homogenous lists, slices, None, scalars ... As such it seems odd to me to support some extra types only on this JAX wrapper Op helper. If those types are deemed useful enough for this wrapper to handle them, then the case would be made we should add them as regular PyTensor types, and not-special case JAX.

I guess I'm just not clear as to what the wrapper is doing with these special inputs (I'm assuming outputs are just being ignored as I wrote above). For the inputs, it's creating a partial function on the perform method? Then it sounds like they should also be implemented as Op.__props__, which is the PyTensor API for parametrizing Ops with non-symbolic inputs. It uses this for nice debugprint and reasoning for stuff like two Ops with the same props and inputs can be considered equivalent and merged.

Allowing wrapped jax functions as arguments to a wrapped function.

Also seems somewhat similar to PyTensor Ops with inner functions (ScalarLoop, OpFromGraph, Scan), that compile inner PyTensor functions (or dispatched equivalents on backends like JAX).

I guess the common theme is that this PR may be reinventing several things that PyTensor already does (I could be wrong), and there may be room to reuse existing functionality, or expanding it so that it's not restricted to the JAX backend, and more specifically this wrapper. Let me know if any of this makes sense.

@jdehning
Copy link
Author

jdehning commented Feb 6, 2025

I added a ToDo list in the first post, so you can check the progress. I refactored the code with the help of Cursor AI, and now JAXOp can also be called directly, which can be used to specify undetermined output shapes. I also think it would be useful to have a Zoom meeting to have a better idea of which direction to go. You could for example write me via the Pymc discourse. Mondays and Tuesdays are quite full for me, but otherwise, I am generally available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement helper @as_jax_op to wrap JAX functions in PyTensor
2 participants