Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation error when processing pytrees with empty leaves #1087

Open
benoitsteiner opened this issue Jan 13, 2025 · 3 comments
Open

Compilation error when processing pytrees with empty leaves #1087

benoitsteiner opened this issue Jan 13, 2025 · 3 comments

Comments

@benoitsteiner
Copy link

The neuron compiler fails to compile the code below:

import jax

@flax.struct.dataclass
class Tree:
    t1: jax.Array
    t2: jax.Array

tree1 = Tree(jax.numpy.zeros([0]), jax.numpy.zeros([3]))
tree2 = Tree(jax.numpy.zeros([0]), jax.numpy.zeros([3]))

@jax.jit
def reduce_one(*args):
    return sum(args)

print(reduce_one(tree1.t1, tree2.t1))

@jax.jit
def reduce(*args):
    return jax.tree.map(lambda *args: sum(args), *args)

print(reduce(tree1, tree2))

The code compiles properly on CPU, the bug is only triggered on neuron.
The code also compiles properly when commenting out the call to the reduce function, which indicates that the compiler can handle empty jax rrays in general but has problems when these empty arrays are embedded in a pytree.

@benoitsteiner
Copy link
Author

Another datapoint: the code executes correctly as well when removing the @jax.jit annotations to let the interpreter execute it instead of compiling it.

@aws-liuayu
Copy link

Thanks for filing the issue. We will looking into it and get back to you.

@aws-zhehongb
Copy link

The current version of neuron compiler dont support zero-size array as input. We are looking for a solution internally. Meanwhile could you try to workaround the problem by avoid sending zero-size array to the TRN device?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants