Skip to content

Commit

Permalink
tests: follow T2RunSncosmo churn
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Jan 6, 2025
1 parent e1d1435 commit 0dc5c66
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_T2RunSncosmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ def mock_t2runsncosmo():
sncosmo_model_name="salt2",
redshift_kind=None,
fixed_z=None,
backup_z=None,
scale_z=None,
sncosmo_bounds={},
apply_mwcorrection=False,
phaseselect_kind=None,
noisified=False,
plot_db=False,
plot_props=None,
plot_matplotlib_suffix=None,
plot_matplotlib_dir=".",
plot_suffix=None,
plot_dir=".",
t2_dependency=[],
tabulator=[UnitModel(unit="ZTFT2Tabulator")],
logger=AmpelLogger.get_logger(),
Expand Down Expand Up @@ -68,7 +67,7 @@ def test_T2RunSncosmo_no_phaselimit(mock_t2runsncosmo):
datapoints = [MagicMock(spec=DataPoint)]
t2_views = [MagicMock(spec=T2DocView)]

with patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")):
with patch.object(mock_t2runsncosmo, "get_redshift", return_value=([0.1], ["Fixed"], [1.0])):
result = mock_t2runsncosmo.process(compound, datapoints, t2_views)

assert isinstance(result, dict)
Expand All @@ -79,14 +78,16 @@ def test_T2RunSncosmo_fit_error(mock_t2runsncosmo):
compound, datapoints, t2_views = inputs()

with (
patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")),
patch.object(
mock_t2runsncosmo, "get_redshift", return_value=([0.1], ["Fixed"], [1.0])
),
patch.object(mock_t2runsncosmo, "_get_phaselimit", return_value=(0, 1000)),
patch("sncosmo.fit_lc", side_effect=RuntimeError("fit error")),
):
result = mock_t2runsncosmo.process(compound, datapoints, t2_views)

assert isinstance(result, dict)
assert result["run_error"] is True
assert result["success"] is False


def test_T2RunSncosmo_success(mock_t2runsncosmo):
Expand All @@ -95,15 +96,18 @@ def test_T2RunSncosmo_success(mock_t2runsncosmo):
mock_fit_result = {
"parameters": np.array([0.1]),
"data_mask": np.array([True]),
"covariance": None,
"covariance": np.array([0.01]),
"param_names": ["z"],
"chisq": 1.0,
"ndof": 1,
"success": True,
}
mock_fitted_model = Mock()

with (
patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")),
patch.object(
mock_t2runsncosmo, "get_redshift", return_value=([0.1], ["Fixed"], [1.0])
),
patch.object(mock_t2runsncosmo, "_get_phaselimit", return_value=(0, 1000)),
patch("sncosmo.fit_lc", return_value=(mock_fit_result, mock_fitted_model)),
patch.object(
Expand All @@ -114,4 +118,4 @@ def test_T2RunSncosmo_success(mock_t2runsncosmo):

assert isinstance(result, dict)
assert "sncosmo_result" in result
assert result["fit_metrics"] == {"metric": 1.0}
assert result["sncosmo_result"]["fit_metrics"] == {"metric": 1.0}

0 comments on commit 0dc5c66

Please sign in to comment.