Replies: 7 comments 5 replies
-
We have pm.sampling_jax so you can sampler with numpyro or blackjax using jax xla backend |
Beta Was this translation helpful? Give feedback.
-
Thanks. Ferrine. Yes pm.sampling_jax works wonderfully fast on the HPC. One problem I have is that ODE class does not yet possess an Op that is recognizable by JAX (within pymc or aesara) and I need it in the code that I am running (a sort of Lotka-Volterra problem). |
Beta Was this translation helpful? Give feedback.
-
thanks, the idea looks indeed promising. |
Beta Was this translation helpful? Give feedback.
-
is this example from the pymc site a good starting point? How to wrap a JAX function for use in PyMC |
Beta Was this translation helpful? Give feedback.
-
@fbarfi did you try |
Beta Was this translation helpful? Give feedback.
-
thank you @michaelosthege for the suggestion. I tried before to install sunode on my M1 Max but could not do it. On the other hand I made some progress on following @twiecki suggestion. I wrote an aesara Op using scipy odeint (and another one using the jax odeint) and it works great with pm.sample(). I am having some errors when I try to jaxify (for using sampling_jax) which I am trying to figure out -- the problem currently has to do with messages like TypeError: Shapes must be 1D sequences of concrete values of integer type, got (50, Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>). or The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[4])>with<DynamicJaxprTrace(level=0/1)> I am learning how to solve these issues. |
Beta Was this translation helpful? Give feedback.
-
I should have added that using odeint from Jax in my aesara Op is much faster than using odeint from scipy even without sampling_jax. |
Beta Was this translation helpful? Give feedback.
-
I am using pymc 4.0 on HPC server. It works fine but much slower than my laptop. Can one use a GPU backend? Thanks for any feedback.
Beta Was this translation helpful? Give feedback.
All reactions