From 3272686224b9774b709a7c14e1e527d8f3052ab6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 27 Aug 2024 13:40:20 +0800 Subject: [PATCH] Allow mixed simulations from mixed scenario types --- cge_modeling/base/cge.py | 29 ++++++++-------- cge_modeling/base/utilities.py | 42 ++++++++++++++++++++++++ tests/test_cge_model.py | 60 ++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 13 deletions(-) diff --git a/cge_modeling/base/cge.py b/cge_modeling/base/cge.py index e59cf46..9918797 100644 --- a/cge_modeling/base/cge.py +++ b/cge_modeling/base/cge.py @@ -32,6 +32,7 @@ from cge_modeling.base.utilities import ( _replace_dim_marker_with_dim_name, _validate_input, + create_final_param_dict, ensure_input_is_sequence, flat_array_to_variable_dict, get_method_defaults, @@ -987,22 +988,24 @@ def simulate( f"initial_state must be a Result or a dict of initial values, found {type(initial_state)}" ) - if not all(x in x0_var_param for x in self.variable_names + self.parameter_names): + all_model_objs = [*self.variable_names, *self.parameter_names] + + if not all(x in x0_var_param for x in all_model_objs): + missing_objs = [x for x in all_model_objs if x not in x0_var_param] raise ValueError( - f"initial_state must contain values for all variables and parameters in the model. Found {initial_state.keys()}" + f"initial_state must contain values for all variables and parameters in the model. Did not find values " + f"for {', '.join(missing_objs)}" ) - final_param_dict = deepcopy(x0_var_param) - if final_values is not None: - final_param_dict.update(final_values) - elif final_delta is not None: - for k, v in final_delta.items(): - final_param_dict[k] += v - elif final_delta_pct is not None: - for k, v in final_delta_pct.items(): - final_param_dict[k] *= v - else: - raise ValueError() + if any(x not in all_model_objs for x in x0_var_param): + unknown_objs = [x for x in x0_var_param if x not in all_model_objs] + raise ValueError( + f"initial_state contains values for variables or parameters not in the model: {', '.join(unknown_objs)}" + ) + + final_param_dict = create_final_param_dict( + x0_var_param, final_values, final_delta, final_delta_pct + ) _, theta_simulation = variable_dict_to_flat_array( final_param_dict, self.variables, self.parameters diff --git a/cge_modeling/base/utilities.py b/cge_modeling/base/utilities.py index 7b928f9..eae4323 100644 --- a/cge_modeling/base/utilities.py +++ b/cge_modeling/base/utilities.py @@ -1,6 +1,8 @@ import re +from collections import Counter from collections.abc import Sequence +from copy import deepcopy from itertools import product from typing import Any, cast @@ -417,3 +419,43 @@ def get_method_defaults(use_grad, use_hess, use_hessp, method): use_hess = False return use_grad, use_hess, use_hessp + + +def create_final_param_dict( + initial_params: dict[str, Any], + final_values: dict[str, Any] | None, + final_delta: dict[str, Any] | None, + final_delta_pct: dict[str, Any] | None, +) -> dict[str, Any]: + scenario_params = deepcopy(initial_params) + final_values = final_values if final_values is not None else {} + final_delta = final_delta if final_delta is not None else {} + final_delta_pct = final_delta_pct if final_delta_pct is not None else {} + + all_params_to_update = [*final_values.keys(), *final_delta.keys(), *final_delta_pct.keys()] + if len(all_params_to_update) == 0: + raise ValueError( + "No parameters to update! Cannot create a scenario without updating any parameters." + ) + + update_count = Counter(all_params_to_update) + repeated_arguments = [k for k, v in update_count.items() if v > 1] + if len(repeated_arguments) > 0: + raise ValueError( + f"Arguments {', '.join(repeated_arguments)} are repeated among final_values, final_delta," + f" and final_delta_pct. Define each scenario in exactly one way (by giving the final value, " + f"the offset from the initial value, or the percentage change from the initial value)." + ) + + # For final values, directly insert the provided values, overwriting the initial values + scenario_params.update(final_values) + + # For deltas, add the delta to the initial value + for k, v in final_delta.items(): + scenario_params[k] += v + + # For percent changes, multiply the initial value by the percentage change + for k, v in final_delta_pct.items(): + scenario_params[k] *= v + + return scenario_params diff --git a/tests/test_cge_model.py b/tests/test_cge_model.py index ea09e6c..2c43761 100644 --- a/tests/test_cge_model.py +++ b/tests/test_cge_model.py @@ -564,3 +564,63 @@ def test_can_compile_without_jax_installed(): mod = load_model_1(backend="pytensor") inital_data = calibrate_model_1(**model_1_data) mod.simulate(inital_data, final_delta_pct={"L_s": 0.5}) + + +def test_invalid_simulation_raises(): + mod = load_model_1(backend="numba") + initial_state = calibrate_model_1(**model_1_data) + + with pytest.raises( + ValueError, match="initial_state must be a Result or a dict of initial values" + ): + mod.simulate([1, 2, 3, 4, 5], final_values={"L_s": 1500}) + + with pytest.raises( + ValueError, + match="initial_state must contain values for all variables and parameters in the " + "model. Did not find values for alpha", + ): + bad_init = initial_state.copy() + del bad_init["alpha"] + + mod.simulate(bad_init, final_values={"L_s": 1500}) + + with pytest.raises( + ValueError, + match="initial_state contains values for variables or parameters not in the " "model: lol", + ): + bad_init = initial_state.copy() + bad_init["lol"] = 3 + mod.simulate(bad_init, final_values={"L_s": 1500}) + + with pytest.raises(ValueError, match="No parameters to update!"): + mod.simulate(initial_state) + + with pytest.raises(ValueError, match="Arguments K_s are repeated among"): + mod.simulate( + initial_state, final_values={"L_s": 1500, "K_s": 5000}, final_delta_pct={"K_s": 0.5} + ) + + +@pytest.mark.parametrize( + "scenario_kwargs, expected_result", + [ + ({"final_values": {"L_s": 1500}}, {"L_s": lambda x: 1500}), + ({"final_delta_pct": {"L_s": 0.5}}, {"L_s": lambda x: x * 0.5}), + ({"final_delta": {"L_s": -500}}, {"L_s": lambda x: x - 500}), + ( + {"final_values": {"L_s": 1500}, "final_delta_pct": {"K_s": 0.5}}, + {"L_s": lambda x: 1500, "K_s": lambda x: x * 0.5}, + ), + ], + ids=["final_values", "final_delta_pct", "final_delta", "combined"], +) +def test_simulate(scenario_kwargs, expected_result): + mod = load_model_1(backend="numba") + initial_state = calibrate_model_1(**model_1_data) + + res = mod.simulate(initial_state, **scenario_kwargs) + + for key in expected_result: + final_param = res["optimizer"].parameters[key].isel(step=-1) + assert final_param == expected_result[key](initial_state[key])