Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mixed simulations from mixed scenario types #29

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions cge_modeling/base/cge.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from cge_modeling.base.utilities import (
_replace_dim_marker_with_dim_name,
_validate_input,
create_final_param_dict,
ensure_input_is_sequence,
flat_array_to_variable_dict,
get_method_defaults,
Expand Down Expand Up @@ -987,22 +988,24 @@ def simulate(
f"initial_state must be a Result or a dict of initial values, found {type(initial_state)}"
)

if not all(x in x0_var_param for x in self.variable_names + self.parameter_names):
all_model_objs = [*self.variable_names, *self.parameter_names]

if not all(x in x0_var_param for x in all_model_objs):
missing_objs = [x for x in all_model_objs if x not in x0_var_param]
raise ValueError(
f"initial_state must contain values for all variables and parameters in the model. Found {initial_state.keys()}"
f"initial_state must contain values for all variables and parameters in the model. Did not find values "
f"for {', '.join(missing_objs)}"
)

final_param_dict = deepcopy(x0_var_param)
if final_values is not None:
final_param_dict.update(final_values)
elif final_delta is not None:
for k, v in final_delta.items():
final_param_dict[k] += v
elif final_delta_pct is not None:
for k, v in final_delta_pct.items():
final_param_dict[k] *= v
else:
raise ValueError()
if any(x not in all_model_objs for x in x0_var_param):
unknown_objs = [x for x in x0_var_param if x not in all_model_objs]
raise ValueError(
f"initial_state contains values for variables or parameters not in the model: {', '.join(unknown_objs)}"
)

final_param_dict = create_final_param_dict(
x0_var_param, final_values, final_delta, final_delta_pct
)

_, theta_simulation = variable_dict_to_flat_array(
final_param_dict, self.variables, self.parameters
Expand Down
42 changes: 42 additions & 0 deletions cge_modeling/base/utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re

from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from itertools import product
from typing import Any, cast

Expand Down Expand Up @@ -417,3 +419,43 @@ def get_method_defaults(use_grad, use_hess, use_hessp, method):
use_hess = False

return use_grad, use_hess, use_hessp


def create_final_param_dict(
initial_params: dict[str, Any],
final_values: dict[str, Any] | None,
final_delta: dict[str, Any] | None,
final_delta_pct: dict[str, Any] | None,
) -> dict[str, Any]:
scenario_params = deepcopy(initial_params)
final_values = final_values if final_values is not None else {}
final_delta = final_delta if final_delta is not None else {}
final_delta_pct = final_delta_pct if final_delta_pct is not None else {}

all_params_to_update = [*final_values.keys(), *final_delta.keys(), *final_delta_pct.keys()]
if len(all_params_to_update) == 0:
raise ValueError(
"No parameters to update! Cannot create a scenario without updating any parameters."
)

update_count = Counter(all_params_to_update)
repeated_arguments = [k for k, v in update_count.items() if v > 1]
if len(repeated_arguments) > 0:
raise ValueError(
f"Arguments {', '.join(repeated_arguments)} are repeated among final_values, final_delta,"
f" and final_delta_pct. Define each scenario in exactly one way (by giving the final value, "
f"the offset from the initial value, or the percentage change from the initial value)."
)

# For final values, directly insert the provided values, overwriting the initial values
scenario_params.update(final_values)

# For deltas, add the delta to the initial value
for k, v in final_delta.items():
scenario_params[k] += v

# For percent changes, multiply the initial value by the percentage change
for k, v in final_delta_pct.items():
scenario_params[k] *= v

return scenario_params
60 changes: 60 additions & 0 deletions tests/test_cge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,63 @@ def test_can_compile_without_jax_installed():
mod = load_model_1(backend="pytensor")
inital_data = calibrate_model_1(**model_1_data)
mod.simulate(inital_data, final_delta_pct={"L_s": 0.5})


def test_invalid_simulation_raises():
mod = load_model_1(backend="numba")
initial_state = calibrate_model_1(**model_1_data)

with pytest.raises(
ValueError, match="initial_state must be a Result or a dict of initial values"
):
mod.simulate([1, 2, 3, 4, 5], final_values={"L_s": 1500})

with pytest.raises(
ValueError,
match="initial_state must contain values for all variables and parameters in the "
"model. Did not find values for alpha",
):
bad_init = initial_state.copy()
del bad_init["alpha"]

mod.simulate(bad_init, final_values={"L_s": 1500})

with pytest.raises(
ValueError,
match="initial_state contains values for variables or parameters not in the " "model: lol",
):
bad_init = initial_state.copy()
bad_init["lol"] = 3
mod.simulate(bad_init, final_values={"L_s": 1500})

with pytest.raises(ValueError, match="No parameters to update!"):
mod.simulate(initial_state)

with pytest.raises(ValueError, match="Arguments K_s are repeated among"):
mod.simulate(
initial_state, final_values={"L_s": 1500, "K_s": 5000}, final_delta_pct={"K_s": 0.5}
)


@pytest.mark.parametrize(
"scenario_kwargs, expected_result",
[
({"final_values": {"L_s": 1500}}, {"L_s": lambda x: 1500}),
({"final_delta_pct": {"L_s": 0.5}}, {"L_s": lambda x: x * 0.5}),
({"final_delta": {"L_s": -500}}, {"L_s": lambda x: x - 500}),
(
{"final_values": {"L_s": 1500}, "final_delta_pct": {"K_s": 0.5}},
{"L_s": lambda x: 1500, "K_s": lambda x: x * 0.5},
),
],
ids=["final_values", "final_delta_pct", "final_delta", "combined"],
)
def test_simulate(scenario_kwargs, expected_result):
mod = load_model_1(backend="numba")
initial_state = calibrate_model_1(**model_1_data)

res = mod.simulate(initial_state, **scenario_kwargs)

for key in expected_result:
final_param = res["optimizer"].parameters[key].isel(step=-1)
assert final_param == expected_result[key](initial_state[key])
Loading