From 39a271d2b8543b400ee52460ce8e754d4f11a203 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:24:04 +0100 Subject: [PATCH 01/27] Refactor example models --- .pre-commit-config.yaml | 1 + MANIFEST.in | 5 +- examples/README.md | 22 ++++ examples/long_running.py | 121 ++++++++++++++++++ pyproject.toml | 5 +- .../example_models/testing_example_models.py | 95 -------------- src/lcm/get_model.py | 2 +- src/lcm/test.py | 8 -- tests/test_entry_point.py | 21 +-- tests/test_long.py | 13 -- tests/test_model_functions.py | 5 +- .../test_models}/__init__.py | 0 .../test_models/phelps_deaton.py | 105 +++++++-------- .../test_models/stochastic.py | 39 +++++- tests/test_next_state.py | 3 +- tests/test_process_model.py | 13 +- tests/test_simulate.py | 11 +- tests/test_state_space.py | 3 +- tests/test_stochastic.py | 10 +- 19 files changed, 271 insertions(+), 211 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/long_running.py delete mode 100644 src/lcm/example_models/testing_example_models.py delete mode 100644 src/lcm/test.py delete mode 100644 tests/test_long.py rename {src/lcm/example_models => tests/test_models}/__init__.py (100%) rename src/lcm/example_models/basic_example_models.py => tests/test_models/phelps_deaton.py (53%) rename src/lcm/example_models/stochastic_example_models.py => tests/test_models/stochastic.py (67%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d1270fa..28ef322c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,7 @@ repos: - id: name-tests-test args: - --pytest-test-first + exclude: ^tests/test_models/ - id: no-commit-to-branch args: - --branch diff --git a/MANIFEST.in b/MANIFEST.in index ce3bc0f2..1a6451bb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -21,10 +21,11 @@ exclude *.yml exclude *.pickle exclude pytask.ini -prune src/lcm/sandbox +prune .envs +prune examples prune docs +prune src/lcm/sandbox prune tests -prune .envs global-exclude __pycache__ global-exclude *.py[co] diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..b7e51f44 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,22 @@ +# Example model specifications + +## Choosing an example + +| Example name | Description | Runtime | +| -------------- | -------------------------------------------------- | ------------- | +| `long_running` | Consumptions-savings model with health and leisure | a few minutes | + +## Running an example + +Say you want to run the [`long_running`](./long_running.py) example locally. In a Python +shell, execute: + +```python +from lcm.entry_point import get_lcm_function + +from long_running import MODEL_CONFIG, PARAMS + + +solve_model, _ = get_lcm_function(model=MODEL_CONFIG) +vf_arr = solve_model(PARAMS) +``` diff --git a/examples/long_running.py b/examples/long_running.py new file mode 100644 index 00000000..4502e763 --- /dev/null +++ b/examples/long_running.py @@ -0,0 +1,121 @@ +"""Example specification for a consumption-savings model with health and leisure.""" + +import jax.numpy as jnp + +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} + +RETIREMENT_AGE = 65 + +# ====================================================================================== +# Model functions +# ====================================================================================== + + +# -------------------------------------------------------------------------------------- +# Utility function +# -------------------------------------------------------------------------------------- +def utility(consumption, working, health, sport, delta): + return jnp.log(consumption) - (delta - health) * working - sport + + +# -------------------------------------------------------------------------------------- +# Auxiliary variables +# -------------------------------------------------------------------------------------- +def working(leisure): + return 1 - leisure + + +def wage(age): + return 1 + 0.1 * age + + +def age(_period): + return _period + 18 + + +# -------------------------------------------------------------------------------------- +# State transitions +# -------------------------------------------------------------------------------------- +def next_wealth(wealth, consumption, working, wage, interest_rate): + return (1 + interest_rate) * (wealth - consumption) + wage * working + + +def next_health(health, sport, working): + return health * (1 + sport - working / 2) + + +def next_wealth_with_shock( + wealth, + consumption, + working, + wage, + wage_shock, + interest_rate, +): + return interest_rate * (wealth - consumption) + wage * wage_shock * working + + +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- +def consumption_constraint(consumption, wealth): + return consumption <= wealth + + +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + +MODEL_CONFIG = { + "functions": { + "utility": utility, + "next_wealth": next_wealth, + "consumption_constraint": consumption_constraint, + "working": working, + "wage": wage, + "age": age, + "next_health": next_health, + }, + "choices": { + "leisure": {"options": [0, 1]}, + "consumption": { + "grid_type": "linspace", + "start": 1, + "stop": 100, + "n_points": N_GRID_POINTS["choices"], + }, + "sport": { + "grid_type": "linspace", + "start": 0, + "stop": 1, + "n_points": N_GRID_POINTS["choices"], + }, + }, + "states": { + "wealth": { + "grid_type": "linspace", + "start": 1, + "stop": 100, + "n_points": N_GRID_POINTS["states"], + }, + "health": { + "grid_type": "linspace", + "start": 0, + "stop": 1, + "n_points": N_GRID_POINTS["states"], + }, + }, + "n_periods": RETIREMENT_AGE - 18, +} + +PARAMS = { + "beta": 0.95, + "utility": {"delta": 0.05}, + "next_wealth": {"interest_rate": 0.05}, +} diff --git a/pyproject.toml b/pyproject.toml index bfba52a5..7bcb07db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ write_to = "src/lcm/_version.py" [tool.ruff] target-version = "py311" fix = true + +[tool.ruff.lint] select = ["ALL"] extend-ignore = [ # missing type annotation @@ -57,6 +59,7 @@ extend-ignore = [ [tool.ruff.lint.per-file-ignores] "docs/source/conf.py" = ["E501", "ERA001", "DTZ005"] "tests/test_*.py" = ["PLR2004"] +"examples/*" = ["INP001"] [tool.ruff.lint.pydocstyle] convention = "google" @@ -75,7 +78,7 @@ markers = [ "slow: Tests that take a long time to run and are skipped in continuous integration.", "illustrative: Tests are designed for illustrative purposes", ] -norecursedirs = ["docs", ".envs"] +norecursedirs = ["docs", ".envs", "tests/test_models"] [tool.yamlfix] diff --git a/src/lcm/example_models/testing_example_models.py b/src/lcm/example_models/testing_example_models.py deleted file mode 100644 index 097207bf..00000000 --- a/src/lcm/example_models/testing_example_models.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Define example model specifications.""" - -import jax.numpy as jnp - -RETIREMENT_AGE = 65 -N_CHOICE_GRID_POINTS = 200 -N_STATE_GRID_POINTS = 100 - - -def phelps_deaton_utility(consumption, working, health, sport, delta): - return jnp.log(consumption) - (delta - health) * working - sport - - -def working(retirement): - return 1 - retirement - - -def next_wealth_with_shock( - wealth, - consumption, - working, - wage, - wage_shock, - interest_rate, -): - return interest_rate * (wealth - consumption) + wage * wage_shock * working - - -def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth - consumption) + wage * working - - -def next_health(health, sport, working): - return health * (1 + sport - working / 2) - - -def consumption_constraint(consumption, wealth): - return consumption <= wealth - - -def wage(age): - return 1 + 0.1 * age - - -def age(_period): - return _period + 18 - - -PHELPS_DEATON = { - "functions": { - "utility": phelps_deaton_utility, - "next_wealth": next_wealth, - "consumption_constraint": consumption_constraint, - "working": working, - "wage": wage, - "age": age, - "next_health": next_health, - }, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, - }, - "sport": { - "grid_type": "linspace", - "start": 0, - "stop": 1, - "n_points": N_CHOICE_GRID_POINTS, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 100, - "n_points": N_STATE_GRID_POINTS, - }, - "health": { - "grid_type": "linspace", - "start": 0, - "stop": 1, - "n_points": N_STATE_GRID_POINTS, - }, - }, - "n_periods": RETIREMENT_AGE - 18, -} - -PARAMS = { - "beta": 0.95, - "utility": {"delta": 0.05}, - "next_wealth": {"interest_rate": 0.05}, -} diff --git a/src/lcm/get_model.py b/src/lcm/get_model.py index 4e0f37e4..774115f3 100644 --- a/src/lcm/get_model.py +++ b/src/lcm/get_model.py @@ -4,7 +4,7 @@ from pybaum import tree_update -from lcm.example_models.basic_example_models import ( +from tests.test_models.phelps_deaton import ( PHELPS_DEATON, PHELPS_DEATON_WITH_FILTERS, ) diff --git a/src/lcm/test.py b/src/lcm/test.py deleted file mode 100644 index dceb1619..00000000 --- a/src/lcm/test.py +++ /dev/null @@ -1,8 +0,0 @@ -from lcm.logger import get_logger - -if __name__ == "__main__": - logger = get_logger(debug_mode=False) - - logger.info("This is an info message.") - logger.debug("This is a debug message.") - logger.warning("This is a warning message.") diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 64fa6e49..05a9b833 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -5,17 +5,18 @@ create_compute_conditional_continuation_value, get_lcm_function, ) -from lcm.example_models.basic_example_models import ( - PHELPS_DEATON, - PHELPS_DEATON_FULLY_DISCRETE, - PHELPS_DEATON_WITH_FILTERS, - phelps_deaton_utility, -) from lcm.model_functions import get_utility_and_feasibility_function from lcm.process_model import process_model from lcm.state_space import create_state_choice_space from pybaum import tree_equal, tree_map +from tests.test_models.phelps_deaton import ( + PHELPS_DEATON, + PHELPS_DEATON_FULLY_DISCRETE, + PHELPS_DEATON_WITH_FILTERS, + utility, +) + MODELS = { "simple": PHELPS_DEATON, "with_filters": PHELPS_DEATON_WITH_FILTERS, @@ -175,7 +176,7 @@ def test_create_compute_conditional_continuation_value(): params=params, vf_arr=None, ) - assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, delta=1.0) def test_create_compute_conditional_continuation_value_with_discrete_model(): @@ -218,7 +219,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): params=params, vf_arr=None, ) - assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, delta=1.0) # ====================================================================================== @@ -267,7 +268,7 @@ def test_create_compute_conditional_continuation_policy(): vf_arr=None, ) assert policy == 2 - assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, delta=1.0) def test_create_compute_conditional_continuation_policy_with_discrete_model(): @@ -311,4 +312,4 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): vf_arr=None, ) assert policy == 1 - assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, delta=1.0) diff --git a/tests/test_long.py b/tests/test_long.py deleted file mode 100644 index b7dacabc..00000000 --- a/tests/test_long.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest -from lcm.entry_point import get_lcm_function -from lcm.example_models.testing_example_models import PARAMS, PHELPS_DEATON - -SKIP_REASON = """The test is designed to run approximately 1 minute on a standard -laptop, such that we can differentiate the performance of running LCM on a GPU versus -on the CPU.""" - - -@pytest.mark.skip(reason=SKIP_REASON) -def test_long(): - solve_model, template = get_lcm_function(PHELPS_DEATON, targets="solve") - solve_model(PARAMS) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 197ee350..f2a0a7c4 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -1,6 +1,5 @@ import jax.numpy as jnp import pandas as pd -from lcm.example_models.basic_example_models import PHELPS_DEATON, phelps_deaton_utility from lcm.interfaces import Model from lcm.model_functions import ( get_combined_constraint, @@ -11,6 +10,8 @@ from lcm.state_space import create_state_choice_space from numpy.testing import assert_array_equal +from tests.test_models.phelps_deaton import PHELPS_DEATON, utility + def test_get_combined_constraint(): def f(): @@ -82,7 +83,7 @@ def test_get_utility_and_feasibility_function(): assert_array_equal( u, - phelps_deaton_utility( + utility( consumption=consumption, working=1 - retirement, delta=1.0, diff --git a/src/lcm/example_models/__init__.py b/tests/test_models/__init__.py similarity index 100% rename from src/lcm/example_models/__init__.py rename to tests/test_models/__init__.py diff --git a/src/lcm/example_models/basic_example_models.py b/tests/test_models/phelps_deaton.py similarity index 53% rename from src/lcm/example_models/basic_example_models.py rename to tests/test_models/phelps_deaton.py index 8ae9a442..06184a10 100644 --- a/src/lcm/example_models/basic_example_models.py +++ b/tests/test_models/phelps_deaton.py @@ -1,65 +1,71 @@ -"""Define example model specifications.""" +"""Example specifications of the Phelps-Deaton model.""" import jax.numpy as jnp -RETIREMENT_AGE = 65 -N_CHOICE_GRID_POINTS = 500 -N_STATE_GRID_POINTS = 100 +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} -def phelps_deaton_utility_with_shock( - consumption, - working, - delta, - additive_utility_shock, -): - return jnp.log(consumption) + additive_utility_shock - delta * working +RETIREMENT_AGE = 65 +# ====================================================================================== +# Model functions +# ====================================================================================== -def phelps_deaton_utility(consumption, working, delta): + +# -------------------------------------------------------------------------------------- +# Utility functions +# -------------------------------------------------------------------------------------- +def utility(consumption, working, delta): return jnp.log(consumption) - delta * working -def phelps_deaton_utility_with_filter( +def utility_with_filter( consumption, working, delta, lagged_retirement, # noqa: ARG001 ): - return jnp.log(consumption) - delta * working + return utility(consumption=consumption, working=working, delta=delta) +# -------------------------------------------------------------------------------------- +# Auxiliary variables +# -------------------------------------------------------------------------------------- def working(retirement): return 1 - retirement -def next_wealth_with_shock( - wealth, - consumption, - working, - wage, - wage_shock, - interest_rate, -): - return interest_rate * (wealth - consumption) + wage * wage_shock * working +def wage(age): + return 1 + 0.1 * age + + +def age(_period): + return _period + 18 +# -------------------------------------------------------------------------------------- +# State transitions +# -------------------------------------------------------------------------------------- def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- def consumption_constraint(consumption, wealth): return consumption <= wealth -def wage(age): - return 1 + 0.1 * age - - -def age(_period): - return _period + 18 - - +# -------------------------------------------------------------------------------------- +# Filters +# -------------------------------------------------------------------------------------- def mandatory_retirement_filter(retirement, age): return jnp.logical_or(retirement == 1, age < RETIREMENT_AGE) @@ -68,9 +74,13 @@ def absorbing_retirement_filter(retirement, lagged_retirement): return jnp.logical_or(retirement == 1, lagged_retirement == 0) +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + PHELPS_DEATON = { "functions": { - "utility": phelps_deaton_utility, + "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -83,7 +93,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -91,7 +101,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, @@ -100,7 +110,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): PHELPS_DEATON_FULLY_DISCRETE = { "functions": { - "utility": phelps_deaton_utility, + "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -114,33 +124,16 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, } -PHELPS_DEATON_WITH_SHOCKS = { - **PHELPS_DEATON, - "functions": { - "utility": phelps_deaton_utility_with_shock, - "next_wealth": next_wealth_with_shock, - "consumption_constraint": consumption_constraint, - "working": working, - }, - "shocks": { - "wage_shock": "lognormal", - # special name to signal that this shock can be set to zero to calculate - # expected utility - "additive_utility_shock": "extreme_value", - }, -} - - PHELPS_DEATON_WITH_FILTERS = { "functions": { - "utility": phelps_deaton_utility_with_filter, + "utility": utility_with_filter, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -153,7 +146,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -161,7 +154,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, "lagged_retirement": {"options": [0, 1]}, }, diff --git a/src/lcm/example_models/stochastic_example_models.py b/tests/test_models/stochastic.py similarity index 67% rename from src/lcm/example_models/stochastic_example_models.py rename to tests/test_models/stochastic.py index fb139424..0e8cb3ad 100644 --- a/src/lcm/example_models/stochastic_example_models.py +++ b/tests/test_models/stochastic.py @@ -1,21 +1,39 @@ -"""Define example model specifications.""" +"""Example specifications of a simple Phelps-Deaton style stochastic model.""" import jax.numpy as jnp - import lcm -N_CHOICE_GRID_POINTS = 500 -N_STATE_GRID_POINTS = 100 +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== + +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} +# ====================================================================================== +# Model functions +# ====================================================================================== + +# -------------------------------------------------------------------------------------- +# Utility function +# -------------------------------------------------------------------------------------- def utility(consumption, working, health, partner, delta, gamma): # noqa: ARG001 return jnp.log(consumption) + (gamma * health - delta) * working +# -------------------------------------------------------------------------------------- +# Deterministic state transitions +# -------------------------------------------------------------------------------------- def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working +# -------------------------------------------------------------------------------------- +# Stochastic state transitions +# -------------------------------------------------------------------------------------- @lcm.mark.stochastic def next_health(health, partner): # noqa: ARG001 pass @@ -26,11 +44,18 @@ def next_partner(_period, working, partner): # noqa: ARG001 pass +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- def consumption_constraint(consumption, wealth): return consumption <= wealth -MODEL = { +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + +MODEL_CONFIG = { "functions": { "utility": utility, "next_wealth": next_wealth, @@ -44,7 +69,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -54,7 +79,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 6ff5c5fa..572ad993 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -1,11 +1,12 @@ import jax.numpy as jnp import pandas as pd -from lcm.example_models.basic_example_models import PHELPS_DEATON from lcm.interfaces import Model from lcm.next_state import _get_stochastic_next_func, get_next_state_function from lcm.process_model import process_model from pybaum import tree_equal +from tests.test_models.phelps_deaton import PHELPS_DEATON + # ====================================================================================== # Solve target # ====================================================================================== diff --git a/tests/test_process_model.py b/tests/test_process_model.py index 846f3818..e5ecb684 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -4,12 +4,6 @@ import numpy as np import pandas as pd import pytest -from lcm.example_models.basic_example_models import ( - N_CHOICE_GRID_POINTS, - N_STATE_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) from lcm.interfaces import GridSpec from lcm.mark import StochasticInfo from lcm.process_model import ( @@ -23,6 +17,13 @@ from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal +from tests.test_models.phelps_deaton import ( + N_CHOICE_GRID_POINTS, + N_STATE_GRID_POINTS, + PHELPS_DEATON, + PHELPS_DEATON_WITH_FILTERS, +) + @pytest.fixture() def user_model(): diff --git a/tests/test_simulate.py b/tests/test_simulate.py index bbfb772f..e3c0287f 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -6,11 +6,6 @@ create_compute_conditional_continuation_policy, get_lcm_function, ) -from lcm.example_models.basic_example_models import ( - N_CHOICE_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) from lcm.logging import get_logger from lcm.model_functions import get_utility_and_feasibility_function from lcm.next_state import _get_next_state_function_simulation @@ -32,6 +27,12 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal from pybaum import tree_equal +from tests.test_models.phelps_deaton import ( + N_CHOICE_GRID_POINTS, + PHELPS_DEATON, + PHELPS_DEATON_WITH_FILTERS, +) + # ====================================================================================== # Test simulate using raw inputs # ====================================================================================== diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 8f9bc33c..a5a5b8e0 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -2,7 +2,6 @@ import numpy as np import pandas as pd import pytest -from lcm.example_models.basic_example_models import PHELPS_DEATON_WITH_FILTERS from lcm.interfaces import Model from lcm.process_model import process_model from lcm.state_space import ( @@ -14,6 +13,8 @@ ) from numpy.testing import assert_array_almost_equal as aaae +from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS + def test_create_state_choice_space(): _model = process_model(PHELPS_DEATON_WITH_FILTERS) diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 2a0223cc..636247a2 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -5,7 +5,8 @@ from lcm.entry_point import ( get_lcm_function, ) -from lcm.example_models.stochastic_example_models import MODEL, PARAMS + +from tests.test_models.stochastic import MODEL_CONFIG, PARAMS # ====================================================================================== # Simulate @@ -13,7 +14,10 @@ def test_get_lcm_function_with_simulate_target(): - simulate_model, _ = get_lcm_function(model=MODEL, targets="solve_and_simulate") + simulate_model, _ = get_lcm_function( + model=MODEL_CONFIG, + targets="solve_and_simulate", + ) res = simulate_model( PARAMS, @@ -47,7 +51,7 @@ def test_get_lcm_function_with_simulate_target(): def test_get_lcm_function_with_solve_target(): - solve_model, _ = get_lcm_function(model=MODEL) + solve_model, _ = get_lcm_function(model=MODEL_CONFIG) solve_model(PARAMS) From 7ee660c16ba9b08dfae22746d94c0ea16d2c9644 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:26:31 +0100 Subject: [PATCH 02/27] Do not make unnecessary ignores --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7bcb07db..1cedd9f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ markers = [ "slow: Tests that take a long time to run and are skipped in continuous integration.", "illustrative: Tests are designed for illustrative purposes", ] -norecursedirs = ["docs", ".envs", "tests/test_models"] +norecursedirs = ["docs", ".envs"] [tool.yamlfix] From b6c9eb1c697ff3baf93803fb0df24564e920da1d Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:30:22 +0100 Subject: [PATCH 03/27] Fix tests --- tests/test_process_model.py | 11 +++++------ tests/test_simulate.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_process_model.py b/tests/test_process_model.py index e5ecb684..2ce0c291 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -18,8 +18,7 @@ from pandas.testing import assert_frame_equal from tests.test_models.phelps_deaton import ( - N_CHOICE_GRID_POINTS, - N_STATE_GRID_POINTS, + N_GRID_POINTS, PHELPS_DEATON, PHELPS_DEATON_WITH_FILTERS, ) @@ -123,14 +122,14 @@ def test_process_phelps_deaton_with_filters(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_STATE_GRID_POINTS}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["states"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 1, "stop": 100, "n_points": N_CHOICE_GRID_POINTS}, + specs={"start": 1, "stop": 100, "n_points": N_GRID_POINTS["choices"]}, ) assert model.gridspecs["consumption"] == consumption_specs @@ -174,14 +173,14 @@ def test_process_phelps_deaton(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_STATE_GRID_POINTS}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["states"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_CHOICE_GRID_POINTS}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["choices"]}, ) assert model.gridspecs["consumption"] == consumption_specs diff --git a/tests/test_simulate.py b/tests/test_simulate.py index e3c0287f..aef9eaa0 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -28,7 +28,7 @@ from pybaum import tree_equal from tests.test_models.phelps_deaton import ( - N_CHOICE_GRID_POINTS, + N_GRID_POINTS, PHELPS_DEATON, PHELPS_DEATON_WITH_FILTERS, ) @@ -69,7 +69,7 @@ def simulate_inputs(): return { "state_indexers": [{}], "continuous_choice_grids": [ - {"consumption": jnp.linspace(1, 100, num=N_CHOICE_GRID_POINTS)}, + {"consumption": jnp.linspace(1, 100, num=N_GRID_POINTS["choices"])}, ], "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, From 2053bce1a2383b8b4e320718742287bc4c448d61 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:33:58 +0100 Subject: [PATCH 04/27] Change no. of grid points back to previous level --- tests/test_models/phelps_deaton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index 06184a10..7f251ee7 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -8,7 +8,7 @@ N_GRID_POINTS = { "states": 100, - "choices": 200, + "choices": 500, } RETIREMENT_AGE = 65 From 9a0937624e48dd29f5628cc6b3d1fe215213c764 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:44:23 +0100 Subject: [PATCH 05/27] Fix typo in examples/README.md --- examples/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/README.md b/examples/README.md index b7e51f44..a8e6038e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,9 +2,9 @@ ## Choosing an example -| Example name | Description | Runtime | -| -------------- | -------------------------------------------------- | ------------- | -| `long_running` | Consumptions-savings model with health and leisure | a few minutes | +| Example name | Description | Runtime | +| -------------- | ------------------------------------------------- | ------------- | +| `long_running` | Consumption-savings model with health and leisure | a few minutes | ## Running an example From 1102de49354023d3debb5244364e9291ed383401 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:02:11 +0100 Subject: [PATCH 06/27] Update examples/README.md Co-authored-by: Hans-Martin von Gaudecker --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index a8e6038e..34a23960 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,7 @@ ## Running an example -Say you want to run the [`long_running`](./long_running.py) example locally. In a Python +Say you want to solve the [`long_running`](./long_running.py) example locally. In a Python shell, execute: ```python From d721fcb0cd28991b02fb59f8f2d397ae0a3a69ce Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:02:32 +0100 Subject: [PATCH 07/27] Update examples/long_running.py Co-authored-by: Hans-Martin von Gaudecker --- examples/long_running.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/long_running.py b/examples/long_running.py index 4502e763..cf9dbf24 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -46,7 +46,7 @@ def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working -def next_health(health, sport, working): +def next_health(health, exercise, working): return health * (1 + sport - working / 2) From 20cdd61b2ba92c0c93ed6fcea36a0dd9e04f808a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:03:10 +0000 Subject: [PATCH 08/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/README.md b/examples/README.md index 34a23960..7b86f009 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,8 +8,8 @@ ## Running an example -Say you want to solve the [`long_running`](./long_running.py) example locally. In a Python -shell, execute: +Say you want to solve the [`long_running`](./long_running.py) example locally. In a +Python shell, execute: ```python from lcm.entry_point import get_lcm_function From fc4812d82a9464abe654a347259ad425b6920293 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:06:10 +0100 Subject: [PATCH 09/27] Update examples/long_running.py Co-authored-by: Hans-Martin von Gaudecker --- examples/long_running.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/long_running.py b/examples/long_running.py index cf9dbf24..3039ca0d 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -43,7 +43,7 @@ def age(_period): # State transitions # -------------------------------------------------------------------------------------- def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth - consumption) + wage * working + return (1 + interest_rate) * (wealth + working * wage - consumption) def next_health(health, exercise, working): From 53411d96fd4f234576c5a2cbb5ea2cf62e5900de Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:07:46 +0100 Subject: [PATCH 10/27] Rename sport -> exericse --- examples/long_running.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index cf9dbf24..d244583f 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -20,8 +20,8 @@ # -------------------------------------------------------------------------------------- # Utility function # -------------------------------------------------------------------------------------- -def utility(consumption, working, health, sport, delta): - return jnp.log(consumption) - (delta - health) * working - sport +def utility(consumption, working, health, exercise, delta): + return jnp.log(consumption) - (delta - health) * working - exercise # -------------------------------------------------------------------------------------- @@ -47,18 +47,7 @@ def next_wealth(wealth, consumption, working, wage, interest_rate): def next_health(health, exercise, working): - return health * (1 + sport - working / 2) - - -def next_wealth_with_shock( - wealth, - consumption, - working, - wage, - wage_shock, - interest_rate, -): - return interest_rate * (wealth - consumption) + wage * wage_shock * working + return health * (1 + exercise - working / 2) # -------------------------------------------------------------------------------------- @@ -90,7 +79,7 @@ def consumption_constraint(consumption, wealth): "stop": 100, "n_points": N_GRID_POINTS["choices"], }, - "sport": { + "exercise": { "grid_type": "linspace", "start": 0, "stop": 1, From dcaf4a65d3c37537979c9e86a361cfd01f8585e1 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:23:46 +0100 Subject: [PATCH 11/27] Implement suggestions from review --- examples/long_running.py | 14 ++++++++------ tests/test_models/phelps_deaton.py | 14 +++++++------- tests/test_models/stochastic.py | 8 ++++---- tests/test_process_model.py | 8 ++++---- tests/test_simulate.py | 2 +- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index 3b370f6d..7a089ec5 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -6,8 +6,10 @@ # Numerical parameters and constants # ====================================================================================== N_GRID_POINTS = { - "states": 100, - "choices": 200, + "wealth": 100, + "health": 100, + "consumption": 100, + "exericse": 200, } RETIREMENT_AGE = 65 @@ -77,13 +79,13 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_GRID_POINTS["choices"], + "n_points": N_GRID_POINTS["consumption"], }, "exercise": { "grid_type": "linspace", "start": 0, "stop": 1, - "n_points": N_GRID_POINTS["choices"], + "n_points": N_GRID_POINTS["exercise"], }, }, "states": { @@ -91,13 +93,13 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["wealth"], }, "health": { "grid_type": "linspace", "start": 0, "stop": 1, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["health"], }, }, "n_periods": RETIREMENT_AGE - 18, diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index 7f251ee7..18f1afc9 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -7,8 +7,8 @@ # ====================================================================================== N_GRID_POINTS = { - "states": 100, - "choices": 500, + "wealth": 100, + "consumption": 500, } RETIREMENT_AGE = 65 @@ -93,7 +93,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_GRID_POINTS["choices"], + "n_points": N_GRID_POINTS["consumption"], }, }, "states": { @@ -101,7 +101,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["wealth"], }, }, "n_periods": 3, @@ -124,7 +124,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["wealth"], }, }, "n_periods": 3, @@ -146,7 +146,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_GRID_POINTS["choices"], + "n_points": N_GRID_POINTS["consumption"], }, }, "states": { @@ -154,7 +154,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["wealth"], }, "lagged_retirement": {"options": [0, 1]}, }, diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 0e8cb3ad..00eeae35 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -8,8 +8,8 @@ # ====================================================================================== N_GRID_POINTS = { - "states": 100, - "choices": 200, + "wealth": 100, + "consumption": 200, } # ====================================================================================== @@ -69,7 +69,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_GRID_POINTS["choices"], + "n_points": N_GRID_POINTS["consumption"], }, }, "states": { @@ -79,7 +79,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_GRID_POINTS["states"], + "n_points": N_GRID_POINTS["wealth"], }, }, "n_periods": 3, diff --git a/tests/test_process_model.py b/tests/test_process_model.py index 2ce0c291..06d9a28d 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -122,14 +122,14 @@ def test_process_phelps_deaton_with_filters(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["states"]}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 1, "stop": 100, "n_points": N_GRID_POINTS["choices"]}, + specs={"start": 1, "stop": 100, "n_points": N_GRID_POINTS["consumption"]}, ) assert model.gridspecs["consumption"] == consumption_specs @@ -173,14 +173,14 @@ def test_process_phelps_deaton(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["states"]}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["choices"]}, + specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["consumption"]}, ) assert model.gridspecs["consumption"] == consumption_specs diff --git a/tests/test_simulate.py b/tests/test_simulate.py index aef9eaa0..7e14c14f 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -69,7 +69,7 @@ def simulate_inputs(): return { "state_indexers": [{}], "continuous_choice_grids": [ - {"consumption": jnp.linspace(1, 100, num=N_GRID_POINTS["choices"])}, + {"consumption": jnp.linspace(1, 100, num=N_GRID_POINTS["consumption"])}, ], "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, From 8769ad090355a1caea2d354515b507008bb534bf Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:40:36 +0100 Subject: [PATCH 12/27] Reference issue #30 --- tests/test_models/phelps_deaton.py | 4 +++- tests/test_models/stochastic.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index 18f1afc9..c20367ca 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -29,7 +29,9 @@ def utility_with_filter( consumption, working, delta, - lagged_retirement, # noqa: ARG001 + # Temporary workaround for bug described in issue #30, which requires us to pass + # all state variables to the utility function. + lagged_retirement, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. ): return utility(consumption=consumption, working=working, delta=delta) diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 00eeae35..71e2f375 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -20,7 +20,16 @@ # -------------------------------------------------------------------------------------- # Utility function # -------------------------------------------------------------------------------------- -def utility(consumption, working, health, partner, delta, gamma): # noqa: ARG001 +def utility( + consumption, + working, + health, + # Temporary workaround for bug described in issue #30, which requires us to pass + # all state variables to the utility function. + partner, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. + delta, + gamma, +): return jnp.log(consumption) + (gamma * health - delta) * working From 249c75c2f2fa3bfe52da3b1da7fc61008b4748f7 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 15:59:03 +0100 Subject: [PATCH 13/27] Add reference to DC-EGM paper --- tests/test_models/phelps_deaton.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index c20367ca..6f34c67b 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -1,4 +1,11 @@ -"""Example specifications of the Phelps-Deaton model.""" +"""Example specifications of the Phelps-Deaton model. + +This specification corresponds to the example model presented in the paper: "The +endogenous grid method for discrete-continuous dynamic choice models with (or without) +taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning +(2017, https://doi.org/10.3982/QE643). + +""" import jax.numpy as jnp From a61c5dcd21b2b964c7c9ed595bc3a032d974217d Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 16:16:53 +0100 Subject: [PATCH 14/27] Rename delta -> disutility_of_work --- examples/long_running.py | 6 +++--- tests/test_models/phelps_deaton.py | 8 ++++---- tests/test_models/stochastic.py | 18 +++++++++++++----- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index 7a089ec5..f5dec21d 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -22,8 +22,8 @@ # -------------------------------------------------------------------------------------- # Utility function # -------------------------------------------------------------------------------------- -def utility(consumption, working, health, exercise, delta): - return jnp.log(consumption) - (delta - health) * working - exercise +def utility(consumption, working, health, exercise, disutility_of_work): + return jnp.log(consumption) - (disutility_of_work - health) * working - exercise # -------------------------------------------------------------------------------------- @@ -107,6 +107,6 @@ def consumption_constraint(consumption, wealth): PARAMS = { "beta": 0.95, - "utility": {"delta": 0.05}, + "utility": {"disutility_of_work": 0.05}, "next_wealth": {"interest_rate": 0.05}, } diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index 6f34c67b..acc193c8 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -28,19 +28,19 @@ # -------------------------------------------------------------------------------------- # Utility functions # -------------------------------------------------------------------------------------- -def utility(consumption, working, delta): - return jnp.log(consumption) - delta * working +def utility(consumption, working, disutility_of_work): + return jnp.log(consumption) - disutility_of_work * working def utility_with_filter( consumption, working, - delta, + disutility_of_work, # Temporary workaround for bug described in issue #30, which requires us to pass # all state variables to the utility function. lagged_retirement, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. ): - return utility(consumption=consumption, working=working, delta=delta) + return utility(consumption, working=working, disutility_of_work=disutility_of_work) # -------------------------------------------------------------------------------------- diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 71e2f375..8d4707d3 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -1,4 +1,13 @@ -"""Example specifications of a simple Phelps-Deaton style stochastic model.""" +"""Example specifications of a simple Phelps-Deaton style stochastic model. + +This specification is motivated by the example model presented in the paper: "The +endogenous grid method for discrete-continuous dynamic choice models with (or without) +taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning +(2017, https://doi.org/10.3982/QE643). + +See also the specifications in tests/test_models/phelps_deaton.py. + +""" import jax.numpy as jnp import lcm @@ -27,10 +36,9 @@ def utility( # Temporary workaround for bug described in issue #30, which requires us to pass # all state variables to the utility function. partner, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. - delta, - gamma, + disutility_of_work, ): - return jnp.log(consumption) + (gamma * health - delta) * working + return jnp.log(consumption) - (1 - health / 2) * disutility_of_work * working # -------------------------------------------------------------------------------------- @@ -97,7 +105,7 @@ def consumption_constraint(consumption, wealth): PARAMS = { "beta": 0.95, - "utility": {"delta": 0.5, "gamma": 0.25}, + "utility": {"disutility_of_work": 0.5}, "next_wealth": {"interest_rate": 0.05, "wage": 10.0}, "next_health": {}, "consumption_constraint": {}, From 0a9424742a77db1d75d0bd1f34586c1528ac8bbe Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 16:18:55 +0100 Subject: [PATCH 15/27] Rename delta -> disutility_of_work throughout --- src/lcm/get_model.py | 8 ++++---- tests/test_analytical_solution.py | 4 +++- tests/test_entry_point.py | 16 ++++++++-------- tests/test_model_functions.py | 4 ++-- tests/test_next_state.py | 2 +- tests/test_simulate.py | 24 ++++++++++++------------ tests/test_stochastic.py | 6 +++--- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/lcm/get_model.py b/src/lcm/get_model.py index 774115f3..61e932ca 100644 --- a/src/lcm/get_model.py +++ b/src/lcm/get_model.py @@ -104,7 +104,7 @@ def get_model(model: str): model=PHELPS_DEATON_FIVE_PERIODS, params={ "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -115,18 +115,18 @@ def get_model(model: str): model=ISKHAKOV_2017_FIVE_PERIODS, params={ "beta": 0.98, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.0, "wage": 20.0, }, }, ), - "iskhakov_2017_low_delta": ModelAndParams( + "iskhakov_2017_low_disutility_of_work": ModelAndParams( model=ISKHAKOV_2017_THREE_PERIODS, params={ "beta": 0.98, - "utility": {"delta": 0.1}, + "utility": {"disutility_of_work": 0.1}, "next_wealth": { "interest_rate": 0.0, "wage": 20.0, diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 2cf0abda..9fabaadd 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -9,7 +9,9 @@ TEST_CASES = { "iskhakov_2017_five_periods": get_model("iskhakov_2017_five_periods"), - "iskhakov_2017_low_delta": get_model("iskhakov_2017_low_delta"), + "iskhakov_2017_low_disutility_of_work": get_model( + "iskhakov_2017_low_disutility_of_work", + ), } diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 05a9b833..24db8d81 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -141,7 +141,7 @@ def test_create_compute_conditional_continuation_value(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -176,7 +176,7 @@ def test_create_compute_conditional_continuation_value(): params=params, vf_arr=None, ) - assert val == utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, disutility_of_work=1.0) def test_create_compute_conditional_continuation_value_with_discrete_model(): @@ -184,7 +184,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -219,7 +219,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): params=params, vf_arr=None, ) - assert val == utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, disutility_of_work=1.0) # ====================================================================================== @@ -232,7 +232,7 @@ def test_create_compute_conditional_continuation_policy(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -268,7 +268,7 @@ def test_create_compute_conditional_continuation_policy(): vf_arr=None, ) assert policy == 2 - assert val == utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, disutility_of_work=1.0) def test_create_compute_conditional_continuation_policy_with_discrete_model(): @@ -276,7 +276,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -312,4 +312,4 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): vf_arr=None, ) assert policy == 1 - assert val == utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, disutility_of_work=1.0) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index f2a0a7c4..dfa58cbd 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -46,7 +46,7 @@ def test_get_utility_and_feasibility_function(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -86,7 +86,7 @@ def test_get_utility_and_feasibility_function(): utility( consumption=consumption, working=1 - retirement, - delta=1.0, + disutility_of_work=1.0, ), ) assert_array_equal(f, jnp.array([True, True, False])) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 572ad993..f37f02c9 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -18,7 +18,7 @@ def test_get_next_state_function_with_solve_target(): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, }, diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 7e14c14f..adfc92b2 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -80,7 +80,7 @@ def simulate_inputs(): def test_simulate_using_raw_inputs(simulate_inputs): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, }, @@ -117,7 +117,7 @@ def _model_solution(n_periods): params = { "beta": 1.0, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, "wage": 1.0, @@ -183,7 +183,7 @@ def test_effect_of_beta_on_last_period(): params = { "beta": None, - "utility": {"delta": 1.0}, + "utility": {"disutility_of_work": 1.0}, "next_wealth": { "interest_rate": 0.05, }, @@ -193,7 +193,7 @@ def test_effect_of_beta_on_last_period(): params_low = params.copy() params_low["beta"] = 0.5 - # high delta + # high disutility_of_work params_high = params.copy() params_high["beta"] = 0.99 @@ -228,7 +228,7 @@ def test_effect_of_beta_on_last_period(): ).all() -def test_effect_of_delta(): +def test_effect_of_disutility_of_work(): model = {**PHELPS_DEATON, "n_periods": 5} # Model solutions @@ -237,19 +237,19 @@ def test_effect_of_delta(): params = { "beta": 1.0, - "utility": {"delta": None}, + "utility": {"disutility_of_work": None}, "next_wealth": { "interest_rate": 0.05, }, } - # low delta + # low disutility_of_work params_low = params.copy() - params_low["utility"]["delta"] = 0.2 + params_low["utility"]["disutility_of_work"] = 0.2 - # high delta + # high disutility_of_work params_high = params.copy() - params_high["utility"]["delta"] = 1.5 + params_high["utility"]["disutility_of_work"] = 1.5 # solutions solution_low = solve_model(params_low) @@ -325,7 +325,7 @@ def test_compute_targets(): } def f_a(a, params): - return a + params["delta"] + return a + params["disutility_of_work"] def f_b(b, params): # noqa: ARG001 return b @@ -336,7 +336,7 @@ def f_b(b, params): # noqa: ARG001 processed_results=processed_results, targets=["fa", "fb"], model_functions=model_functions, - params={"delta": -1.0}, + params={"disutility_of_work": -1.0}, ) expected = { "fa": jnp.arange(3) - 1.0, diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 636247a2..4615eaa7 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -62,8 +62,8 @@ def test_get_lcm_function_with_solve_target(): @pytest.fixture() def model_and_params(): - def utility(consumption, working, health, delta, gamma): - return jnp.log(consumption) + (gamma * health - delta) * working + def utility(consumption, working, health, disutility_of_work, gamma): + return jnp.log(consumption) + (gamma * health - disutility_of_work) * working def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working @@ -113,7 +113,7 @@ def consumption_constraint(consumption, wealth): params = { "beta": 0.95, - "utility": {"delta": 0.5, "gamma": 0.5}, + "utility": {"disutility_of_work": 0.5, "gamma": 0.5}, "next_wealth": {"interest_rate": 0.05, "wage": 10.0}, "next_health": {}, "consumption_constraint": {}, From 456bb79c0c5381f0dc1faf1f819562483973f263 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 17:01:33 +0100 Subject: [PATCH 16/27] Remove dependency of get_model from test_regression_test.py --- pyproject.toml | 3 + src/lcm/_config.py | 2 +- tests/data/regression_tests/simulation.csv | 21 - tests/data/regression_tests/simulation.pkl | Bin 0 -> 1617 bytes tests/data/regression_tests/solution.json | 512 --------------------- tests/data/regression_tests/solution.pkl | Bin 0 -> 2725 bytes tests/test_analytical_solution.py | 6 +- tests/test_models/phelps_deaton.py | 18 +- tests/test_regression_test.py | 54 ++- 9 files changed, 47 insertions(+), 569 deletions(-) delete mode 100644 tests/data/regression_tests/simulation.csv create mode 100644 tests/data/regression_tests/simulation.pkl delete mode 100644 tests/data/regression_tests/solution.json create mode 100644 tests/data/regression_tests/solution.pkl diff --git a/pyproject.toml b/pyproject.toml index 1cedd9f9..443a74ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ extend-ignore = [ # use of `assert` detected "S101", + # `pickle` module is unsafe + "S301", + # Private member accessed: `_stochastic_info` "SLF001", diff --git a/src/lcm/_config.py b/src/lcm/_config.py index 40b5d764..40950b80 100644 --- a/src/lcm/_config.py +++ b/src/lcm/_config.py @@ -1,3 +1,3 @@ from pathlib import Path -TEST_DATA_PATH = Path(__file__).parent.parent.parent.resolve().joinpath("tests", "data") +TEST_DATA = Path(__file__).parent.parent.parent.resolve().joinpath("tests", "data") diff --git a/tests/data/regression_tests/simulation.csv b/tests/data/regression_tests/simulation.csv deleted file mode 100644 index e7124ef2..00000000 --- a/tests/data/regression_tests/simulation.csv +++ /dev/null @@ -1,21 +0,0 @@ -period,initial_state_id,value,retirement,consumption,wealth,_period -0,0,0.0,1,1.0,1.0,0 -0,1,7.349747,1,3.3987975,20.0,0 -0,2,10.85528,1,8.196393,40.0,0 -0,3,13.665637,1,13.793587,70.0,0 -1,0,-inf,0,1.0,0.0,1 -1,1,6.13457,1,4.9979963,17.431261,1 -1,2,8.751791,1,8.995992,33.393787,1 -1,3,11.042786,1,14.593186,59.01673,1 -2,0,-inf,0,1.0,-0.049999952,2 -2,1,4.5309544,1,4.1983967,13.054928,2 -2,2,6.559553,1,8.995992,25.617685,2 -2,3,8.367242,1,15.392786,46.644722,2 -3,0,-inf,0,1.0,-0.1024999,3 -3,1,3.0912378,1,4.1983967,9.299357,3 -3,2,4.3671656,1,8.995992,17.452776,3 -3,3,5.634691,1,16.192385,32.81453,3 -4,0,-inf,0,1.0,-0.1576249,4 -4,1,1.6090372,1,4.9979963,5.3560085,4 -4,2,2.1036942,1,8.196393,8.879623,4 -4,3,2.8327417,1,16.991983,17.45325,4 diff --git a/tests/data/regression_tests/simulation.pkl b/tests/data/regression_tests/simulation.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3abb6e85bffdf27f6c9066800c397cf0f83538f0 GIT binary patch literal 1617 zcmZ`(U1%It6yDj*Y?@}1w4q58EKR6_tBH#>#<&>Vi>XyRWNigODRH>7b2oR)?94KA zCowd^C^n{YU$oc0s1H&_K?ouUJ_r^B1@Wg3ilNvRe2IN(6jUfkJ$HV#Ys_AlId|sV zbME>1zO$mfc{G~}7vEdQj*0z3#q~&`>fss@yyFOF_-MF4EADzpEQ>sE(`&Py;0K$J zTMi?hgKfW1!w#M$o-cUr5!GAhUuc*J-lGRj-KxylL^q_3JlyL9 zwfbTd3~Y&KRRr(UtDajkf@+m`a43InSc!u7UZ3fh>-wkUZajPZ&wM}X`QfuWH+v7F zpRSCa{AKbG+9)3Dd|+^^^v-YCyzz85Qs~BO@4m7R9kMoiE_M#0x8CUc`S0c5OMk7e zkBnz;mX^-vZY*`F=;I&t`#w+YMgv=S(1`;X^tHJA>KEVNhCVodDk$DNimu&$W&FE~ z6KMTCbz$w3wub)Y`7aN@|IJ84f2m7aw*OBbntfYDQ5O4Rd4DXAL~`W>ntnc49KEh0 zZDVrpht(1EYyXgYYV09&;j0su?m@f*$RA z`(gK55ecq_^#p0M#X6A{PFEikXVzX7Pl>#i;mV{am&?a)!U|tniU~u@Xt`Cvl_|mC z812rV8;BxpywmY66H7sRdD?1nO#63ods@&z=~7F}`)Q|ieIa%cw=~sC#>ZOfTiGT5 z6Z=*&IoC=iDYue>4rysww}IMTX)EU{)7Ynz?_b8HQ9%hG=Zjq_cd{*7Qmg@^Y9`tQTIs2b$sRp6(-I1 z@OGejsDxNjr^S?FhzJRarb2U|iBC2)kt<8v&`|<(oRKlB^Vkk#fF5C%M{2}jz*Bh} z@RwPxBY1X!V4G3F+x0r|*fj+KYu@I_5mOFUB?89z;7-YhS%C9tnN4%XS)0rg+m~gd z<8%UwK#M!!GihSdEZLk)fU`?)ZOK7j8GpYCy)X^GR!aID|W3Fey zkSqi7q0G^>&(SK0f{`_cK@Ot;X1oMXGVk>oVK_w748RpL29talTBuiTJR5HJW7~EY VeBBJ{0A`t@jSbm~H9s&4{{nvY2$}!@ literal 0 HcmV?d00001 diff --git a/tests/data/regression_tests/solution.json b/tests/data/regression_tests/solution.json deleted file mode 100644 index 206771d1..00000000 --- a/tests/data/regression_tests/solution.json +++ /dev/null @@ -1,512 +0,0 @@ -[ - [ - 0.0, - 1.609037160873413, - 3.280139207839966, - 4.959736347198486, - 6.436992645263672, - 7.636216163635254, - 8.49956226348877, - 9.236945152282715, - 9.895655632019043, - 10.485862731933594, - 11.009462356567383, - 11.476234436035156, - 11.90827751159668, - 12.30354118347168, - 12.668807029724121, - 13.009539604187012, - 13.329602241516113, - 13.631184577941895, - 13.914976119995117, - 14.181705474853516, - 14.434638023376465, - 14.677236557006836, - 14.907997131347656, - 15.127725601196289, - 15.338500022888184, - 15.541778564453125, - 15.736351013183594, - 15.922971725463867, - 16.103235244750977, - 16.2774715423584, - 16.44595718383789, - 16.60894203186035, - 16.766494750976562, - 16.919342041015625, - 17.06781768798828, - 17.211830139160156, - 17.35169219970703, - 17.48772430419922, - 17.62046241760254, - 17.749614715576172, - 17.875568389892578, - 17.998262405395508, - 18.117958068847656, - 18.234933853149414, - 18.349817276000977, - 18.461841583251953, - 18.570756912231445, - 18.677404403686523, - 18.782127380371094, - 18.885183334350586, - 18.986446380615234, - 19.084327697753906, - 19.18060302734375, - 19.275423049926758, - 19.368736267089844, - 19.460474014282227, - 19.549938201904297, - 19.63794708251953, - 19.7244873046875, - 19.809663772583008, - 19.893457412719727, - 19.975894927978516, - 20.0570068359375, - 20.136667251586914, - 20.214962005615234, - 20.292137145996094, - 20.368305206298828, - 20.443410873413086, - 20.517166137695312, - 20.5898494720459, - 20.66151237487793, - 20.73221778869629, - 20.801939010620117, - 20.870697021484375, - 20.938554763793945, - 21.005416870117188, - 21.071449279785156, - 21.136539459228516, - 21.200878143310547, - 21.264432907104492, - 21.327144622802734, - 21.38905143737793, - 21.45021629333496, - 21.510589599609375, - 21.570310592651367, - 21.629364013671875, - 21.68773078918457, - 21.74538230895996, - 21.802289962768555, - 21.858627319335938, - 21.914451599121094, - 21.969585418701172, - 22.023998260498047, - 22.078208923339844, - 22.132692337036133, - 22.18666648864746, - 22.240039825439453, - 22.292774200439453, - 22.344619750976562, - 22.395854949951172 - ], - [ - 0.0, - 1.609037160873413, - 3.2538504600524902, - 4.800864219665527, - 6.063236713409424, - 6.8833208084106445, - 7.583081245422363, - 8.197111129760742, - 8.734527587890625, - 9.188711166381836, - 9.603611946105957, - 9.979996681213379, - 10.32372760772705, - 10.639446258544922, - 10.933663368225098, - 11.206408500671387, - 11.459930419921875, - 11.69947624206543, - 11.927656173706055, - 12.140583038330078, - 12.342621803283691, - 12.536150932312012, - 12.721176147460938, - 12.896724700927734, - 13.065574645996094, - 13.227193832397461, - 13.38270378112793, - 13.53246784210205, - 13.677145004272461, - 13.816749572753906, - 13.950849533081055, - 14.08076286315918, - 14.206995964050293, - 14.329532623291016, - 14.448339462280273, - 14.563127517700195, - 14.674884796142578, - 14.784219741821289, - 14.890746116638184, - 14.993427276611328, - 15.093796730041504, - 15.19232177734375, - 15.288949012756348, - 15.381830215454102, - 15.473026275634766, - 15.56264877319336, - 15.650506019592285, - 15.735862731933594, - 15.819493293762207, - 15.901622772216797, - 15.982186317443848, - 16.06104850769043, - 16.138484954833984, - 16.214248657226562, - 16.288612365722656, - 16.3616886138916, - 16.433692932128906, - 16.504037857055664, - 16.573286056518555, - 16.641393661499023, - 16.708454132080078, - 16.774280548095703, - 16.83905029296875, - 16.90280532836914, - 16.965579986572266, - 17.027347564697266, - 17.088260650634766, - 17.148141860961914, - 17.20716094970703, - 17.265356063842773, - 17.322708129882812, - 17.37929916381836, - 17.435081481933594, - 17.490036010742188, - 17.54423713684082, - 17.5977840423584, - 17.650787353515625, - 17.702754974365234, - 17.754154205322266, - 17.806262969970703, - 17.857818603515625, - 17.908693313598633, - 17.95891571044922, - 18.007951736450195, - 18.05640411376953, - 18.104270935058594, - 18.151195526123047, - 18.197465896606445, - 18.2431640625, - 18.288122177124023, - 18.3326416015625, - 18.376657485961914, - 18.420188903808594, - 18.46322250366211, - 18.505821228027344, - 18.547943115234375, - 18.589618682861328, - 18.630905151367188, - 18.671831130981445, - 18.71220588684082 - ], - [ - 0.0, - 1.609037160873413, - 3.193450689315796, - 4.537532329559326, - 5.304745197296143, - 5.954254150390625, - 6.507669448852539, - 6.945423603057861, - 7.33719539642334, - 7.689281463623047, - 7.995795249938965, - 8.278403282165527, - 8.536031723022461, - 8.7706937789917, - 8.990982055664062, - 9.193929672241211, - 9.384664535522461, - 9.565553665161133, - 9.733380317687988, - 9.894050598144531, - 10.046565055847168, - 10.190986633300781, - 10.329140663146973, - 10.461769104003906, - 10.587770462036133, - 10.70872688293457, - 10.82571029663086, - 10.938405990600586, - 11.045883178710938, - 11.150434494018555, - 11.252116203308105, - 11.348937034606934, - 11.443496704101562, - 11.535870552062988, - 11.62411880493164, - 11.710451126098633, - 11.794950485229492, - 11.876179695129395, - 11.955656051635742, - 12.033387184143066, - 12.108747482299805, - 12.182483673095703, - 12.254426956176758, - 12.324460983276367, - 12.393216133117676, - 12.460168838500977, - 12.525819778442383, - 12.590036392211914, - 12.652765274047852, - 12.714395523071289, - 12.774778366088867, - 12.833782196044922, - 12.891794204711914, - 12.948715209960938, - 13.004526138305664, - 13.059261322021484, - 13.113110542297363, - 13.166114807128906, - 13.217918395996094, - 13.26899528503418, - 13.319427490234375, - 13.368610382080078, - 13.41728401184082, - 13.46737003326416, - 13.516111373901367, - 13.564022064208984, - 13.610557556152344, - 13.65637493133545, - 13.700997352600098, - 13.744844436645508, - 13.787697792053223, - 13.830053329467773, - 13.871835708618164, - 13.912891387939453, - 13.953496932983398, - 13.993532180786133, - 14.03302001953125, - 14.072063446044922, - 14.11050796508789, - 14.148538589477539, - 14.186107635498047, - 14.223125457763672, - 14.259763717651367, - 14.296005249023438, - 14.331695556640625, - 14.367011070251465, - 14.401971817016602, - 14.436481475830078, - 14.47055721282959, - 14.504310607910156, - 14.537744522094727, - 14.570648193359375, - 14.603280067443848, - 14.635626792907715, - 14.667494773864746, - 14.699098587036133, - 14.73045825958252, - 14.761331558227539, - 14.791975021362305, - 14.822362899780273 - ], - [ - 0.0, - 1.609037160873413, - 3.0546796321868896, - 3.756441116333008, - 4.328797340393066, - 4.741797924041748, - 5.102178573608398, - 5.395397186279297, - 5.661606311798096, - 5.886263847351074, - 6.095755577087402, - 6.279889106750488, - 6.451984405517578, - 6.608402729034424, - 6.754400253295898, - 6.890329360961914, - 7.017242908477783, - 7.136931419372559, - 7.249366760253906, - 7.356710433959961, - 7.457448482513428, - 7.554689884185791, - 7.645936012268066, - 7.734623432159424, - 7.818196773529053, - 7.8995361328125, - 7.976801872253418, - 8.05201530456543, - 8.123757362365723, - 8.193826675415039, - 8.260684967041016, - 8.326040267944336, - 8.388843536376953, - 8.450052261352539, - 8.509283065795898, - 8.566838264465332, - 8.622849464416504, - 8.677196502685547, - 8.730266571044922, - 8.781784057617188, - 8.832263946533203, - 8.881186485290527, - 8.929313659667969, - 8.975887298583984, - 9.022041320800781, - 9.070188522338867, - 9.116161346435547, - 9.160223007202148, - 9.202659606933594, - 9.24348258972168, - 9.28371524810791, - 9.32281494140625, - 9.3614501953125, - 9.399123191833496, - 9.436283111572266, - 9.47262954711914, - 9.508438110351562, - 9.543533325195312, - 9.57809066772461, - 9.612010955810547, - 9.64537239074707, - 9.67822265625, - 9.710469245910645, - 9.742314338684082, - 9.773517608642578, - 9.804378509521484, - 9.834641456604004, - 9.864582061767578, - 9.893961906433105, - 9.923052787780762, - 9.951576232910156, - 9.979854583740234, - 10.007579803466797, - 10.035058975219727, - 10.06205940246582, - 10.088783264160156, - 10.115095138549805, - 10.14111328125, - 10.166763305664062, - 10.192110061645508, - 10.217130661010742, - 10.241827964782715, - 10.26626205444336, - 10.290343284606934, - 10.314216613769531, - 10.337711334228516, - 10.36102294921875, - 10.383988380432129, - 10.406768798828125, - 10.429222106933594, - 10.451505661010742, - 10.474639892578125, - 10.49744987487793, - 10.519763946533203, - 10.541597366333008, - 10.562971115112305, - 10.583905220031738, - 10.604476928710938, - 10.62489128112793, - 10.645040512084961 - ], - [ - 0.0, - 1.609037160873413, - 2.1967790126800537, - 2.5644867420196533, - 2.8327417373657227, - 3.0440452098846436, - 3.2183947563171387, - 3.366811990737915, - 3.4960215091705322, - 3.6104302406311035, - 3.713083028793335, - 3.8061723709106445, - 3.891329288482666, - 3.9698002338409424, - 4.0425591468811035, - 4.110381126403809, - 4.173893928527832, - 4.233612537384033, - 4.2899651527404785, - 4.343310832977295, - 4.393954277038574, - 4.442155838012695, - 4.488141059875488, - 4.532103538513184, - 4.574214935302734, - 4.622513771057129, - 4.6610541343688965, - 4.698163986206055, - 4.733945846557617, - 4.768491268157959, - 4.801883220672607, - 4.834196090698242, - 4.865497589111328, - 4.895848751068115, - 4.9253058433532715, - 4.953919887542725, - 4.981738090515137, - 5.008803367614746, - 5.035155296325684, - 5.060830593109131, - 5.08586311340332, - 5.110284328460693, - 5.134122848510742, - 5.157406806945801, - 5.180160999298096, - 5.202408790588379, - 5.224172592163086, - 5.245472431182861, - 5.26632833480835, - 5.286757946014404, - 5.310734748840332, - 5.330286026000977, - 5.349462032318115, - 5.368277072906494, - 5.386744976043701, - 5.404877662658691, - 5.422687530517578, - 5.440185546875, - 5.457382678985596, - 5.474289417266846, - 5.490914821624756, - 5.50726842880249, - 5.5233588218688965, - 5.539194583892822, - 5.554783344268799, - 5.570132732391357, - 5.585249900817871, - 5.600142478942871, - 5.614816188812256, - 5.62927770614624, - 5.643533229827881, - 5.657588005065918, - 5.671448230743408, - 5.685119152069092, - 5.698605537414551, - 5.714552402496338, - 5.727650165557861, - 5.740578651428223, - 5.753342151641846, - 5.765944480895996, - 5.778390407562256, - 5.790683269500732, - 5.802826404571533, - 5.814824104309082, - 5.826679706573486, - 5.838396072387695, - 5.849977016448975, - 5.861425399780273, - 5.872744083404541, - 5.883936405181885, - 5.8950042724609375, - 5.905951499938965, - 5.9167799949646, - 5.927492141723633, - 5.938091278076172, - 5.948578834533691, - 5.958958148956299, - 5.9692301750183105, - 5.979397773742676, - 5.991464614868164 - ] -] diff --git a/tests/data/regression_tests/solution.pkl b/tests/data/regression_tests/solution.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d45858ddd7277383d02eabf174440fba25ada4b5 GIT binary patch literal 2725 zcma);X*d*W8^`S$F`Jn&yBRapSW=y1sgs(2)}mA#r6So%ma$}L+E7WNqnNU^=tP@S zB8gIIL3CQ=v{Q&EM@ot-ig!-uyx04Fct5=Nhv&ZT=l{RI+}HKHpC?g!U!I=Y_wZFp z(utw|>n;7(#ROXV$He$6l<8EznBc%j*}B*md0?#H-7X8Zfdd0fK+>drkdK=oZ9NM&m_`;+7Z_tVwbdbPgdd8 zE&WoWi?-a%Yfu`q*Ow4x@J*^%?PgK1MIAP;l3+Va3-d!VIxJjukKf|7aK|oA zQM5j0#no)gI!Z*I;hKw2y~uD7FF(0*CI$JOQyX(=R4m}$W)`(k5y2^>2oD;f^30o0 zBlBr^QcF3e^O**1`{rYhGUT=epsj{;$w;C(8MV`tP9-N zVX&D8x6JE{8|!%>-LIU~P30pk>{HX&az0wiq!Av!@=j|ri58|-~PX|+DEX?_+ z(-mqAgPpO*DZd$G@LSsL4{wb@U)b}g#!!GZ-Om;erwJez9Wz_JPyqjr<)$mc1(;w* zSa@Wo0KvBN=g0~L$apxs^M17e+pSh@-~L>Hbv}Lik3R|U@zjt04MZV&$6O=m5+UX# z+Ef$g2w`(__Ld9FglM)dF25cv#O=k?PA8-av10A+)UhK%ERZtW+bV=O=v?Bs?Vb?a z{0+_>okIN1iX5vO79uiFQuc~$g5%~i?zBlJc=Xa@MT?yY;;t{gVC`Xo4XuacM^~F* zi5r=iu+ap+XbCBi878TPy|w1SPgxGs(v8}ygqEyW>z9yG#Tt*1Uy@p&~5K zUaPihqX@NSC#=-dMc5zon@n*~1nr#>l{KeDxIeyhwDqzGLy9Z*_wR@>xTyB>o)!@_ z`UoBqdR2PA&H4Q!A{^(mx;)VnV_#Uq$^n)b3g$S&Hxe<5gVac+wqg{t?3laESqx>Y za=!6jdqB3Osf0(5HO_@XK zdGGes zHSzp?Xv(|+ZG^3Ktgfoo#pI*ANyowo80BrPOX<|Zih8f%_YV5FIoUQXHOTbXYw!o3bI54#l%qhl}g!n44N`5vaoeeW`q7)f@&Y&C})Y zH!^S_)zRZc83U0H8cBD17?3sG82Q842ovI>#hso;s4no-vfp8ZM4Huo4wiWWsnvuKsH@6OvoBBP$Lup*3?`a#jNq-{M~jvIm(s_iSQ& zm5_z=*2-BDHx}&l4o@GAVIefB*=J)e3p$G7AVoC``N9Y7nVl@;u&8=g1UBXl`Zn}i zuyN%}G;_fsHWt^Oxm_Q_hPf}F*R5n@+#}K%hYM_Ywi82sdcuZ5z=U_HLu|x7T3a4( z#6ea|pG&n32TOZ0glD}t@GjXgRV$VQ(_2P`YMC4;2;Nq4XE``Symn>x9S*L4dcJ7t zTMmA5B`=BB=7JYE`I@Vcit8{ z!G+bA$MKcdxagWUyVLMF7Y6rsub();#TJM5(bn-i&}1X5Wnv!6)xD{4c08P%@6+n@ zGY^Vu`Hwoocxan7-R`GtJXGX8{W2qu2bs)$wQea7X2h>^w%+1FQgicS;|m@$h z7~p|F`bB(*z(=`$h%v{QkMK1qGqzdrQU60h(*svN650y37Ovvs6~jDxOwLEg6w?E} zyZBi6x}z^LpN}FkzjSvQ{~yXin5iMG|BI);-&!@b{D!f!J_cRV1DW(EL!0}hFB$i` zGw6fT^sc7i?qgr1cS#_#Eo>^#AKql3jhe`_AJ6A$!ew1DT2;~{POs;?8L z%3VtP!xzRw?BDUg?(=x$mea0`P1DEf%|e+*gg#EJu~T1pTpxAyoMh9N`WWS0{$T^n z05eD59SWRhfW7j%)}c)X=vwu(%H^y9A|ylC0@@5fohw?sluU$K<|OftQX+1=@lMx> zAi}Qtsn3I4A{b7mGmC18Sk35DHeXK*qo^x(LPhye5Z@J$&!JLeS9}aduv?HTvy(QV! zos6j7fcNfxsyge^XMPkJv`sPcQ;B3`L`4;T*h9uv0palA5i-nN3xckckl|=Fv%~os z84)|``;_;|7>Z1DIrWl^${1;0@;fptp3g`!|4N4Cign2c^(c7nm0~iq@4E+8b5Bq*8;Dz(!*#Zfa!=mzq)0 ztRLhOXHCV5`_sF4pdz8!QFh9eipN97F)S}te%qfQ^P_^{RbEvZM*SZ=1zY|LiwA92 literal 0 HcmV?d00001 diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 9fabaadd..34875f1c 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -2,14 +2,14 @@ import numpy as np import pytest -from lcm._config import TEST_DATA_PATH +from lcm._config import TEST_DATA from lcm.entry_point import get_lcm_function from lcm.get_model import get_model from numpy.testing import assert_array_almost_equal as aaae TEST_CASES = { "iskhakov_2017_five_periods": get_model("iskhakov_2017_five_periods"), - "iskhakov_2017_low_disutility_of_work": get_model( + "iskhakov_2017_low_delta": get_model( "iskhakov_2017_low_disutility_of_work", ), } @@ -42,7 +42,7 @@ def test_analytical_solution(model_name, model_config): # ================================================================================== analytical = { _type: np.genfromtxt( - TEST_DATA_PATH.joinpath( + TEST_DATA.joinpath( "analytical_solution", f"{model_name}__values_{_type}.csv", ), diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index acc193c8..e6b88e14 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -1,9 +1,11 @@ """Example specifications of the Phelps-Deaton model. -This specification corresponds to the example model presented in the paper: "The -endogenous grid method for discrete-continuous dynamic choice models with (or without) -taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning -(2017, https://doi.org/10.3982/QE643). +This specification extends the example model presented in the paper: "The endogenous +grid method for discrete-continuous dynamic choice models with (or without) taste +shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, +https://doi.org/10.3982/QE643). + +In comparison to the original paper, it adds the auxiliary variables "age" and "wage". """ @@ -100,7 +102,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "retirement": {"options": [0, 1]}, "consumption": { "grid_type": "linspace", - "start": 0, + "start": 1, "stop": 100, "n_points": N_GRID_POINTS["consumption"], }, @@ -108,7 +110,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "states": { "wealth": { "grid_type": "linspace", - "start": 0, + "start": 1, "stop": 100, "n_points": N_GRID_POINTS["wealth"], }, @@ -131,7 +133,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "states": { "wealth": { "grid_type": "linspace", - "start": 0, + "start": 1, "stop": 100, "n_points": N_GRID_POINTS["wealth"], }, @@ -161,7 +163,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "states": { "wealth": { "grid_type": "linspace", - "start": 0, + "start": 1, "stop": 100, "n_points": N_GRID_POINTS["wealth"], }, diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 4bd91bda..3f2e2630 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -1,50 +1,56 @@ -import json - import jax.numpy as jnp import pandas as pd -from lcm._config import TEST_DATA_PATH +from lcm._config import TEST_DATA from lcm.entry_point import get_lcm_function -from lcm.get_model import get_model from numpy.testing import assert_array_almost_equal as aaae from pandas.testing import assert_frame_equal +from tests.test_models.phelps_deaton import PHELPS_DEATON + +REGRESSION_TEST_MODEL = {**PHELPS_DEATON, "n_perids": 5} +REGRESSION_TEST_PARAMS = { + "beta": 0.95, + "utility": {"disutility_of_work": 1.0}, + "next_wealth": { + "interest_rate": 0.05, + }, +} + def test_regression_test(): """Test that the output of lcm does not change.""" + # ---------------------------------------------------------------------------------- # Load generated output - # ================================================================================== - with TEST_DATA_PATH.joinpath("regression_tests", "simulation.csv").open() as file: - expected_simulate = pd.read_csv(file, index_col=["period", "initial_state_id"]) - - with TEST_DATA_PATH.joinpath("regression_tests", "solution.json").open() as file: - _expected_solve = json.load(file) - - # Stack value function array along time dimension - expected_solve = jnp.stack([jnp.array(data) for data in _expected_solve]) + # ---------------------------------------------------------------------------------- + expected_simulate = pd.read_pickle( + TEST_DATA.joinpath("regression_tests", "simulation.pkl"), + ) - # Create current lcm ouput - # ================================================================================== - model_config = get_model("phelps_deaton_regression_test") + expected_solve = pd.read_pickle( + TEST_DATA.joinpath("regression_tests", "solution.pkl"), + ) - solve, _ = get_lcm_function(model=model_config.model, targets="solve") + # ---------------------------------------------------------------------------------- + # Generate current lcm ouput + # ---------------------------------------------------------------------------------- + solve, _ = get_lcm_function(model=REGRESSION_TEST_MODEL, targets="solve") - _got_solve = solve(model_config.params) - # Stack value function array along time dimension - got_solve = jnp.stack(_got_solve) + got_solve = solve(REGRESSION_TEST_PARAMS) solve_and_simulate, _ = get_lcm_function( - model=model_config.model, + model=REGRESSION_TEST_MODEL, targets="solve_and_simulate", ) got_simulate = solve_and_simulate( - params=model_config.params, + params=REGRESSION_TEST_PARAMS, initial_states={ - "wealth": jnp.array([1.0, 20, 40, 70]), + "wealth": jnp.array([5.0, 20, 40, 70]), }, ) + # ---------------------------------------------------------------------------------- # Compare - # ================================================================================== + # ---------------------------------------------------------------------------------- aaae(expected_solve, got_solve, decimal=5) assert_frame_equal(expected_simulate, got_simulate) From a3005c434843e03e59c79e3e61da2325bb3b7686 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 17:32:16 +0100 Subject: [PATCH 17/27] Remove get_models.py module --- src/lcm/get_model.py | 136 ------------------ tests/data/regression_tests/simulation.pkl | Bin 1617 -> 1617 bytes tests/data/regression_tests/solution.pkl | Bin 2725 -> 2725 bytes tests/test_analytical_solution.py | 50 +++++-- tests/test_models/phelps_deaton.py | 10 +- tests/test_process_model.py | 8 +- tests/test_simulate.py | 2 +- ...model.py => test_solution_on_toy_model.py} | 0 8 files changed, 48 insertions(+), 158 deletions(-) delete mode 100644 src/lcm/get_model.py rename tests/{test_analytical_solution_on_toy_model.py => test_solution_on_toy_model.py} (100%) diff --git a/src/lcm/get_model.py b/src/lcm/get_model.py deleted file mode 100644 index 61e932ca..00000000 --- a/src/lcm/get_model.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Get a user model and parameters.""" - -from typing import NamedTuple - -from pybaum import tree_update - -from tests.test_models.phelps_deaton import ( - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) - - -class ModelAndParams(NamedTuple): - """Model and parameters.""" - - model: dict - params: dict - - -def get_model(model: str): - """Get a user model and parameters. - - Args: - model (str): Model name. - - Returns: - NamedTuple: Model and parameters. Has attributes `model` and `params`. - - """ - if model not in MODELS: - raise ValueError(f"Model {model} not found. Choose from {set(MODELS.keys())}.") - return MODELS[model] - - -# ====================================================================================== -# Models -# ====================================================================================== - -# Remove age and wage functions from Phelps-Deaton model, as they are not used in the -# original paper. -PHELPS_DEATON_WITHOUT_AGE = PHELPS_DEATON.copy() -PHELPS_DEATON_WITHOUT_AGE["functions"] = { - name: func - for name, func in PHELPS_DEATON_WITHOUT_AGE["functions"].items() - if name not in ["age", "wage"] -} - - -PHELPS_DEATON_FIVE_PERIODS = { - **PHELPS_DEATON_WITHOUT_AGE, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 500, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 100, - }, - }, - "n_periods": 5, -} - - -ISKHAKOV_2017_FIVE_PERIODS = { - **PHELPS_DEATON_WITH_FILTERS, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 500, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 100, - }, - "lagged_retirement": {"options": [0, 1]}, - }, - "n_periods": 5, -} - - -ISKHAKOV_2017_THREE_PERIODS = tree_update(ISKHAKOV_2017_FIVE_PERIODS, {"n_periods": 3}) - -# ====================================================================================== -# Models and params -# ====================================================================================== - -MODELS = { - "phelps_deaton_regression_test": ModelAndParams( - model=PHELPS_DEATON_FIVE_PERIODS, - params={ - "beta": 1.0, - "utility": {"disutility_of_work": 1.0}, - "next_wealth": { - "interest_rate": 0.05, - "wage": 1.0, - }, - }, - ), - "iskhakov_2017_five_periods": ModelAndParams( - model=ISKHAKOV_2017_FIVE_PERIODS, - params={ - "beta": 0.98, - "utility": {"disutility_of_work": 1.0}, - "next_wealth": { - "interest_rate": 0.0, - "wage": 20.0, - }, - }, - ), - "iskhakov_2017_low_disutility_of_work": ModelAndParams( - model=ISKHAKOV_2017_THREE_PERIODS, - params={ - "beta": 0.98, - "utility": {"disutility_of_work": 0.1}, - "next_wealth": { - "interest_rate": 0.0, - "wage": 20.0, - }, - }, - ), -} diff --git a/tests/data/regression_tests/simulation.pkl b/tests/data/regression_tests/simulation.pkl index 3abb6e85bffdf27f6c9066800c397cf0f83538f0..d28f5814ef91df0d48d59bca48c87c20c79d6cf5 100644 GIT binary patch delta 271 zcmcb}bCGAliTcgV40Gr1`)SW&Sb4wnzo-L~-|6R5*yS7w^?l~aS}Hj(fWe3VcVG63 zYE|$$ywo^$&^SuO;qk{>In{3x4!P!MzXqLsWbfVFE4@kopM8fI?=ry_76%K@6Z>5E zb2{8D@LBokudu_^Ef4qtr^!0p+kZ}eo4&5Y^2ODqmIf9MFbnLN=CQnOOO}Ev)P^hY ziYjQ=IQ<7M4>IocA58}cZ85o>QPIvga+B8$n?(Enxdu0?%e5R_Z%h7(Tx;b}y;cLrH!(Yn9?oV3I8#nq*j209KemG7CJyb?yKF diff --git a/tests/data/regression_tests/solution.pkl b/tests/data/regression_tests/solution.pkl index d45858ddd7277383d02eabf174440fba25ada4b5..c1889cb64d28d9805109c634b7eaf20f36d6eb4c 100644 GIT binary patch delta 2428 zcmZXWX*|^l7sq8^!oA$Z?k<0KC3zBBhF`ZrWs9O>EG1Nih-$LjLXR5C+9VATN+o*_ zQL@z-Mroo#QnV*`qLPrgFP`Vsb6%a#`Fzjsy!xGUN~k5&^hUAFf1^fDCNHmP+Ac}( z);(yuGF();J^!{QW09P~uLeR2))HM`T?=WDPux`=c}oWunhSe1Vwhs| zpXo?H@Q8_DkN16VUBrU&$%~5e2`v0DMb-(u$3oJbh3=)gZ0zr;i751EV~6bltGz{R zM1H-}IzGzA%pu?WN+y9T6Lvbgd25B{HAuhg6J@ODw8#ArPa z-ZEJQ?TVdti)wkuo3K z3&KteGx?ZFTlJ)EIUi*fj-9JK`1l-IQ@wHzAC1CXr`HL5xamKng*)_y%%=s&s~g_2K~0E1S4MS_422lkXD3^+NQj~7*Un4`@t`oH)nTI$ zZOl+Z$L&IVSmWcy*eAqf_F+L)q7d=9A=w7mLd>2yx68j+i0rev6`{34q)!D5yR{3^ zGu&;e-Ydi~x%zSDs1R#?P0Ac+CWOe%^_Q`hHAUpTs!0*q6b{ZMcgz{47!OLm&(VAD*m}4z$5$UwrK1Cpz53vKRuU|G1GKEpUEWt=0N*{CrA1UjOx)Pj&K_r{LpWOy@>^;7x+%|i^drGD9Q{xkyyM!gS1 zY8hZy1u1UmVZh0CvpC^B1Mk#xBi<@7k;?MZX)tEOV71)iKW&(x`+3p*H!|_B=ca`= zyO?l1Ae(vRFca|(AxnZ!Gx2~Ff3D*y6QA2JOP1YcBK&1^!42^s6ODB-x7M3 zX{5+PBIN{2pvS_}&YrpZW-QEmZzRuw1+slh+j9>V9u)K%`UJ6HO)~qKeUOEMTh6-M ze_y_mSr+J=oyy#+EO3>|L?%sAJ^qQr=DRGEGxqcghgq<_TB#(Gupl4a)F&3q zv*2*jZ23(MHsE*2QfS16^T9Bd$cznj!PbXXE7?eLn%y?Mo{f=;6>|YwrS-*!SDX)K zLnB;+(G<-_&9RrWzb3NrXYa!ot{H4Zq#qpU$YWC^)U*=W>F8@*@*Yj!h1Mpml?rCFcpaUn;H$yDDwiN*=7QBha`Y zaqM0Tkp-n?;{V-}7v1q6WwI9F7)|{GIY}O&2P|TuE-6BTZRXks^(R^S^4R+VuocJRP5lq5FV68g(BtK`U?$ba@*qloQQxoYtmqJdfCiIV8 z&&U>%fJ=L7l_E%ZFgLfHTulg{ zicOPfL6z1XBGJ;u&ehg$7)!N5liibS?WqlB(V5vb`^DM_>{klddP)lLssa*gw6Rt) z8`9IG4d;Q9R;LfzxU7AnV6W5V&g1JqjG&KZ`u^l7d$C=Y~JdQ1B*>x2axCqT-3) z_@g2dDiW0Ioh)prSSOPjW#vwV=kM=4`gc;X&AKAb@MkI}Q-0`COr=6lv~gKZJ{8=b z_fPj$QjtGnuGHL0MV`&Dg-I|b5EmA&FV?l8VO>=HvmOT;zI@w#e}fkduS!F2rUcQTa*^ks8BK$qYIi_D5)CiQ zcWn8VNkj1FM;*S>|01OH6*s+t2D@U#!L*w+9BWv*%&&`v%)8}fHG?$FJaH2mzLgFh zRkPprg@$cmQ%0e3bj&Ya((OT_L+!6f@sC=Dbo3e=`Ky#q$N80wCOTGhm}Vt81c>M` zuBtIU>OzP8*w+(D+obDP-fp*l5FOt|ibr)L>3F)k;bLVR9Wx_xr#Q#y$n>VD#Aegs zyhd`p>Kq*@o6EgAE=%=)YE5mvM#qZz#ou$A>EN%qyJ1Tg9j#%8QKacUI&R6@jKn>o Y`=+1FAFd-~a#s delta 2428 zcmWMoc{mh^1C1OJbD0@)GiH>lt#%!?sd>4hQmiPXYdf4#NXk%AQcEaD2OBC@j^s|r z5w?(DO2`pW7Lo8d;^(jTz3;vE&->$jFO8N)iz*c_ZmlWL{-7b6sF|}IA`FOXqGD2? z2aSse=F+}ec2gps8_sfWia$jGN3C>Hf+V2-DFg(B186gFt zcxxT2j6p+Pp?|b#i_$L z`xMV}VHI7uySa`FVpGxTK?)BMZa-QU_wmr3C-SiQn}^gVcJG-3JP^|AZH+ZmU>*Nm zX24YiidRlQApfI+sh<%IV_#K3vmR=#7AvWuM|M)@?Rr)C3evQ8TdU&W_k7L$?y68X zl(SC0qKZEawr}-JRYm06*(*)us<@k2_{LCKns| z;3S+9#^U#8b$?JOT&i5hvINNRYXP35d#t%lYv9mh^XHTYHTF3_`4!`be8{tHLd zu*Zr-5C^K^FDX9RD^d+5-h=7<6g7ylto-*CsNteE_0i8u&V&O z<4!C)5GX)R-a}o9YXaOl@{ec09RZ}TcobD<3edDNccJ^40Mh|wMoq5-nA%nIEUHrg z$x%5Q_2DIbT>t5!OC1$TDf9$6s?mlpH-&wKZW}G_6{x*L-10#9RovSvEB<`kczH+2k#O zpPA{JlCPJe{hRCMZ4=UP+iy}{S}uz<_pcMv+~u&qy-<6(UmlKicC+7%6;QRtAR;_O z5wX#xgW^mA&U#fQsXZrRjZQve# zN5Q&&`;y0E3ifu~+>_EtK`Gr=*;Y>p%*5B1Y_BT8_?%cc)}aLX#wcB$78Tr}Vb?rg zDrC(=y(rJA$m?32WiLg8Xv$Qt#heD+w_0n@hSL!6;nTgWIx!6v;aNI|W$2*U`v#V7 zrK3pun(wzjI&O!X*nG^R!^>DQf0bw6)mzpXOLQ*5P-!jutA)m__I zsEk_+9f>BwC`H?m5tvmNqhXIIp7{%Q(?*HU{h$%(^^9gf?L+z`e@IAl`tze-e*ZCSy#<3 zZO+c?WGjh>9O%xx^DnC4U|`#pex-H}6dSLL&#ao@Ajr6Pp?f74RL?o)ULhCx5)KqU zLoOa|cj|WfgNuNQq}F~nE_ya^F#P=@7lny!GaD1R@buhyR5q6jEyDcP3op6QsIDq* z{>X*xzY}rw6I}2XW`uX;c*s{cuEJK~!Ts3fjTd!!sQWdg<&7l|;vV&s3#kWq_(a!^ zTlD3jZ>{F-;j27Yf9@OgO5!1d#LK;&$BTjNMsZU8Vz!gwfarEK?fvxmaZwk&aWIlL zCAv1yGCP^gm0$DbVzEU`sqA+_-Q5w}=+25ZgO5?RfqEa(Y2JcK+re70f#5ne| zc3_Ao3qtX!T?s+5h|ua!P8^iQr|7Gx=d4!X$$$5qiZfQ=)3eK$N~GnmZ&Ag*(@72* z37%D6IdT}gd_2URBoCL>y&SHyJd|x0B;_8;qj2XRaR*1`F)^hB+uv3~)H1Ii#)C-&4{<0S}KEO6-52fZ95CsAiV}7TC{zJxf)@ zriGEahqo!>hHq{6^f^Tg9B3;w&sKzo#&pG@9z{?r1V`;j1ZYLC7XBt8;Kdh*Ymy!W z7?!s=y-6T|ZkidLQA5B{`e;D>7y%*1d!D6=SwyU&Zue|6Cc>d_xWV8w5l)_YZke}< z_!99>_0S6+0?JDO+x*zcLv`Yl8|xc7m~qF z5>5^u`nJ<$iN};o@=lUKJ?G<_9!x^y$&;yLQ6yYYmAf~UOoEnW$`Nr{4hbg8oBGTu zNbm@&8;xxuVcILgJiUvAA|Fv==m-fq?Hfb1=Sld*@l@y?c{0A)iB;MeWW=iZ6n$Pr z#zOXUYKL^34H*Nayy`9&GVE*)&CR-#p?Bjz&2Ru2-~TdOpC3ks zOV~!SiCqjCzT^L!d7VTCk$c-jHIuwlnFnD<5gGb~;SF}xOM!ns+}+qrMt{y(;#|j) zFaE1T;vg9Gw^L4mk~YvP%1Nx{2m6(6SE5`G;Qd%93Sx63chb)&>Y$ Date: Wed, 28 Feb 2024 17:46:21 +0100 Subject: [PATCH 18/27] Update example model --- examples/long_running.py | 17 +++++++++++------ tests/test_models/phelps_deaton.py | 2 -- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index f5dec21d..a54368f7 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -9,7 +9,7 @@ "wealth": 100, "health": 100, "consumption": 100, - "exericse": 200, + "exercise": 200, } RETIREMENT_AGE = 65 @@ -29,6 +29,10 @@ def utility(consumption, working, health, exercise, disutility_of_work): # -------------------------------------------------------------------------------------- # Auxiliary variables # -------------------------------------------------------------------------------------- +def labor_income(wage, working): + return wage * working + + def working(leisure): return 1 - leisure @@ -44,8 +48,8 @@ def age(_period): # -------------------------------------------------------------------------------------- # State transitions # -------------------------------------------------------------------------------------- -def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth + working * wage - consumption) +def next_wealth(wealth, consumption, labor_income, interest_rate): + return (1 + interest_rate) * (wealth + labor_income - consumption) def next_health(health, exercise, working): @@ -55,8 +59,8 @@ def next_health(health, exercise, working): # -------------------------------------------------------------------------------------- # Constraints # -------------------------------------------------------------------------------------- -def consumption_constraint(consumption, wealth): - return consumption <= wealth +def consumption_constraint(consumption, wealth, labor_income): + return consumption <= wealth + labor_income # ====================================================================================== @@ -67,11 +71,12 @@ def consumption_constraint(consumption, wealth): "functions": { "utility": utility, "next_wealth": next_wealth, + "next_health": next_health, "consumption_constraint": consumption_constraint, + "labor_income": labor_income, "working": working, "wage": wage, "age": age, - "next_health": next_health, }, "choices": { "leisure": {"options": [0, 1]}, diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index 460c2a93..0c344f40 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -5,8 +5,6 @@ shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, https://doi.org/10.3982/QE643). -In comparison to the original paper, it adds the auxiliary variables "age" and "wage". - """ import jax.numpy as jnp From e4ebed36a25383043f85bce71741830af4d8a6a5 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 18:36:32 +0100 Subject: [PATCH 19/27] Rename phelps_deaton -> deterministic --- tests/test_analytical_solution.py | 6 ++-- tests/test_entry_point.py | 28 +++++++++---------- tests/test_model_functions.py | 4 +-- .../{phelps_deaton.py => deterministic.py} | 12 ++++---- tests/test_models/stochastic.py | 4 +-- tests/test_next_state.py | 4 +-- tests/test_process_model.py | 14 +++++----- tests/test_regression_test.py | 4 +-- tests/test_simulate.py | 24 ++++++++-------- tests/test_state_space.py | 4 +-- 10 files changed, 52 insertions(+), 52 deletions(-) rename tests/test_models/{phelps_deaton.py => deterministic.py} (95%) diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 46ef2fdd..76d5e231 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -6,7 +6,7 @@ from lcm.entry_point import get_lcm_function from numpy.testing import assert_array_almost_equal as aaae -from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS +from tests.test_models.deterministic import BASE_MODEL_WITH_FILTERS # ====================================================================================== # Model specifications @@ -23,11 +23,11 @@ TEST_CASES = { "iskhakov_2017_five_periods": { - "model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 5}, + "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 5}, "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}}, }, "iskhakov_2017_low_delta": { - "model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 3}, + "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 3}, "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}}, }, } diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 24db8d81..7364f3ca 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -10,17 +10,17 @@ from lcm.state_space import create_state_choice_space from pybaum import tree_equal, tree_map -from tests.test_models.phelps_deaton import ( - PHELPS_DEATON, - PHELPS_DEATON_FULLY_DISCRETE, - PHELPS_DEATON_WITH_FILTERS, +from tests.test_models.deterministic import ( + BASE_MODEL, + BASE_MODEL_FULLY_DISCRETE, + BASE_MODEL_WITH_FILTERS, utility, ) MODELS = { - "simple": PHELPS_DEATON, - "with_filters": PHELPS_DEATON_WITH_FILTERS, - "fully_discrete": PHELPS_DEATON_FULLY_DISCRETE, + "simple": BASE_MODEL, + "with_filters": BASE_MODEL_WITH_FILTERS, + "fully_discrete": BASE_MODEL_FULLY_DISCRETE, } @@ -45,7 +45,7 @@ def test_get_lcm_function_with_solve_target(user_model): @pytest.mark.parametrize( "user_model", - [PHELPS_DEATON, PHELPS_DEATON_FULLY_DISCRETE], + [BASE_MODEL, BASE_MODEL_FULLY_DISCRETE], ids=["simple", "fully_discrete"], ) def test_get_lcm_function_with_simulation_target_simple(user_model): @@ -66,7 +66,7 @@ def test_get_lcm_function_with_simulation_target_simple(user_model): @pytest.mark.parametrize( "user_model", - [PHELPS_DEATON, PHELPS_DEATON_FULLY_DISCRETE], + [BASE_MODEL, BASE_MODEL_FULLY_DISCRETE], ids=["simple", "fully_discrete"], ) def test_get_lcm_function_with_simulation_is_coherent(user_model): @@ -109,7 +109,7 @@ def test_get_lcm_function_with_simulation_is_coherent(user_model): @pytest.mark.parametrize( "user_model", - [PHELPS_DEATON_WITH_FILTERS], + [BASE_MODEL_WITH_FILTERS], ids=["with_filters"], ) def test_get_lcm_function_with_simulation_target_with_filters(user_model): @@ -137,7 +137,7 @@ def test_get_lcm_function_with_simulation_target_with_filters(user_model): def test_create_compute_conditional_continuation_value(): - model = process_model(PHELPS_DEATON) + model = process_model(BASE_MODEL) params = { "beta": 1.0, @@ -180,7 +180,7 @@ def test_create_compute_conditional_continuation_value(): def test_create_compute_conditional_continuation_value_with_discrete_model(): - model = process_model(PHELPS_DEATON_FULLY_DISCRETE) + model = process_model(BASE_MODEL_FULLY_DISCRETE) params = { "beta": 1.0, @@ -228,7 +228,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): def test_create_compute_conditional_continuation_policy(): - model = process_model(PHELPS_DEATON) + model = process_model(BASE_MODEL) params = { "beta": 1.0, @@ -272,7 +272,7 @@ def test_create_compute_conditional_continuation_policy(): def test_create_compute_conditional_continuation_policy_with_discrete_model(): - model = process_model(PHELPS_DEATON_FULLY_DISCRETE) + model = process_model(BASE_MODEL_FULLY_DISCRETE) params = { "beta": 1.0, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index dfa58cbd..837f4146 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -10,7 +10,7 @@ from lcm.state_space import create_state_choice_space from numpy.testing import assert_array_equal -from tests.test_models.phelps_deaton import PHELPS_DEATON, utility +from tests.test_models.deterministic import BASE_MODEL, utility def test_get_combined_constraint(): @@ -42,7 +42,7 @@ def h(): def test_get_utility_and_feasibility_function(): - model = process_model(PHELPS_DEATON) + model = process_model(BASE_MODEL) params = { "beta": 1.0, diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/deterministic.py similarity index 95% rename from tests/test_models/phelps_deaton.py rename to tests/test_models/deterministic.py index 0c344f40..80f682f4 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/deterministic.py @@ -1,6 +1,6 @@ -"""Example specifications of the Phelps-Deaton model. +"""Example specifications of a deterministic consumption-saving model. -This specification extends the example model presented in the paper: "The endogenous +This specification builds on the example model presented in the paper: "The endogenous grid method for discrete-continuous dynamic choice models with (or without) taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, https://doi.org/10.3982/QE643). @@ -84,10 +84,10 @@ def absorbing_retirement_filter(retirement, lagged_retirement): # ====================================================================================== -# Model specification and parameters +# Model specifications # ====================================================================================== -PHELPS_DEATON = { +BASE_MODEL = { "functions": { "utility": utility, "next_wealth": next_wealth, @@ -117,7 +117,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): } -PHELPS_DEATON_FULLY_DISCRETE = { +BASE_MODEL_FULLY_DISCRETE = { "functions": { "utility": utility, "next_wealth": next_wealth, @@ -140,7 +140,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): } -PHELPS_DEATON_WITH_FILTERS = { +BASE_MODEL_WITH_FILTERS = { "functions": { "utility": utility_with_filter, "next_wealth": next_wealth, diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 8d4707d3..8c1789de 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -1,11 +1,11 @@ -"""Example specifications of a simple Phelps-Deaton style stochastic model. +"""Example specifications of a stochastic consumption-saving model. This specification is motivated by the example model presented in the paper: "The endogenous grid method for discrete-continuous dynamic choice models with (or without) taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, https://doi.org/10.3982/QE643). -See also the specifications in tests/test_models/phelps_deaton.py. +See also the specifications in tests/test_models/deterministic.py. """ diff --git a/tests/test_next_state.py b/tests/test_next_state.py index f37f02c9..3737ae86 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -5,7 +5,7 @@ from lcm.process_model import process_model from pybaum import tree_equal -from tests.test_models.phelps_deaton import PHELPS_DEATON +from tests.test_models.deterministic import BASE_MODEL # ====================================================================================== # Solve target @@ -13,7 +13,7 @@ def test_get_next_state_function_with_solve_target(): - model = process_model(PHELPS_DEATON) + model = process_model(BASE_MODEL) got_func = get_next_state_function(model, target="solve") params = { diff --git a/tests/test_process_model.py b/tests/test_process_model.py index 53b1a3bc..c3acc07f 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -17,10 +17,10 @@ from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal -from tests.test_models.phelps_deaton import ( +from tests.test_models.deterministic import ( + BASE_MODEL, + BASE_MODEL_WITH_FILTERS, N_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, ) @@ -100,8 +100,8 @@ def test_get_grids(user_model): assert_array_equal(got["c"], jnp.array([2, 3])) -def test_process_phelps_deaton_with_filters(): - model = process_model(PHELPS_DEATON_WITH_FILTERS) +def test_process_model_with_filters(): + model = process_model(BASE_MODEL_WITH_FILTERS) # Variable Info assert ( @@ -156,8 +156,8 @@ def test_process_phelps_deaton_with_filters(): assert ~model.function_info.loc["utility"].to_numpy().any() -def test_process_phelps_deaton(): - model = process_model(PHELPS_DEATON) +def test_process_model(): + model = process_model(BASE_MODEL) # Variable Info assert ~(model.variable_info["is_sparse"].to_numpy()).any() diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 3f2e2630..8fdb29c8 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -5,9 +5,9 @@ from numpy.testing import assert_array_almost_equal as aaae from pandas.testing import assert_frame_equal -from tests.test_models.phelps_deaton import PHELPS_DEATON +from tests.test_models.deterministic import BASE_MODEL -REGRESSION_TEST_MODEL = {**PHELPS_DEATON, "n_perids": 5} +REGRESSION_TEST_MODEL = {**BASE_MODEL, "n_perids": 5} REGRESSION_TEST_PARAMS = { "beta": 0.95, "utility": {"disutility_of_work": 1.0}, diff --git a/tests/test_simulate.py b/tests/test_simulate.py index fb5d268a..6291cc0e 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -27,10 +27,10 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal from pybaum import tree_equal -from tests.test_models.phelps_deaton import ( +from tests.test_models.deterministic import ( + BASE_MODEL, + BASE_MODEL_WITH_FILTERS, N_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, ) # ====================================================================================== @@ -40,7 +40,7 @@ @pytest.fixture() def simulate_inputs(): - user_model = {**PHELPS_DEATON, "n_periods": 1} + user_model = {**BASE_MODEL, "n_periods": 1} model = process_model(user_model) _, space_info, _, _ = create_state_choice_space( @@ -104,9 +104,9 @@ def test_simulate_using_raw_inputs(simulate_inputs): @pytest.fixture() -def phelps_deaton_model_solution(): +def base_model_solution(): def _model_solution(n_periods): - model = {**PHELPS_DEATON, "n_periods": n_periods} + model = {**BASE_MODEL, "n_periods": n_periods} model["functions"] = { # remove dependency on age, so that wage becomes a parameter name: func @@ -130,9 +130,9 @@ def _model_solution(n_periods): return _model_solution -@pytest.mark.parametrize("n_periods", range(3, PHELPS_DEATON["n_periods"] + 1)) -def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods): - vf_arr_list, params, model = phelps_deaton_model_solution(n_periods) +@pytest.mark.parametrize("n_periods", range(3, BASE_MODEL["n_periods"] + 1)) +def test_simulate_using_get_lcm_function(base_model_solution, n_periods): + vf_arr_list, params, model = base_model_solution(n_periods) simulate_model, _ = get_lcm_function(model=model, targets="simulate") @@ -175,7 +175,7 @@ def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods def test_effect_of_beta_on_last_period(): - model = {**PHELPS_DEATON, "n_periods": 5} + model = {**BASE_MODEL, "n_periods": 5} # Model solutions # ================================================================================== @@ -229,7 +229,7 @@ def test_effect_of_beta_on_last_period(): def test_effect_of_disutility_of_work(): - model = {**PHELPS_DEATON, "n_periods": 5} + model = {**BASE_MODEL, "n_periods": 5} # Model solutions # ================================================================================== @@ -404,7 +404,7 @@ def test_filter_ccv_policy(): def test_create_data_state_choice_space(): - model = process_model(PHELPS_DEATON_WITH_FILTERS) + model = process_model(BASE_MODEL_WITH_FILTERS) got_space, got_segment_info = create_data_scs( states={ "wealth": jnp.array([10.0, 20.0]), diff --git a/tests/test_state_space.py b/tests/test_state_space.py index a5a5b8e0..7a2e7fe4 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -13,11 +13,11 @@ ) from numpy.testing import assert_array_almost_equal as aaae -from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS +from tests.test_models.deterministic import BASE_MODEL_WITH_FILTERS def test_create_state_choice_space(): - _model = process_model(PHELPS_DEATON_WITH_FILTERS) + _model = process_model(BASE_MODEL_WITH_FILTERS) create_state_choice_space( model=_model, period=0, From 21a05f909c7f2c965e0e0bc95ec7c7cf9961ffb4 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 09:43:20 +0100 Subject: [PATCH 20/27] Implement requested changes and correct typo --- examples/long_running.py | 9 ++------- tests/test_analytical_solution.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index a54368f7..d1787e45 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -1,4 +1,4 @@ -"""Example specification for a consumption-savings model with health and leisure.""" +"""Example specification for a consumption-savings model with health and exercise.""" import jax.numpy as jnp @@ -33,10 +33,6 @@ def labor_income(wage, working): return wage * working -def working(leisure): - return 1 - leisure - - def wage(age): return 1 + 0.1 * age @@ -74,12 +70,11 @@ def consumption_constraint(consumption, wealth, labor_income): "next_health": next_health, "consumption_constraint": consumption_constraint, "labor_income": labor_income, - "working": working, "wage": wage, "age": age, }, "choices": { - "leisure": {"options": [0, 1]}, + "working": {"options": [0, 1]}, "consumption": { "grid_type": "linspace", "start": 1, diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 76d5e231..a795c716 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -1,4 +1,11 @@ -"""Testing against the analytical solution by Iskhakov et al (2017).""" +"""Testing against the analytical solution of Iskhakov et al. (2017). + +The benchmark is taken from the paper "The endogenous grid method for +discrete-continuous dynamic choice models with (or without) taste shocks" by Fedor +Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, +https://doi.org/10.3982/QE643). + +""" import numpy as np import pytest @@ -12,7 +19,7 @@ # Model specifications # ====================================================================================== -ISKHAVOV_2017_PARAMS = { +ISKHAKOV_2017_PARAMS = { "beta": 0.98, "utility": {"disutility_of_work": None}, "next_wealth": { @@ -24,11 +31,11 @@ TEST_CASES = { "iskhakov_2017_five_periods": { "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 5}, - "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}}, + "params": {**ISKHAKOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}}, }, "iskhakov_2017_low_delta": { "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 3}, - "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}}, + "params": {**ISKHAKOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}}, }, } From bd8e90fa30c2122887a64b1121e2625971eba833 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:14:26 +0100 Subject: [PATCH 21/27] Improve todos --- pyproject.toml | 3 +++ tests/test_models/deterministic.py | 4 +++- tests/test_models/stochastic.py | 4 +++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 443a74ea..372edfd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ extend-ignore = [ # exception must not use an f-string literal "EM102", + # line contains a todo + "FIX002", + # Too many arguments to function call "PLR0913", diff --git a/tests/test_models/deterministic.py b/tests/test_models/deterministic.py index 80f682f4..e725e973 100644 --- a/tests/test_models/deterministic.py +++ b/tests/test_models/deterministic.py @@ -38,7 +38,9 @@ def utility_with_filter( disutility_of_work, # Temporary workaround for bug described in issue #30, which requires us to pass # all state variables to the utility function. - lagged_retirement, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. + # TODO(@timmens): Remove unused arguments once #30 is fixed. + # https://github.com/OpenSourceEconomics/lcm/issues/30 + lagged_retirement, # noqa: ARG001 ): return utility(consumption, working=working, disutility_of_work=disutility_of_work) diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 8c1789de..1cf616b3 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -35,7 +35,9 @@ def utility( health, # Temporary workaround for bug described in issue #30, which requires us to pass # all state variables to the utility function. - partner, # noqa: ARG001, TODO: Remove unused arguments once #30 is fixed. + # TODO(@timmens): Remove unused arguments once #30 is fixed. + # https://github.com/OpenSourceEconomics/lcm/issues/30 + partner, # noqa: ARG001 disutility_of_work, ): return jnp.log(consumption) - (1 - health / 2) * disutility_of_work * working From fb4a1e48dc51d6773efea9f052db95c6bb89ff0c Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:21:43 +0100 Subject: [PATCH 22/27] Use labor_income in stochastic models --- tests/test_models/stochastic.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 1cf616b3..d180426c 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -46,8 +46,12 @@ def utility( # -------------------------------------------------------------------------------------- # Deterministic state transitions # -------------------------------------------------------------------------------------- -def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth - consumption) + wage * working +def next_wealth(wealth, consumption, labor_income, interest_rate): + return (1 + interest_rate) * (wealth - consumption) + labor_income + + +def labor_income(working, wage): + return working * wage # -------------------------------------------------------------------------------------- @@ -81,6 +85,7 @@ def consumption_constraint(consumption, wealth): "next_health": next_health, "next_partner": next_partner, "consumption_constraint": consumption_constraint, + "labor_income": labor_income, }, "choices": { "working": {"options": [0, 1]}, @@ -108,9 +113,10 @@ def consumption_constraint(consumption, wealth): PARAMS = { "beta": 0.95, "utility": {"disutility_of_work": 0.5}, - "next_wealth": {"interest_rate": 0.05, "wage": 10.0}, + "next_wealth": {"interest_rate": 0.05}, "next_health": {}, "consumption_constraint": {}, + "labor_income": {"wage": 10.0}, "shocks": { # Health shock: # ------------------------------------------------------------------------------ From a797bf3bc05d371f93b7706af9bdd98ae4dba616 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:30:36 +0100 Subject: [PATCH 23/27] Fix test_process_model.py --- tests/test_models/deterministic.py | 19 +++++++++++++++---- tests/test_models/stochastic.py | 11 +++++++---- tests/test_process_model.py | 12 +++++++++--- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/test_models/deterministic.py b/tests/test_models/deterministic.py index e725e973..a497448e 100644 --- a/tests/test_models/deterministic.py +++ b/tests/test_models/deterministic.py @@ -48,6 +48,10 @@ def utility_with_filter( # -------------------------------------------------------------------------------------- # Auxiliary variables # -------------------------------------------------------------------------------------- +def labor_income(working, wage): + return working * wage + + def working(retirement): return 1 - retirement @@ -63,8 +67,8 @@ def age(_period): # -------------------------------------------------------------------------------------- # State transitions # -------------------------------------------------------------------------------------- -def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth - consumption) + wage * working +def next_wealth(wealth, consumption, labor_income, interest_rate): + return (1 + interest_rate) * (wealth - consumption) + labor_income # -------------------------------------------------------------------------------------- @@ -94,6 +98,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, + "labor_income": labor_income, "working": working, "wage": wage, "age": age, @@ -124,7 +129,10 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, + "labor_income": labor_income, "working": working, + "wage": wage, + "age": age, }, "choices": { "retirement": {"options": [0, 1]}, @@ -146,10 +154,13 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "functions": { "utility": utility_with_filter, "next_wealth": next_wealth, + "next_lagged_retirement": lambda retirement: retirement, "consumption_constraint": consumption_constraint, - "working": working, "absorbing_retirement_filter": absorbing_retirement_filter, - "next_lagged_retirement": lambda retirement: retirement, + "labor_income": labor_income, + "working": working, + "wage": wage, + "age": age, }, "choices": { "retirement": {"options": [0, 1]}, diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index d180426c..ce18acab 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -43,6 +43,13 @@ def utility( return jnp.log(consumption) - (1 - health / 2) * disutility_of_work * working +# -------------------------------------------------------------------------------------- +# Auxiliary variables +# -------------------------------------------------------------------------------------- +def labor_income(working, wage): + return working * wage + + # -------------------------------------------------------------------------------------- # Deterministic state transitions # -------------------------------------------------------------------------------------- @@ -50,10 +57,6 @@ def next_wealth(wealth, consumption, labor_income, interest_rate): return (1 + interest_rate) * (wealth - consumption) + labor_income -def labor_income(working, wage): - return working * wage - - # -------------------------------------------------------------------------------------- # Stochastic state transitions # -------------------------------------------------------------------------------------- diff --git a/tests/test_process_model.py b/tests/test_process_model.py index c3acc07f..a104a786 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -151,8 +151,14 @@ def test_process_model_with_filters(): # Functions assert ( model.function_info["is_next"].to_numpy() - == np.array([False, True, False, False, False, True]) + == np.array([False, True, True, False, False, False, False, False, False]) ).all() + + assert ( + model.function_info["is_constraint"].to_numpy() + == np.array([False, False, False, True, False, False, False, False, False]) + ).all() + assert ~model.function_info.loc["utility"].to_numpy().any() @@ -200,12 +206,12 @@ def test_process_model(): # Functions assert ( model.function_info["is_next"].to_numpy() - == np.array([False, True, False, False, False, False]) + == np.array([False, True, False, False, False, False, False]) ).all() assert ( model.function_info["is_constraint"].to_numpy() - == np.array([False, False, True, False, False, False]) + == np.array([False, False, True, False, False, False, False]) ).all() assert ~model.function_info.loc["utility"].to_numpy().any() From 435636d32f8e2907b956d91c66c4260bf33cee7c Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:49:31 +0100 Subject: [PATCH 24/27] Fix test_analytical_solution.py --- tests/test_analytical_solution.py | 41 ++++++++++++++++++++++--------- tests/test_simulate.py | 23 +++++++++-------- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index a795c716..e0ba1a73 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -7,6 +7,8 @@ """ +from copy import deepcopy + import numpy as np import pytest from lcm._config import TEST_DATA @@ -19,23 +21,38 @@ # Model specifications # ====================================================================================== -ISKHAKOV_2017_PARAMS = { - "beta": 0.98, - "utility": {"disutility_of_work": None}, - "next_wealth": { - "interest_rate": 0.0, - "wage": 20.0, - }, -} + +def _get_iskhakov_2017_model(n_periods): + model_config = deepcopy(BASE_MODEL_WITH_FILTERS) + model_config["n_periods"] = n_periods + # remove age and wage functions, as they are not modelled in Iskhakov et al. (2017) + model_config["functions"] = { + name: func + for name, func in model_config["functions"].items() + if name not in ("age", "wage") + } + return model_config + + +def _get_iskhakov_2017_params(disutility_of_work): + return { + "beta": 0.98, + "utility": {"disutility_of_work": disutility_of_work}, + "next_wealth": { + "interest_rate": 0.0, + }, + "labor_income": {"wage": 20.0}, + } + TEST_CASES = { "iskhakov_2017_five_periods": { - "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 5}, - "params": {**ISKHAKOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}}, + "model": _get_iskhakov_2017_model(n_periods=5), + "params": _get_iskhakov_2017_params(disutility_of_work=1.0), }, "iskhakov_2017_low_delta": { - "model": {**BASE_MODEL_WITH_FILTERS, "n_periods": 3}, - "params": {**ISKHAKOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}}, + "model": _get_iskhakov_2017_model(n_periods=3), + "params": _get_iskhakov_2017_params(disutility_of_work=0.1), }, } diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 6291cc0e..2304f24f 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -116,12 +116,12 @@ def _model_solution(n_periods): solve_model, _ = get_lcm_function(model=model) params = { - "beta": 1.0, - "utility": {"disutility_of_work": 1.0}, + "beta": 0.95, + "utility": {"disutility_of_work": 0.25}, "next_wealth": { "interest_rate": 0.05, - "wage": 1.0, }, + "labor_income": {"wage": 5.0}, } vf_arr_list = solve_model(params) @@ -130,9 +130,9 @@ def _model_solution(n_periods): return _model_solution -@pytest.mark.parametrize("n_periods", range(3, BASE_MODEL["n_periods"] + 1)) -def test_simulate_using_get_lcm_function(base_model_solution, n_periods): - vf_arr_list, params, model = base_model_solution(n_periods) +def test_simulate_using_get_lcm_function(base_model_solution): + n_periods = 3 + vf_arr_list, params, model = base_model_solution(n_periods=n_periods) simulate_model, _ = get_lcm_function(model=model, targets="simulate") @@ -159,14 +159,13 @@ def test_simulate_using_get_lcm_function(base_model_solution, n_periods): last_period_index = n_periods - 1 assert_array_equal(res.loc[last_period_index, :]["retirement"], 1) - # assert that higher wealth leads to higher consumption for period in range(n_periods): - assert (res.loc[period, :]["consumption"].diff()[1:] >= 0).all() - # The following does not work. I.e. the continuation value in each period is not - # weakly increasing in wealth. It is unclear if this needs to hold. - # ------------------------------------------------------------------------------ - # assert jnp.all(jnp.diff(res[period]["value"]) >= 0) # noqa: ERA001 + # assert that higher wealth leads to higher consumption in each period + assert (res.loc[period]["consumption"].diff()[1:] >= 0).all() + + # assert that higher wealth leads to higher value function in each period + assert (res.loc[period]["value"].diff()[1:] >= 0).all() # ====================================================================================== From 36b383f74b8da01d82afd29101a1dbfdbbbd580c Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:56:15 +0100 Subject: [PATCH 25/27] Rename utility -> base_model_utility in tests --- tests/test_entry_point.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 7364f3ca..c74041c5 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -14,8 +14,8 @@ BASE_MODEL, BASE_MODEL_FULLY_DISCRETE, BASE_MODEL_WITH_FILTERS, - utility, ) +from tests.test_models.deterministic import utility as base_model_utility MODELS = { "simple": BASE_MODEL, @@ -176,7 +176,11 @@ def test_create_compute_conditional_continuation_value(): params=params, vf_arr=None, ) - assert val == utility(consumption=30.0, working=0, disutility_of_work=1.0) + assert val == base_model_utility( + consumption=30.0, + working=0, + disutility_of_work=1.0, + ) def test_create_compute_conditional_continuation_value_with_discrete_model(): @@ -219,7 +223,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): params=params, vf_arr=None, ) - assert val == utility(consumption=2, working=0, disutility_of_work=1.0) + assert val == base_model_utility(consumption=2, working=0, disutility_of_work=1.0) # ====================================================================================== @@ -268,7 +272,11 @@ def test_create_compute_conditional_continuation_policy(): vf_arr=None, ) assert policy == 2 - assert val == utility(consumption=30.0, working=0, disutility_of_work=1.0) + assert val == base_model_utility( + consumption=30.0, + working=0, + disutility_of_work=1.0, + ) def test_create_compute_conditional_continuation_policy_with_discrete_model(): @@ -312,4 +320,4 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): vf_arr=None, ) assert policy == 1 - assert val == utility(consumption=2, working=0, disutility_of_work=1.0) + assert val == base_model_utility(consumption=2, working=0, disutility_of_work=1.0) From efe11cbb432f9830aaf10a2ab999df8a90c1d23a Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 10:59:52 +0100 Subject: [PATCH 26/27] Remove MODELS dictionary in test_entry_point.py --- tests/test_entry_point.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index c74041c5..a7d66810 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -17,19 +17,16 @@ ) from tests.test_models.deterministic import utility as base_model_utility -MODELS = { - "simple": BASE_MODEL, - "with_filters": BASE_MODEL_WITH_FILTERS, - "fully_discrete": BASE_MODEL_FULLY_DISCRETE, -} - - # ====================================================================================== # Solve # ====================================================================================== -@pytest.mark.parametrize("user_model", list(MODELS.values()), ids=list(MODELS)) +@pytest.mark.parametrize( + "user_model", + [BASE_MODEL, BASE_MODEL_FULLY_DISCRETE, BASE_MODEL_WITH_FILTERS], + ids=["base", "fully_discrete", "with_filters"], +) def test_get_lcm_function_with_solve_target(user_model): solve_model, params_template = get_lcm_function(model=user_model) @@ -46,7 +43,7 @@ def test_get_lcm_function_with_solve_target(user_model): @pytest.mark.parametrize( "user_model", [BASE_MODEL, BASE_MODEL_FULLY_DISCRETE], - ids=["simple", "fully_discrete"], + ids=["base", "fully_discrete"], ) def test_get_lcm_function_with_simulation_target_simple(user_model): simulate, params_template = get_lcm_function( @@ -67,7 +64,7 @@ def test_get_lcm_function_with_simulation_target_simple(user_model): @pytest.mark.parametrize( "user_model", [BASE_MODEL, BASE_MODEL_FULLY_DISCRETE], - ids=["simple", "fully_discrete"], + ids=["base", "fully_discrete"], ) def test_get_lcm_function_with_simulation_is_coherent(user_model): """Test that solve_and_simulate creates same output as solve then simulate.""" From 48e6070ef490632d6a62678ead7af23f55b62cb9 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 29 Feb 2024 11:24:55 +0100 Subject: [PATCH 27/27] Update examples README --- examples/README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/README.md b/examples/README.md index 7b86f009..5dfbbb0e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,8 +8,16 @@ ## Running an example -Say you want to solve the [`long_running`](./long_running.py) example locally. In a -Python shell, execute: +Say you want to solve the [`long_running`](./long_running.py) example locally. First, +clone this repository and move into the example folder. In a console, type: + +```console +$ git clone https://github.com/OpenSourceEconomics/lcm.git +$ cd lcm/examples +``` + +Make sure that you have `lcm` installed in your Python environment. Then, in a Python +shell run ```python from lcm.entry_point import get_lcm_function