Skip to content

Commit

Permalink
de-mock tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Dec 13, 2024
1 parent cf81b29 commit 4da3896
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions test_T2RunSncosmo.py → tests/test_T2RunSncosmo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import numpy as np
import pytest

from ampel.base.AuxUnitRegister import AuxUnitRegister
from ampel.content.DataPoint import DataPoint
from ampel.content.T1Document import T1Document
from ampel.contrib.hu.t2.T2RunSncosmo import T2RunSncosmo
from ampel.log.AmpelLogger import AmpelLogger
from ampel.model.UnitModel import UnitModel
from ampel.view.T2DocView import T2DocView
from ampel.ztf.view.ZTFT2Tabulator import ZTFT2Tabulator


@pytest.fixture
def mock_t2runsncosmo():
return T2RunSncosmo(
AuxUnitRegister._dyn["ZTFT2Tabulator"] = ZTFT2Tabulator # noqa: SLF001

t2 = T2RunSncosmo(
sncosmo_model_name="salt2",
redshift_kind=None,
fixed_z=None,
Expand All @@ -26,42 +32,55 @@ def mock_t2runsncosmo():
plot_matplotlib_suffix=None,
plot_matplotlib_dir=".",
t2_dependency=[],
tabulator=[UnitModel(unit="ZTFT2Tabulator")],
logger=AmpelLogger.get_logger(),
)
t2.post_init()
return t2


def inputs():
compound = MagicMock(spec=T1Document)
base = {"tag": ["ZTF"]}
datapoints = [
base | {"body": d}
for d in (
{"jd": 2450000, "mag": 20, "magerr": 0.1},
{"jd": 2450001, "mag": 21, "magerr": 0.1},
{"jd": 2450002, "mag": 22, "magerr": 0.1},
)
]
t2_views = [MagicMock(spec=T2DocView)]
return compound, datapoints, t2_views


def test_T2RunSncosmo_no_redshift(mock_t2runsncosmo):
compound = Mock(spec=T1Document)
datapoints = [Mock(spec=DataPoint)]
t2_views = [Mock(spec=T2DocView)]
compound, datapoints, t2_views = inputs()

result = mock_t2runsncosmo.process(compound, datapoints, t2_views)

assert isinstance(result, dict)
assert result["z_source"] is None
assert result["z_source"] == "Fitted"


def test_T2RunSncosmo_no_phaselimit(mock_t2runsncosmo):
compound = Mock(spec=T1Document)
datapoints = [Mock(spec=DataPoint)]
t2_views = [Mock(spec=T2DocView)]
compound = MagicMock(spec=T1Document)
datapoints = [MagicMock(spec=DataPoint)]
t2_views = [MagicMock(spec=T2DocView)]

with patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")):
result = mock_t2runsncosmo.process(compound, datapoints, t2_views)

assert isinstance(result, dict)
assert result["jdstart"] is None
assert result["jdstart"] == -np.inf


def test_T2RunSncosmo_fit_error(mock_t2runsncosmo):
compound = Mock(spec=T1Document)
datapoints = [Mock(spec=DataPoint)]
t2_views = [Mock(spec=T2DocView)]
compound, datapoints, t2_views = inputs()

with (
patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")),
patch.object(mock_t2runsncosmo, "_get_phaselimit", return_value=(0, 1000)),
patch.object(mock_t2runsncosmo, "get_flux_table", return_value=Mock()),
patch("sncosmo.fit_lc", side_effect=RuntimeError("fit error")),
):
result = mock_t2runsncosmo.process(compound, datapoints, t2_views)
Expand All @@ -71,14 +90,11 @@ def test_T2RunSncosmo_fit_error(mock_t2runsncosmo):


def test_T2RunSncosmo_success(mock_t2runsncosmo):
compound = Mock(spec=T1Document)
datapoints = [Mock(spec=DataPoint)]
t2_views = [Mock(spec=T2DocView)]
compound, datapoints, t2_views = inputs()

mock_flux_table = Mock()
mock_fit_result = {
"parameters": [0.1],
"data_mask": [True],
"parameters": np.array([0.1]),
"data_mask": np.array([True]),
"covariance": None,
"param_names": ["z"],
"chisq": 1.0,
Expand All @@ -89,7 +105,6 @@ def test_T2RunSncosmo_success(mock_t2runsncosmo):
with (
patch.object(mock_t2runsncosmo, "_get_redshift", return_value=(0.1, "Fixed")),
patch.object(mock_t2runsncosmo, "_get_phaselimit", return_value=(0, 1000)),
patch.object(mock_t2runsncosmo, "get_flux_table", return_value=mock_flux_table),
patch("sncosmo.fit_lc", return_value=(mock_fit_result, mock_fitted_model)),
patch.object(
mock_t2runsncosmo, "_get_fit_metrics", return_value={"metric": 1.0}
Expand Down

0 comments on commit 4da3896

Please sign in to comment.