Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Writing function input to global mutable array fails inside jax.grad #26361

Open
ayaka14732 opened this issue Feb 6, 2025 · 0 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

import jax
import jax.numpy as jnp
from jax._src.core import mutable_array

a = jnp.float32(0)
a_ref = mutable_array(a)

@jax.jit
@jax.grad
def f(x):
    a_ref[()] = x  # writing function input to a global mutable array
    return 2 * x

x = jnp.float32(3.)
print(f(x))

Expected behaviour:

No error

Actual behaviour:

Traceback (most recent call last):
  File "/home/ayx/dev/checkify/16.py", line 15, in <module>
    print(f(x))
          ~^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                     ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 740, in _infer_params
    p, args_flat = _infer_params_impl(
                   ~~~~~~~~~~~~~~~~~~^
        fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 629, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
                                              ~~~~~~~~~~~~~~~~~~^
        flat_fun, in_type, attr_token, dbg,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        HashableFunction(res_paths, closure=()),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        IgnoreKey(ji.inline))
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
    ans = call(fun, *args)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 1310, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ~~~~~~~~~~~~~~~~~~~~~~~~~^
        fun, in_type, debug_info=pe_debug)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2159, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
    return _fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 72, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 652, in result_paths
    ans = _fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 394, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 468, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
                  ~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 1975, in _vjp
    out_primals, vjp = ad.vjp(flat_fun, primals_flat)
                       ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 252, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
                                        ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 237, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
                               ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 574, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
                                      ~~~~~~~~~~~~~~~~^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 587, in trace_to_subjaxpr_nounits
    out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
                                          ~~~~~~~~~~~~~~~~~~~~~~~~~~^
        f, trace, instantiate, in_pvals)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 616, in _trace_to_subjaxpr_nounits
    ans = f(*in_args)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 72, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 78, in jvpfun
    out_primals, out_tangents = f(tag, primals, tangents)
                                ~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 115, in jvp_subtrace
    ans = f(*in_tracers)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 88, in flatten_fun_nokwargs
    ans = f(*py_args)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
    return _fun(*args, **kwargs)
  File "/home/ayx/dev/checkify/16.py", line 11, in f
    a_ref[()] = x
    ~~~~~^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 1960, in __setitem__
    def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
                                          ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/types.py", line 353, in _setitem
    return ref_set(tracer, idx, value)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 158, in ref_set
    ref_swap(ref_or_view, idx, value, _function_name="ref_set")
    ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 149, in ref_swap
    return swap_p.bind(ref, value, *flat_transforms, tree=tree)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 425, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
                              ~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 443, in _swap_jvp
    swap_p.bind(ref_tangent, x_tangent, *idx, **params))
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 214, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 224, in default_process_primitive
    tracers = map(self.instantiate_const, tracers)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 198, in instantiate_const
    return self.new_instantiated_const(const)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 167, in new_instantiated_const
    aval = get_aval(val)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 1477, in get_aval
    raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
TypeError: Argument 'Zero(Ref{float32[]})' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.13.0rc3 (main, Oct  2 2024, 17:18:08) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
@ayaka14732 ayaka14732 added the bug Something isn't working label Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants