Skip to content

Commit

Permalink
remove unnecessary device_put (#1952)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jan 17, 2025
1 parent 8a67269 commit 4704656
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from typing import Any, Callable

from jax import device_put, lax
from jax import lax

from numpyro import handlers
from numpyro.ops.pytree import PytreeTrace
Expand Down Expand Up @@ -69,7 +69,7 @@ def cond_wrapper(

wrapped_true_fun = wrap_fn(true_fun, substitute_stack)
wrapped_false_fun = wrap_fn(false_fun, substitute_stack)
wrapped_operand = device_put((rng_key, operand))
wrapped_operand = (rng_key, operand)
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)


Expand Down
5 changes: 2 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, Optional

import jax
from jax import device_put, lax, random
from jax import lax, random
import jax.numpy as jnp

from numpyro import handlers
Expand Down Expand Up @@ -228,7 +228,6 @@ def body_fn(wrapped_carry, x, prefix=None):
# return early if length = unroll_steps
if length == unroll_steps:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = jax.tree.map(device_put, wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
)
Expand Down Expand Up @@ -331,7 +330,7 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = jax.tree.map(device_put, (0, rng_key, init))
wrapped_carry = (jnp.asarray(0), rng_key, init)
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
wa_state,
rng_key,
)
return jax.device_put(init_state)
return init_state

def postprocess_fn(self, args, kwargs):
if self._postprocess_fn is None:
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings

from jax import device_put, lax, random, vmap
from jax import lax, random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

Expand Down Expand Up @@ -359,7 +359,7 @@ def init_kernel(
wa_state,
rng_key_hmc,
)
return device_put(hmc_state)
return hmc_state

def _hmc_next(
step_size,
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from jax import device_put, grad, jacfwd, random, value_and_grad
from jax import grad, jacfwd, random, value_and_grad
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.scipy.special import expit
Expand Down Expand Up @@ -148,7 +148,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):

z = {**gibbs_sites, **hmc_state.z}

return device_put(HMCGibbsState(z, hmc_state, rng_key))
return HMCGibbsState(z, hmc_state, rng_key)

def sample(self, state, model_args, model_kwargs):
model_kwargs = {} if model_kwargs is None else model_kwargs
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections import namedtuple

from jax import device_put, lax, random, vmap
from jax import lax, random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.scipy.special import logsumexp
Expand Down Expand Up @@ -174,7 +174,7 @@ def init_kernel(
adapt_state,
rng_key_sa,
)
return device_put(sa_state)
return sa_state

def sample_kernel(sa_state, model_args=(), model_kwargs=None):
pe_fn = potential_fn
Expand Down
4 changes: 2 additions & 2 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tqdm.auto import tqdm as tqdm_auto

import jax
from jax import device_put, jit, lax, vmap
from jax import jit, lax, vmap
from jax.core import Tracer
from jax.experimental import io_callback
import jax.numpy as jnp
Expand Down Expand Up @@ -386,7 +386,7 @@ def loop_fn(collection):
diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "")

vals = (init_val, collection, device_put(start_idx), device_put(thinning))
vals = (init_val, collection, jnp.asarray(start_idx), jnp.asarray(thinning))

if upper == 0:
# special case, only compiling
Expand Down

0 comments on commit 4704656

Please sign in to comment.