You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importjaximportjax.numpyasjnpfromjax._src.coreimportmutable_arraya=jnp.float32(0)
a_ref=mutable_array(a)
@jax.jit@jax.graddeff(x):
a_ref[()] =x# writing function input to a global mutable arrayreturn2*xx=jnp.float32(3.)
print(f(x))
Expected behaviour:
No error
Actual behaviour:
Traceback (most recent call last):
File "/home/ayx/dev/checkify/16.py", line 15, in <module>
print(f(x))
~^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 340, in cache_miss
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 740, in _infer_params
p, args_flat = _infer_params_impl(
~~~~~~~~~~~~~~~~~~^
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 629, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
~~~~~~~~~~~~~~~~~~^
flat_fun, in_type, attr_token, dbg,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
HashableFunction(res_paths, closure=()),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IgnoreKey(ji.inline))
^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 1310, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
~~~~~~~~~~~~~~~~~~~~~~~~~^
fun, in_type, debug_info=pe_debug)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2159, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
return _fun(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 72, in flatten_fun
ans = f(*py_args, **py_kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 652, in result_paths
ans = _fun(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 394, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 468, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
~~~~^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api.py", line 1975, in _vjp
out_primals, vjp = ad.vjp(flat_fun, primals_flat)
~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 252, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 237, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 574, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
~~~~~~~~~~~~~~~~^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 587, in trace_to_subjaxpr_nounits
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
~~~~~~~~~~~~~~~~~~~~~~~~~~^
f, trace, instantiate, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 616, in _trace_to_subjaxpr_nounits
ans = f(*in_args)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 72, in flatten_fun
ans = f(*py_args, **py_kwargs)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 78, in jvpfun
out_primals, out_tangents = f(tag, primals, tangents)
~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 115, in jvp_subtrace
ans = f(*in_tracers)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 88, in flatten_fun_nokwargs
ans = f(*py_args)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
return _fun(*args, **kwargs)
File "/home/ayx/dev/checkify/16.py", line 11, in f
a_ref[()] = x
~~~~~^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 1960, in __setitem__
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/types.py", line 353, in _setitem
return ref_set(tracer, idx, value)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 158, in ref_set
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 149, in ref_swap
return swap_p.bind(ref, value, *flat_transforms, tree=tree)
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 463, in bind
return self.bind_with_trace(prev_trace, args, params)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 425, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/state/primitives.py", line 443, in _swap_jvp
swap_p.bind(ref_tangent, x_tangent, *idx, **params))
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 463, in bind
return self.bind_with_trace(prev_trace, args, params)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 214, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 224, in default_process_primitive
tracers = map(self.instantiate_const, tracers)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 198, in instantiate_const
return self.new_instantiated_const(const)
~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 167, in new_instantiated_const
aval = get_aval(val)
File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/core.py", line 1477, in get_aval
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
TypeError: Argument 'Zero(Ref{float32[]})' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type
System info (python version, jaxlib version, accelerator, etc.)
Description
Expected behaviour:
No error
Actual behaviour:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: