Skip to content

Commit

Permalink
Fix failing stochastic test
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 27, 2025
1 parent dab3594 commit 85a592b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
17 changes: 14 additions & 3 deletions tests/test_models/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,20 @@ def get_params(
],
# Transition from period 1 to period 2
[
# Description is the same as above
[[0, 1.0], [1.0, 0]],
[[0, 1.0], [0.0, 1.0]],
# Current working decision 0
[
# Current partner state 0
[0, 1.0],
# Current partner state 1
[1.0, 0],
],
# Current working decision 1
[
# Current partner state 0
[0, 1.0],
# Current partner state 1
[0.0, 1.0],
],
],
],
)
Expand Down
28 changes: 12 additions & 16 deletions tests/test_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_get_lcm_function_with_simulate_target():
targets="solve_and_simulate",
)

res = simulate_model(
res: pd.DataFrame = simulate_model( # type: ignore[assignment]
params=get_params(),
initial_states={
"health": jnp.array([1, 1, 0, 0]),
Expand All @@ -28,21 +28,17 @@ def test_get_lcm_function_with_simulate_target():
},
)

expected_partner = [
0,
0,
1,
0, # period 0
1,
1,
1,
1, # period 1
1,
1,
1,
0, # period 2
]
assert jnp.array_equal(res["partner"].values, expected_partner) # type: ignore[call-overload, arg-type]
# This is derived from the partner transition in get_params.
expected_next_partner = (
(res.working.astype(bool) | ~res.partner.astype(bool)).astype(int).loc[:1]
)

pd.testing.assert_series_equal(
res["partner"].loc[1:],
expected_next_partner,
check_index=False,
check_names=False,
)


# ======================================================================================
Expand Down

0 comments on commit 85a592b

Please sign in to comment.