Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with quaxed.numpy.select and unxt.Quantity #314

Open
adrn opened this issue Nov 28, 2024 · 3 comments
Open

Problem with quaxed.numpy.select and unxt.Quantity #314

adrn opened this issue Nov 28, 2024 · 3 comments

Comments

@adrn
Copy link
Contributor

adrn commented Nov 28, 2024

I haven't looked to debug this yet, but noticed in GalacticDynamics/coordinax#220:

from unxt import Quantity
import quaxed.numpy as jnp
z = Quantity([0, 1, 2], "m")
D = jnp.select([z == 0, z < 0, jnp.full(z.shape, 1, dtype=bool)], [z, z**2, z**3])
---------------------------------------------------------------------------
NotFoundLookupError                       Traceback (most recent call last)
File ~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py:193, in _QuaxTrace.process_primitive(self, primitive, tracers, params)
    192 try:
--> 193     method, _ = rule.resolve_method(values)
    194 except plum.NotFoundLookupError:

    [... skipping hidden 1 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/plum/function.py:342, in Function._handle_not_found_lookup_error(self, ex)
    340 if not self.owner:
    341     # Not in a class. Nothing we can do.
--> 342     raise ex from None
    344 # In a class. Walk through the classes in the class's MRO, except for this
    345 # class, and try to get the method.

    [... skipping hidden 1 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/plum/resolver.py:377, in Resolver.resolve(self, target)
    375 if len(candidates) == 0:
    376     # There is no matching signature.
--> 377     raise NotFoundLookupError(self.function_name, target, self.methods)
    379 elif len(candidates) == 1:
    380     # There is exactly one matching signature. Success!

NotFoundLookupError: `select_n_dispatcher(Array([1, 3, 3], dtype=int32), Array([0, 0, 0], dtype=int32), Quantity['length'](Array([0, 1, 
2], dtype=int32), unit='m'), Quantity['area'](Array([0, 1, 4], dtype=int32), unit='m2'), 
Quantity['volume'](Array([0, 1, 8], dtype=int32), unit='m3'))` could not be resolved.

Closest candidates are the following:
    select_n_dispatcher(which: typing.Union[ArrayLike], *cases: unxt._src.quantity.base.AbstractQuantity) ->       
    unxt._src.quantity.base.AbstractQuantity                                                                       
        <function _select_n_p_jqq at 0x1118f29e0> @                                                                
    �]8;id=730899;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3482�\~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/�]8;;�\�]8;id=483197;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3482�\_decorator.py�]8;;�\�]8;id=730899;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3482�\:3482�]8;;�\                           
    select_n_dispatcher(which: unxt._src.quantity.base.AbstractQuantity, *cases:                                   
    unxt._src.quantity.base.AbstractQuantity) -> unxt._src.quantity.base.AbstractQuantity                          
        <function _select_n_p at 0x1118f1ea0> @                                                                    
    �]8;id=731280;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3448�\~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/�]8;;�\�]8;id=642350;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3448�\_decorator.py�]8;;�\�]8;id=731280;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3448�\:3448�]8;;�\                           
    select_n_dispatcher(which: typing.Union[ArrayLike], case0: typing.Union[ArrayLike], case1:                     
    unxt._src.quantity.base.AbstractQuantity) -> unxt._src.quantity.base.AbstractQuantity                          
        <function _select_n_p_jjq at 0x1118f2440> @                                                                
    �]8;id=403415;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3466�\~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/�]8;;�\�]8;id=875187;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3466�\_decorator.py�]8;;�\�]8;id=403415;file:///Users/aprice-whelan/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py#3466�\:3466�]8;;�\                           


During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[5], line 3
      1 z = Quantity([0, 1, 2], "m")
----> 3 D = jnp.select(
      4     [z == 0, z < 0, jnp.full(z.shape, True, dtype=bool)],
      5     [z, z**2, z**3],
      6 )
      7 D

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py:324, in _Quaxify.__call__(self, *args, **kwargs)
    318 dynamic = jtu.tree_map(
    319     ft.partial(_wrap_tracer, trace),
    320     dynamic,
    321     is_leaf=_is_value,
    322 )
    323 fn, args, kwargs = eqx.combine(dynamic, static)
--> 324 out = fn(*args, **kwargs)
    325 out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
    326 return out

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2676, in select(condlist, choicelist, default)
   2674 conditions = stack(broadcast_arrays(False, *condlist))
   2675 idx = argmax(conditions.astype(bool), axis=0)
-> 2676 return lax.select_n(*broadcast_arrays(idx, *choicelist))

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/lax/lax.py:1231, in select_n(which, *cases)
   1229 if len(cases) == 0:
   1230   raise ValueError("select_n() must have at least one case")
-> 1231 return select_n_p.bind(which, *cases)

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/core.py:438, in Primitive.bind(self, *args, **params)
    435 def bind(self, *args, **params):
    436   assert (not config.enable_checks.value or
    437           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 438   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/core.py:442, in Primitive.bind_with_trace(self, trace, args, params)
    440 def bind_with_trace(self, trace, args, params):
    441   with pop_level(trace.level):
--> 442     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    443   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py:195, in _QuaxTrace.process_primitive(self, primitive, tracers, params)
    193     method, _ = rule.resolve_method(values)
    194 except plum.NotFoundLookupError:
--> 195     out = _default_process(primitive, values, params)
    196 else:
    197     out = method(*values, **params)

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py:133, in _default_process(primitive, values, params)
    130 # Avoid an infinite loop, by pushing a new interpreter to the dynamic interpreter
    131 # stack.
    132 with jax.ensure_compile_time_eval():
--> 133     return default(primitive, values, params)

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py:445, in Value.default(primitive, values, params)
    443 for x in values:
    444     if _is_value(x):
--> 445         arrays.append(x.materialise())
    446     elif eqx.is_array_like(x):
    447         arrays.append(cast(ArrayLike, x))

    [... skipping hidden 2 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py:449, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
    446             raise TypeCheckError(msg) from e
    448 # Actually call the function.
--> 449 out = fn(*args, **kwargs)
    451 if full_signature.return_annotation is not inspect.Signature.empty:
    452     # Now type-check the return value. We need to include the
    453     # parameters in the type-checking here in case there are any
   (...)
    464     # checking of the parameters. Unfortunately there doesn't seem
    465     # to be a way around that, so c'est la vie.
    466     kwargs[output_name] = out

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/unxt/_src/quantity/base.py:280, in AbstractQuantity.materialise(self)
    278 def materialise(self) -> NoReturn:
    279     msg = "Refusing to materialise `Quantity`."
--> 280     raise RuntimeError(msg)

RuntimeError: Refusing to materialise `Quantity`.
@nstarman
Copy link
Contributor

Looks like we need to register a corresponding select_n_p dispatch.

@nstarman nstarman transferred this issue from GalacticDynamics/quaxed Dec 3, 2024
@nstarman
Copy link
Contributor

nstarman commented Dec 3, 2024

When I do

from unxt import Quantity
import quaxed.numpy as jnp
z = jnp.array([0, 1, 2])
D = jnp.select([z == 0, z < 0, jnp.full(z.shape, 1, dtype=bool)], [z, z**2, z**3])

The result is Array([0, 1, 8], dtype=int32). This means that the result array will mix quantities of different dimensions, which isn't allowed.

@nstarman
Copy link
Contributor

Close?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants