Skip to content

Commit

Permalink
Improve readability of dispatchers
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 13, 2025
1 parent a5b900a commit 4199a06
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 71 deletions.
72 changes: 25 additions & 47 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from jax import Array, vmap

from lcm.functools import allow_args, allow_only_kwargs
from lcm.utils import find_duplicates

F = TypeVar("F", bound=Callable[..., Array])


def spacemap(
func: F,
product_vars: list[str],
combination_vars: list[str] | None = None,
product_vars: tuple[str, ...],
combination_vars: tuple[str, ...],
) -> F:
"""Apply vmap such that func can be evaluated on product and combination variables.
Expand Down Expand Up @@ -48,52 +49,32 @@ def spacemap(
described above but there might be additional dimensions.
"""
# Check inputs and prepare function
# ==================================================================================
duplicates = {v for v in product_vars if product_vars.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in product variables: {duplicates}",
if duplicates := find_duplicates(product_vars, combination_vars):
msg = (
"Same argument provided more than once in product variables or combination "
f"variables, or is present in both: {duplicates}"
)
raise ValueError(msg)

func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, product_vars)

if combination_vars:
overlap = set(product_vars).intersection(combination_vars)
if overlap:
raise ValueError(
"Product and combination variables must be disjoint. Overlap: "
f"{overlap}",
)

duplicates = {v for v in combination_vars if combination_vars.count(v) > 1}
if duplicates:
raise ValueError(
"Same argument provided more than once in combination variables: "
f"{duplicates}",
)

# jax.vmap cannot deal with keyword-only arguments
func = allow_args(func)

# Apply vmap_1d for combination variables and _base_productmap for product variables
# ==================================================================================
if not combination_vars:
vmapped = _base_productmap(func, product_vars)
else:
vmapped = _base_productmap(func, product_vars)
vmapped = vmap_1d(
vmapped, variables=combination_vars, callable_with="only_args"
)

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined]
vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]

return allow_only_kwargs(vmapped)


