From aa82e36c2a5e4bfde24576e6d330ac00e82290cc Mon Sep 17 00:00:00 2001 From: Jonathan Berrisch Date: Wed, 26 Feb 2025 10:42:02 +0100 Subject: [PATCH] Parameterize dl2_dpp tests --- tests/test_distributions_gamma.py | 7 +++++-- tests/test_distributions_johnsonsu.py | 7 +++++-- tests/test_distributions_normal.py | 7 +++++-- tests/test_distributions_studentt.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/test_distributions_gamma.py b/tests/test_distributions_gamma.py index e540de5..443d244 100644 --- a/tests/test_distributions_gamma.py +++ b/tests/test_distributions_gamma.py @@ -3,12 +3,15 @@ import numpy as np from rolch.distributions.gamma import DistributionGamma +PARAM = np.arange(DistributionGamma().n_params) -def test_dl2_dpp_raises_value_error(): + +@pytest.mark.parametrize("param", PARAM) +def test_dl2_dpp_raises_value_error(param): dist = DistributionGamma() y = np.array([1, 2, 3]) theta = np.array([[0, 1], [1, 2], [2, 3]]) with pytest.raises( ValueError, match="Cross derivatives must use different parameters." ): - dist.dl2_dpp(y, theta, (1, 1)) + dist.dl2_dpp(y, theta, (param, param)) diff --git a/tests/test_distributions_johnsonsu.py b/tests/test_distributions_johnsonsu.py index ca01f0d..ce16ecb 100644 --- a/tests/test_distributions_johnsonsu.py +++ b/tests/test_distributions_johnsonsu.py @@ -2,12 +2,15 @@ import numpy as np from rolch.distributions.johnsonsu import DistributionJSU +PARAM = np.arange(DistributionJSU().n_params) -def test_dl2_dpp_raises_value_error(): + +@pytest.mark.parametrize("param", PARAM) +def test_dl2_dpp_raises_value_error(param): dist = DistributionJSU() y = np.array([1, 2, 3]) theta = np.array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]]) with pytest.raises( ValueError, match="Cross derivatives must use different parameters." ): - dist.dl2_dpp(y, theta, (3, 3)) + dist.dl2_dpp(y, theta, (param, param)) diff --git a/tests/test_distributions_normal.py b/tests/test_distributions_normal.py index 0cc949a..29966eb 100644 --- a/tests/test_distributions_normal.py +++ b/tests/test_distributions_normal.py @@ -2,12 +2,15 @@ import numpy as np from rolch.distributions.normal import DistributionNormal +PARAM = np.arange(DistributionNormal().n_params) -def test_dl2_dpp_raises_value_error(): + +@pytest.mark.parametrize("param", PARAM) +def test_dl2_dpp_raises_value_error(param): dist = DistributionNormal() y = np.array([1, 2, 3]) theta = np.array([[0, 1], [1, 2], [2, 3]]) with pytest.raises( ValueError, match="Cross derivatives must use different parameters." ): - dist.dl2_dpp(y, theta, (0, 0)) + dist.dl2_dpp(y, theta, (param, param)) diff --git a/tests/test_distributions_studentt.py b/tests/test_distributions_studentt.py index 1f0dc58..67089d7 100644 --- a/tests/test_distributions_studentt.py +++ b/tests/test_distributions_studentt.py @@ -3,12 +3,15 @@ import numpy as np from rolch.distributions.studentt import DistributionT +PARAM = np.arange(DistributionT().n_params) -def test_dl2_dpp_raises_value_error(): + +@pytest.mark.parametrize("param", PARAM) +def test_dl2_dpp_raises_value_error(param): dist = DistributionT() y = np.array([1, 2, 3]) theta = np.array([[0, 1, 2], [1, 2, 3], [2, 3, 4]]) with pytest.raises( ValueError, match="Cross derivatives must use different parameters." ): - dist.dl2_dpp(y, theta, (1, 1)) + dist.dl2_dpp(y, theta, (param, param))