Skip to content

Commit

Permalink
some jax core functionalities are deprecated (#1966)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jan 30, 2025
1 parent 7514990 commit 5465982
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
except ImportError:
import jax.linear_util as lu

try:
from jax.extend.core.primitives import call_p, closed_call_p
except ImportError:
from jax.core import call_p, closed_call_p

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p

Expand Down Expand Up @@ -96,15 +101,15 @@ def track_deps_call_rule(eqn, provenance_inputs):
return track_deps_jaxpr(eqn.params["call_jaxpr"], provenance_inputs)


track_deps_rules[core.call_p] = track_deps_call_rule
track_deps_rules[call_p] = track_deps_call_rule
track_deps_rules[xla_pmap_p] = track_deps_call_rule


def track_deps_closed_call_rule(eqn, provenance_inputs):
return track_deps_jaxpr(eqn.params["call_jaxpr"].jaxpr, provenance_inputs)


track_deps_rules[core.closed_call_p] = track_deps_closed_call_rule
track_deps_rules[closed_call_p] = track_deps_closed_call_rule


def track_deps_pjit_rule(eqn, provenance_inputs):
Expand Down

0 comments on commit 5465982

Please sign in to comment.