diff --git a/src/lcm/get_model.py b/src/lcm/get_model.py deleted file mode 100644 index 61e932ca..00000000 --- a/src/lcm/get_model.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Get a user model and parameters.""" - -from typing import NamedTuple - -from pybaum import tree_update - -from tests.test_models.phelps_deaton import ( - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) - - -class ModelAndParams(NamedTuple): - """Model and parameters.""" - - model: dict - params: dict - - -def get_model(model: str): - """Get a user model and parameters. - - Args: - model (str): Model name. - - Returns: - NamedTuple: Model and parameters. Has attributes `model` and `params`. - - """ - if model not in MODELS: - raise ValueError(f"Model {model} not found. Choose from {set(MODELS.keys())}.") - return MODELS[model] - - -# ====================================================================================== -# Models -# ====================================================================================== - -# Remove age and wage functions from Phelps-Deaton model, as they are not used in the -# original paper. -PHELPS_DEATON_WITHOUT_AGE = PHELPS_DEATON.copy() -PHELPS_DEATON_WITHOUT_AGE["functions"] = { - name: func - for name, func in PHELPS_DEATON_WITHOUT_AGE["functions"].items() - if name not in ["age", "wage"] -} - - -PHELPS_DEATON_FIVE_PERIODS = { - **PHELPS_DEATON_WITHOUT_AGE, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 500, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 100, - }, - }, - "n_periods": 5, -} - - -ISKHAKOV_2017_FIVE_PERIODS = { - **PHELPS_DEATON_WITH_FILTERS, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 500, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 400, - "n_points": 100, - }, - "lagged_retirement": {"options": [0, 1]}, - }, - "n_periods": 5, -} - - -ISKHAKOV_2017_THREE_PERIODS = tree_update(ISKHAKOV_2017_FIVE_PERIODS, {"n_periods": 3}) - -# ====================================================================================== -# Models and params -# ====================================================================================== - -MODELS = { - "phelps_deaton_regression_test": ModelAndParams( - model=PHELPS_DEATON_FIVE_PERIODS, - params={ - "beta": 1.0, - "utility": {"disutility_of_work": 1.0}, - "next_wealth": { - "interest_rate": 0.05, - "wage": 1.0, - }, - }, - ), - "iskhakov_2017_five_periods": ModelAndParams( - model=ISKHAKOV_2017_FIVE_PERIODS, - params={ - "beta": 0.98, - "utility": {"disutility_of_work": 1.0}, - "next_wealth": { - "interest_rate": 0.0, - "wage": 20.0, - }, - }, - ), - "iskhakov_2017_low_disutility_of_work": ModelAndParams( - model=ISKHAKOV_2017_THREE_PERIODS, - params={ - "beta": 0.98, - "utility": {"disutility_of_work": 0.1}, - "next_wealth": { - "interest_rate": 0.0, - "wage": 20.0, - }, - }, - ), -} diff --git a/tests/data/regression_tests/simulation.pkl b/tests/data/regression_tests/simulation.pkl index 3abb6e85..d28f5814 100644 Binary files a/tests/data/regression_tests/simulation.pkl and b/tests/data/regression_tests/simulation.pkl differ diff --git a/tests/data/regression_tests/solution.pkl b/tests/data/regression_tests/solution.pkl index d45858dd..c1889cb6 100644 Binary files a/tests/data/regression_tests/solution.pkl and b/tests/data/regression_tests/solution.pkl differ diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 34875f1c..46ef2fdd 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -4,14 +4,32 @@ import pytest from lcm._config import TEST_DATA from lcm.entry_point import get_lcm_function -from lcm.get_model import get_model from numpy.testing import assert_array_almost_equal as aaae +from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS + +# ====================================================================================== +# Model specifications +# ====================================================================================== + +ISKHAVOV_2017_PARAMS = { + "beta": 0.98, + "utility": {"disutility_of_work": None}, + "next_wealth": { + "interest_rate": 0.0, + "wage": 20.0, + }, +} + TEST_CASES = { - "iskhakov_2017_five_periods": get_model("iskhakov_2017_five_periods"), - "iskhakov_2017_low_delta": get_model( - "iskhakov_2017_low_disutility_of_work", - ), + "iskhakov_2017_five_periods": { + "model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 5}, + "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}}, + }, + "iskhakov_2017_low_delta": { + "model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 3}, + "params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}}, + }, } @@ -19,27 +37,34 @@ def mean_square_error(x, y, axis=None): return np.mean((x - y) ** 2, axis=axis) -@pytest.mark.parametrize(("model_name", "model_config"), TEST_CASES.items()) -def test_analytical_solution(model_name, model_config): +# ====================================================================================== +# Test +# ====================================================================================== + + +@pytest.mark.parametrize(("model_name", "model_and_params"), TEST_CASES.items()) +def test_analytical_solution(model_name, model_and_params): """Test that the numerical solution matches the analytical solution. The analytical solution is from Iskhakov et al (2017) and is generated in the development repository: github.com/opensourceeconomics/lcm-dev. """ + # ---------------------------------------------------------------------------------- # Compute LCM solution - # ================================================================================== - solve_model, _ = get_lcm_function(model=model_config.model) + # ---------------------------------------------------------------------------------- + solve_model, _ = get_lcm_function(model=model_and_params["model"]) - vf_arr_list = solve_model(params=model_config.params) + vf_arr_list = solve_model(params=model_and_params["params"]) _numerical = np.stack(vf_arr_list) numerical = { "worker": _numerical[:, 0, :], "retired": _numerical[:, 1, :], } + # ---------------------------------------------------------------------------------- # Load analytical solution - # ================================================================================== + # ---------------------------------------------------------------------------------- analytical = { _type: np.genfromtxt( TEST_DATA.joinpath( @@ -51,8 +76,9 @@ def test_analytical_solution(model_name, model_config): for _type in ["worker", "retired"] } + # ---------------------------------------------------------------------------------- # Compare - # ================================================================================== + # ---------------------------------------------------------------------------------- for _type in ["worker", "retired"]: _analytical = np.array(analytical[_type]) _numerical = numerical[_type] diff --git a/tests/test_models/phelps_deaton.py b/tests/test_models/phelps_deaton.py index e6b88e14..460c2a93 100644 --- a/tests/test_models/phelps_deaton.py +++ b/tests/test_models/phelps_deaton.py @@ -103,7 +103,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "consumption": { "grid_type": "linspace", "start": 1, - "stop": 100, + "stop": 400, "n_points": N_GRID_POINTS["consumption"], }, }, @@ -111,7 +111,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "wealth": { "grid_type": "linspace", "start": 1, - "stop": 100, + "stop": 400, "n_points": N_GRID_POINTS["wealth"], }, }, @@ -134,7 +134,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "wealth": { "grid_type": "linspace", "start": 1, - "stop": 100, + "stop": 400, "n_points": N_GRID_POINTS["wealth"], }, }, @@ -156,7 +156,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "consumption": { "grid_type": "linspace", "start": 1, - "stop": 100, + "stop": 400, "n_points": N_GRID_POINTS["consumption"], }, }, @@ -164,7 +164,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "wealth": { "grid_type": "linspace", "start": 1, - "stop": 100, + "stop": 400, "n_points": N_GRID_POINTS["wealth"], }, "lagged_retirement": {"options": [0, 1]}, diff --git a/tests/test_process_model.py b/tests/test_process_model.py index 06d9a28d..53b1a3bc 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -122,14 +122,14 @@ def test_process_phelps_deaton_with_filters(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]}, + specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["wealth"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 1, "stop": 100, "n_points": N_GRID_POINTS["consumption"]}, + specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["consumption"]}, ) assert model.gridspecs["consumption"] == consumption_specs @@ -173,14 +173,14 @@ def test_process_phelps_deaton(): # Gridspecs wealth_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]}, + specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["wealth"]}, ) assert model.gridspecs["wealth"] == wealth_specs consumption_specs = GridSpec( kind="linspace", - specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["consumption"]}, + specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["consumption"]}, ) assert model.gridspecs["consumption"] == consumption_specs diff --git a/tests/test_simulate.py b/tests/test_simulate.py index adfc92b2..fb5d268a 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -140,7 +140,7 @@ def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods params, vf_arr_list=vf_arr_list, initial_states={ - "wealth": jnp.array([1.0, 20, 40, 70]), + "wealth": jnp.array([20.0, 150, 250, 320]), }, additional_targets=["utility", "consumption_constraint"], ) diff --git a/tests/test_analytical_solution_on_toy_model.py b/tests/test_solution_on_toy_model.py similarity index 100% rename from tests/test_analytical_solution_on_toy_model.py rename to tests/test_solution_on_toy_model.py