def vmap_1d(
func: F,
variables: list[str],
variables: tuple[str, ...],
*,
callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs",
) -> F:
Expand All @@ -105,7 +86,7 @@ def vmap_1d(
Args:
func: The function to be dispatched.
variables: List with names of arguments that over which we map.
variables: Tuple with names of arguments that over which we map.
callable_with: Whether to apply the allow_kwargs decorator to the dispatched
function. If "only_args", the returned function can only be called with
positional arguments. If "only_kwargs", the returned function can only be
Expand All @@ -123,8 +104,7 @@ def vmap_1d(
described above but there might be additional dimensions.
"""
duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
if duplicates := find_duplicates(variables):
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)
Expand Down Expand Up @@ -161,7 +141,7 @@ def vmap_1d(
return out


def productmap(func: F, variables: list[str]) -> F:
def productmap(func: F, variables: tuple[str, ...]) -> F:
"""Apply vmap such that func is evaluated on the Cartesian product of variables.
This is achieved by an iterative application of vmap.
Expand All @@ -171,7 +151,7 @@ def productmap(func: F, variables: list[str]) -> F:
Args:
func: The function to be dispatched.
variables: List with names of arguments that over which the Cartesian product
variables: Tuple with names of arguments that over which the Cartesian product
should be formed.
Returns:
Expand All @@ -185,33 +165,31 @@ def productmap(func: F, variables: list[str]) -> F:
described above but there might be additional dimensions.
"""
func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments

duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
if duplicates := find_duplicates(variables):
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)

signature = inspect.signature(func)
vmapped = _base_productmap(func, variables)
func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, variables)

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = signature # type: ignore[attr-defined]
vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]

return allow_only_kwargs(vmapped)


def _base_productmap(func: F, product_axes: list[str]) -> F:
def _base_productmap(func: F, product_axes: tuple[str, ...]) -> F:
"""Map func over the Cartesian product of product_axes.
Like vmap, this function does not preserve the function signature and does not allow
the function to be called with keyword arguments.
Args:
func: The function to be dispatched. Cannot have keyword-only arguments.
product_axes: List with names of arguments over which we apply vmap.
product_axes: Tuple with names of arguments over which we apply vmap.
Returns:
A callable with the same arguments as func. See `product_map` for details.
Expand Down
4 changes: 2 additions & 2 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def create_compute_conditional_continuation_value(
if continuous_choice_variables:
utility_and_feasibility = productmap(
func=utility_and_feasibility,
variables=continuous_choice_variables,
variables=tuple(continuous_choice_variables),
)

@functools.wraps(utility_and_feasibility)
Expand Down Expand Up @@ -236,7 +236,7 @@ def create_compute_conditional_continuation_policy(
if continuous_choice_variables:
utility_and_feasibility = productmap(
func=utility_and_feasibility,
variables=continuous_choice_variables,
variables=tuple(continuous_choice_variables),
)

@functools.wraps(utility_and_feasibility)
Expand Down
5 changes: 3 additions & 2 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lcm import grid_helpers
from lcm.exceptions import GridInitializationError, format_messages
from lcm.typing import Scalar
from lcm.utils import find_duplicates


class Grid(ABC):
Expand Down Expand Up @@ -188,12 +189,12 @@ def _validate_discrete_grid(category_class: type) -> None:

values = list(names_and_values.values())

duplicated_values = [v for v in values if values.count(v) > 1]
duplicated_values = find_duplicates(values)
if duplicated_values:
error_messages.append(
"Field values of the category_class passed to DiscreteGrid must be unique. "
"The following values are duplicated: "
f"{set(duplicated_values)}"
f"{duplicated_values}"
)

if values != list(range(len(values))):
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_multiply_weights(stochastic_variables):
callable
"""
arg_names = [f"weight_next_{var}" for var in stochastic_variables]
arg_names = tuple(f"weight_next_{var}" for var in stochastic_variables)

@with_signature(args=arg_names)
def _outer(*args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions src/lcm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def solve_continuous_problem(
"""
_gridmapped = spacemap(
func=compute_ccv,
product_vars=list(data_scs.choices),
combination_vars=list(data_scs.states),
product_vars=tuple(data_scs.choices),
combination_vars=tuple(data_scs.states),
)
gridmapped = jax.jit(_gridmapped)

Expand Down Expand Up @@ -372,7 +372,7 @@ def _generate_simulation_keys(key, ids):
# ======================================================================================


@partial(vmap_1d, variables=["ccv_policy", "discrete_argmax"])
@partial(vmap_1d, variables=("ccv_policy", "discrete_argmax"))
def filter_ccv_policy(
ccv_policy,
discrete_argmax,
Expand Down
3 changes: 2 additions & 1 deletion src/lcm/solution/solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def solve_continuous_problem(
"""
_gridmapped = spacemap(
func=compute_ccv,
product_vars=list(state_choice_space.ordered_var_names),
product_vars=state_choice_space.ordered_var_names,
combination_vars=(),
)
gridmapped = jax.jit(_gridmapped)

Expand Down
12 changes: 12 additions & 0 deletions src/lcm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from collections import Counter
from collections.abc import Iterable
from itertools import chain
from typing import TypeVar

T = TypeVar("T")


def find_duplicates(*containers: Iterable[T]) -> set[T]:
combined = chain.from_iterable(containers)
counts = Counter(combined)
return {v for v, count in counts.items() if count > 1}
24 changes: 12 additions & 12 deletions tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_productmap_with_all_arguments_mapped(func, args, grids, expected, reque


def test_productmap_with_positional_args(setup_productmap_f):
decorated = productmap(f, ["a", "b", "c"])
decorated = productmap(f, ("a", "b", "c"))
match = (
"This function has been decorated so that it allows only kwargs, but was "
"called with positional arguments."
Expand All @@ -113,10 +113,10 @@ def test_productmap_with_positional_args(setup_productmap_f):


def test_productmap_different_func_order(setup_productmap_f):
decorated_f = productmap(f, ["a", "b", "c"])
decorated_f = productmap(f, ("a", "b", "c"))
expected = decorated_f(**setup_productmap_f)

decorated_f2 = productmap(f2, ["a", "b", "c"])
decorated_f2 = productmap(f2, ("a", "b", "c"))
calculated_f2 = decorated_f2(**setup_productmap_f)

aaae(calculated_f2, expected)
Expand All @@ -125,7 +125,7 @@ def test_productmap_different_func_order(setup_productmap_f):
def test_productmap_change_arg_order(setup_productmap_f, expected_productmap_f):
expected = jnp.transpose(expected_productmap_f, (1, 0, 2))

decorated = productmap(f, ["b", "a", "c"])
decorated = productmap(f, ("b", "a", "c"))
calculated = decorated(**setup_productmap_f)

aaae(calculated, expected)
Expand All @@ -142,7 +142,7 @@ def test_productmap_with_all_arguments_mapped_some_len_one():

expected = allow_args(f)(*helper).reshape(1, 1, 5)

decorated = productmap(f, ["a", "b", "c"])
decorated = productmap(f, ("a", "b", "c"))
calculated = decorated(**grids)
aaae(calculated, expected)

Expand All @@ -154,7 +154,7 @@ def test_productmap_with_all_arguments_mapped_some_scalar():
"c": jnp.linspace(1, 5, 5),
}

decorated = productmap(f, ["a", "b", "c"])
decorated = productmap(f, ("a", "b", "c"))
with pytest.raises(ValueError, match="vmap was requested to map its argument"):
decorated(**grids)

Expand All @@ -170,15 +170,15 @@ def test_productmap_with_some_arguments_mapped():

expected = allow_args(f)(*helper).reshape(10, 5)

decorated = productmap(f, ["a", "c"])
decorated = productmap(f, ("a", "c"))
calculated = decorated(**grids)
aaae(calculated, expected)


def test_productmap_with_some_argument_mapped_twice():
error_msg = "Same argument provided more than once."
with pytest.raises(ValueError, match=error_msg):
productmap(f, ["a", "a", "c"])
productmap(f, ("a", "a", "c"))


# ======================================================================================
Expand Down Expand Up @@ -233,8 +233,8 @@ def test_spacemap_all_arguments_mapped(

decorated = spacemap(
g,
list(product_vars),
list(combination_vars),
tuple(product_vars),
tuple(combination_vars),
)
calculated = decorated(**product_vars, **combination_vars)

Expand All @@ -245,12 +245,12 @@ def test_spacemap_all_arguments_mapped(
("error_msg", "product_vars", "combination_vars"),
[
(
"Product and combination variables must be disjoint. Overlap: {'a'}",
"Same argument provided more than once in product variables or combination",
["a", "b"],
["a", "c", "d"],
),
(
"Same argument provided more than once in product variables: {'a'}",
"Same argument provided more than once in product variables or combination",
["a", "a", "b"],
["c", "d"],
),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_function_evaluator():
# create a value function array
discrete_part = jnp.arange(4).repeat(6 * 7).reshape((2, 2, 6, 7)) * 100

cont_func = productmap(lambda x, y: x + y, ["x", "y"])
cont_func = productmap(lambda x, y: x + y, ("x", "y"))
cont_part = cont_func(x=jnp.linspace(100, 1100, 6), y=jnp.linspace(-3, 3, 7))

vf_arr = discrete_part + cont_part
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_get_interpolator():
def _utility(wealth, working):
return 2 * wealth - working

prod_utility = productmap(_utility, variables=["wealth", "working"])
prod_utility = productmap(_utility, variables=("wealth", "working"))

values = prod_utility(
wealth=jnp.arange(4, dtype=float),
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_get_interpolator_illustrative():
def f(a, b):
return a - b

prod_f = productmap(f, variables=["a", "b"])
prod_f = productmap(f, variables=("a", "b"))

values = prod_f(a=jnp.arange(2, dtype=float), b=jnp.arange(3, dtype=float))

Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from lcm.utils import find_duplicates


def test_find_duplicates_singe_container_no_duplicates():
assert find_duplicates([1, 2, 3, 4, 5]) == set()


def test_find_duplicates_single_container_with_duplicates():
assert find_duplicates([1, 2, 3, 4, 5, 5]) == {5}


def test_find_duplicates_multiple_containers_no_duplicates():
assert find_duplicates([1, 2, 3, 4, 5], [6, 7, 8, 9, 10]) == set()


def test_find_duplicates_multiple_containers_with_duplicates():
assert find_duplicates([1, 2, 3, 4, 5, 5], [6, 7, 8, 9, 10, 5]) == {5}

0 comments on commit 4199a06

Please sign in to comment.