Skip to content

Commit

Permalink
Import segment_max from jax.ops
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Apr 12, 2024
1 parent 1bbf624 commit 0cf8c27
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 40 deletions.
12 changes: 6 additions & 6 deletions src/lcm/argmax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
from jax import Array
from jax.ops import segment_max

# ======================================================================================
# argmax
Expand Down Expand Up @@ -130,17 +130,17 @@ def segment_argmax(
"""
# Compute segment maximum and bring to the same shape as data
# ==================================================================================
segment_max = jax.ops.segment_max(
segment_maximum = segment_max(
data=data,
segment_ids=segment_ids,
num_segments=num_segments,
indices_are_sorted=True,
)
segment_max_expanded = segment_max[segment_ids]
segment_maximum_expanded = segment_maximum[segment_ids]

# Check where the array attains its maximum
# ==================================================================================
max_value_mask = data == segment_max_expanded
max_value_mask = data == segment_maximum_expanded

# Create index array of argmax indices for each segment (has same shape as data)
# ==================================================================================
Expand All @@ -156,11 +156,11 @@ def segment_argmax(
# ----------------------------------------------------------------------------------
# Note: If multiple maxima exist, this approach will select the last index.
# ==================================================================================
segment_argmax = jax.ops.segment_max(
segment_argmax = segment_max(
data=max_value_indices,
segment_ids=segment_ids,
num_segments=num_segments,
indices_are_sorted=True,
)

return segment_argmax, segment_max
return segment_argmax, segment_maximum
41 changes: 9 additions & 32 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Functions that aggregate the conditional continuation values over discrete choices.
"""Functions that reduce the conditional continuation values over discrete choices.
By conditional continuation value we mean continuation values conditional on a discrete
choice, i.e. the result of solving the continuous choice problem conditional on the
Expand All @@ -25,6 +25,7 @@
import jax.numpy as jnp
import pandas as pd
from jax import Array
from jax.ops import segment_max

from lcm.typing import SegmentInfo

Expand Down Expand Up @@ -54,7 +55,7 @@ def get_solve_discrete_problem(
Returns:
callable: Function that calculates the expected maximum of the conditional
continuation values. The function depends on `cc_values` (jax.Array), the
conditional continuation values, and returns the aggregated values.
conditional continuation values, and returns the reduced values.
"""
if is_last_period:
Expand Down Expand Up @@ -100,7 +101,7 @@ def _solve_discrete_problem_no_shocks(
params: See `get_solve_discrete_problem`.
Returns:
jax.Array: Array with aggregated continuation values. Has less dimensions than
jax.Array: Array with reduced continuation values. Has less dimensions than
cc_values if choice_axes is not None and is shorter in the first dimension
if choice_segments is not None.
Expand All @@ -109,39 +110,15 @@ def _solve_discrete_problem_no_shocks(
if choice_axes is not None:
out = out.max(axis=choice_axes)
if choice_segments is not None:
out = _segment_max_over_first_axis(out, choice_segments)
out = segment_max(
data=out,
indices_are_sorted=True,
**choice_segments,
)

return out


def _segment_max_over_first_axis(
data: Array,
segment_info: SegmentInfo,
) -> Array:
"""Calculate a segment_max over the first axis of data.
Wrapper around ``jax.ops.segment_max``.
Args:
data (jax.Array): Multidimensional jax array.
segment_info (SegmentInfo): Dictionary with the entries "segment_ids" and
"num_segments". segment_ids are a 1d integer array that partitions the
first dimension of `data` into segments over which we need to aggregate.
"num_segments" is the number of segments. The segment_ids are assumed to be
sorted.
Returns:
jax.Array: An array with shape (num_segments,) + data.shape[1:] representing the
segment maximums.
"""
return jax.ops.segment_max(
data=data,
indices_are_sorted=True,
**segment_info,
)


# ======================================================================================
# Discrete problem with extreme value shocks
# --------------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions tests/test_discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import jax.numpy as jnp
import pandas as pd
import pytest
from jax.ops import segment_max
from lcm.discrete_problem import (
_calculate_emax_extreme_value_shocks,
_determine_dense_discrete_choice_axes,
_segment_extreme_value_emax_over_first_axis,
_segment_logsumexp,
_segment_max_over_first_axis,
_solve_discrete_problem_no_shocks,
get_solve_discrete_problem,
)
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_segment_max_over_first_axis_illustrative():
"segment_ids": jnp.array([0, 0, 1, 1]),
"num_segments": 2,
}
got = _segment_max_over_first_axis(a, segment_info=segment_info)
got = segment_max(a, indices_are_sorted=True, **segment_info)
expected = jnp.array([1, 3])
aaae(got, expected)

Expand Down

0 comments on commit 0cf8c27

Please sign in to comment.