Skip to content

Commit

Permalink
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
Browse files Browse the repository at this point in the history
The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 460996305
  • Loading branch information
Jake VanderPlas authored and JAX-CFD authors committed Jul 14, 2022
1 parent c059a3f commit 8e24c0f
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion jax_cfd/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def split_axis(
def concat_along_axis(pytrees, axis):
"""Concatenates `pytrees` along `axis`."""
concat_leaves_fn = lambda *args: jnp.concatenate(args, axis)
return jax.tree_multimap(concat_leaves_fn, *pytrees)
return jax.tree_map(concat_leaves_fn, *pytrees)


def block_reduce(
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def sum_fields(*args):
return jax.tree_multimap(lambda *a: sum(a), *args)
return jax.tree_map(lambda *a: sum(a), *args)


def stable_time_step(
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/funcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def init_context():

def _tree_stack(trees: Sequence[PyTree]) -> PyTree:
if trees:
return tree_util.tree_multimap(lambda *xs: jnp.stack(xs), *trees)
return tree_util.tree_map(lambda *xs: jnp.stack(xs), *trees)
else:
return trees

Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/subgrid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def evm_model(
for j in range(grid.ndim)]
for i in range(grid.ndim)])
viscosity = viscosity_fn(s_ij, v)
tau = jax.tree_multimap(lambda x, y: -2. * x * y, viscosity, s_ij)
tau = jax.tree_map(lambda x, y: -2. * x * y, viscosity, s_ij)
return tuple(-finite_differences.divergence( # pylint: disable=g-complex-comprehension
tuple(grids.GridVariable(t, bc) # use velocity bc to compute diverence
for t in tau[i, :]))
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/collocated/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def sum_fields(*args):
return jax.tree_multimap(lambda *a: sum(a), *args)
return jax.tree_map(lambda *a: sum(a), *args)


def semi_implicit_navier_stokes(
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/ml/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def learned_corrector(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(next_state)
return jax.tree_multimap(lambda x, y: x + y, next_state, corrections)
return jax.tree_map(lambda x, y: x + y, next_state, corrections)

return hk.to_module(step_fn)()

Expand All @@ -242,7 +242,7 @@ def learned_corrector_v2(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(state)
return jax.tree_multimap(lambda x, y: x + dt * y, next_state, corrections)
return jax.tree_map(lambda x, y: x + dt * y, next_state, corrections)

return hk.to_module(step_fn)()

Expand All @@ -262,6 +262,6 @@ def learned_corrector_v3(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(tuple(state) + tuple(next_state))
return jax.tree_multimap(lambda x, y: x + dt * y, next_state, corrections)
return jax.tree_map(lambda x, y: x + dt * y, next_state, corrections)

return hk.to_module(step_fn)()
2 changes: 1 addition & 1 deletion jax_cfd/ml/time_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def euler_integrator(
"""
def _single_step(state, _):
deriv = derivative_module(state)
next_state = jax.tree_multimap(lambda x, dxdt: x + dt * dxdt, state, deriv)
next_state = jax.tree_map(lambda x, dxdt: x + dt * dxdt, state, deriv)
return next_state, next_state

return hk.scan(_single_step, initial_state, None, num_steps)
4 changes: 2 additions & 2 deletions jax_cfd/ml/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def add_noise_to_input_frame(
# TODO(dkochkov) add `split_like` method to `array_utils.py`.
rngs = jax.tree_unflatten(jax.tree_structure(time_zero_slice), rngs)
noise_fn = lambda key, s: scale * jax.random.truncated_normal(key, -2., 2., s)
noise = jax.tree_multimap(noise_fn, rngs, shapes)
noise = jax.tree_map(noise_fn, rngs, shapes)
add_noise_fn = lambda x, n: x.at[:, 0, ...].add(n)
return jax.tree_multimap(add_noise_fn, batch, noise)
return jax.tree_map(add_noise_fn, batch, noise)


def preprocess(
Expand Down

0 comments on commit 8e24c0f

Please sign in to comment.