Skip to content

Commit

Permalink
Remove get_models.py module
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 28, 2024
1 parent 456bb79 commit a3005c4
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 158 deletions.
136 changes: 0 additions & 136 deletions src/lcm/get_model.py

This file was deleted.

Binary file modified tests/data/regression_tests/simulation.pkl
Binary file not shown.
Binary file modified tests/data/regression_tests/solution.pkl
Binary file not shown.
50 changes: 38 additions & 12 deletions tests/test_analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,67 @@
import pytest
from lcm._config import TEST_DATA
from lcm.entry_point import get_lcm_function
from lcm.get_model import get_model
from numpy.testing import assert_array_almost_equal as aaae

from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS

# ======================================================================================
# Model specifications
# ======================================================================================

ISKHAVOV_2017_PARAMS = {
"beta": 0.98,
"utility": {"disutility_of_work": None},
"next_wealth": {
"interest_rate": 0.0,
"wage": 20.0,
},
}

TEST_CASES = {
"iskhakov_2017_five_periods": get_model("iskhakov_2017_five_periods"),
"iskhakov_2017_low_delta": get_model(
"iskhakov_2017_low_disutility_of_work",
),
"iskhakov_2017_five_periods": {
"model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 5},
"params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 1.0}},
},
"iskhakov_2017_low_delta": {
"model": {**PHELPS_DEATON_WITH_FILTERS, "n_periods": 3},
"params": {**ISKHAVOV_2017_PARAMS, "utility": {"disutility_of_work": 0.1}},
},
}


def mean_square_error(x, y, axis=None):
return np.mean((x - y) ** 2, axis=axis)


@pytest.mark.parametrize(("model_name", "model_config"), TEST_CASES.items())
def test_analytical_solution(model_name, model_config):
# ======================================================================================
# Test
# ======================================================================================


@pytest.mark.parametrize(("model_name", "model_and_params"), TEST_CASES.items())
def test_analytical_solution(model_name, model_and_params):
"""Test that the numerical solution matches the analytical solution.
The analytical solution is from Iskhakov et al (2017) and is generated
in the development repository: github.com/opensourceeconomics/lcm-dev.
"""
# ----------------------------------------------------------------------------------
# Compute LCM solution
# ==================================================================================
solve_model, _ = get_lcm_function(model=model_config.model)
# ----------------------------------------------------------------------------------
solve_model, _ = get_lcm_function(model=model_and_params["model"])

vf_arr_list = solve_model(params=model_config.params)
vf_arr_list = solve_model(params=model_and_params["params"])
_numerical = np.stack(vf_arr_list)
numerical = {
"worker": _numerical[:, 0, :],
"retired": _numerical[:, 1, :],
}

# ----------------------------------------------------------------------------------
# Load analytical solution
# ==================================================================================
# ----------------------------------------------------------------------------------
analytical = {
_type: np.genfromtxt(
TEST_DATA.joinpath(
Expand All @@ -51,8 +76,9 @@ def test_analytical_solution(model_name, model_config):
for _type in ["worker", "retired"]
}

# ----------------------------------------------------------------------------------
# Compare
# ==================================================================================
# ----------------------------------------------------------------------------------
for _type in ["worker", "retired"]:
_analytical = np.array(analytical[_type])
_numerical = numerical[_type]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_models/phelps_deaton.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
"consumption": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"stop": 400,
"n_points": N_GRID_POINTS["consumption"],
},
},
"states": {
"wealth": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"stop": 400,
"n_points": N_GRID_POINTS["wealth"],
},
},
Expand All @@ -134,7 +134,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
"wealth": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"stop": 400,
"n_points": N_GRID_POINTS["wealth"],
},
},
Expand All @@ -156,15 +156,15 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
"consumption": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"stop": 400,
"n_points": N_GRID_POINTS["consumption"],
},
},
"states": {
"wealth": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"stop": 400,
"n_points": N_GRID_POINTS["wealth"],
},
"lagged_retirement": {"options": [0, 1]},
Expand Down
8 changes: 4 additions & 4 deletions tests/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ def test_process_phelps_deaton_with_filters():
# Gridspecs
wealth_specs = GridSpec(
kind="linspace",
specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]},
specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["wealth"]},
)

assert model.gridspecs["wealth"] == wealth_specs

consumption_specs = GridSpec(
kind="linspace",
specs={"start": 1, "stop": 100, "n_points": N_GRID_POINTS["consumption"]},
specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["consumption"]},
)
assert model.gridspecs["consumption"] == consumption_specs

Expand Down Expand Up @@ -173,14 +173,14 @@ def test_process_phelps_deaton():
# Gridspecs
wealth_specs = GridSpec(
kind="linspace",
specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["wealth"]},
specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["wealth"]},
)

assert model.gridspecs["wealth"] == wealth_specs

consumption_specs = GridSpec(
kind="linspace",
specs={"start": 0, "stop": 100, "n_points": N_GRID_POINTS["consumption"]},
specs={"start": 1, "stop": 400, "n_points": N_GRID_POINTS["consumption"]},
)
assert model.gridspecs["consumption"] == consumption_specs

Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods
params,
vf_arr_list=vf_arr_list,
initial_states={
"wealth": jnp.array([1.0, 20, 40, 70]),
"wealth": jnp.array([20.0, 150, 250, 320]),
},
additional_targets=["utility", "consumption_constraint"],
)
Expand Down
File renamed without changes.

0 comments on commit a3005c4

Please sign in to comment.