You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The code compiles properly on CPU, the bug is only triggered on neuron.
The code also compiles properly when commenting out the call to the reduce function, which indicates that the compiler can handle empty jax rrays in general but has problems when these empty arrays are embedded in a pytree.
The text was updated successfully, but these errors were encountered:
Another datapoint: the code executes correctly as well when removing the @jax.jit annotations to let the interpreter execute it instead of compiling it.
The current version of neuron compiler dont support zero-size array as input. We are looking for a solution internally. Meanwhile could you try to workaround the problem by avoid sending zero-size array to the TRN device?
The neuron compiler fails to compile the code below:
The code compiles properly on CPU, the bug is only triggered on neuron.
The code also compiles properly when commenting out the call to the reduce function, which indicates that the compiler can handle empty jax rrays in general but has problems when these empty arrays are embedded in a pytree.
The text was updated successfully, but these errors were encountered: