Skip to content

Commit

Permalink
Parameterize dl2_dpp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BerriJ committed Feb 26, 2025
1 parent a09f5f0 commit aa82e36
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions tests/test_distributions_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import numpy as np
from rolch.distributions.gamma import DistributionGamma

PARAM = np.arange(DistributionGamma().n_params)

def test_dl2_dpp_raises_value_error():

@pytest.mark.parametrize("param", PARAM)
def test_dl2_dpp_raises_value_error(param):
dist = DistributionGamma()
y = np.array([1, 2, 3])
theta = np.array([[0, 1], [1, 2], [2, 3]])
with pytest.raises(
ValueError, match="Cross derivatives must use different parameters."
):
dist.dl2_dpp(y, theta, (1, 1))
dist.dl2_dpp(y, theta, (param, param))
7 changes: 5 additions & 2 deletions tests/test_distributions_johnsonsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import numpy as np
from rolch.distributions.johnsonsu import DistributionJSU

PARAM = np.arange(DistributionJSU().n_params)

def test_dl2_dpp_raises_value_error():

@pytest.mark.parametrize("param", PARAM)
def test_dl2_dpp_raises_value_error(param):
dist = DistributionJSU()
y = np.array([1, 2, 3])
theta = np.array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]])
with pytest.raises(
ValueError, match="Cross derivatives must use different parameters."
):
dist.dl2_dpp(y, theta, (3, 3))
dist.dl2_dpp(y, theta, (param, param))
7 changes: 5 additions & 2 deletions tests/test_distributions_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import numpy as np
from rolch.distributions.normal import DistributionNormal

PARAM = np.arange(DistributionNormal().n_params)

def test_dl2_dpp_raises_value_error():

@pytest.mark.parametrize("param", PARAM)
def test_dl2_dpp_raises_value_error(param):
dist = DistributionNormal()
y = np.array([1, 2, 3])
theta = np.array([[0, 1], [1, 2], [2, 3]])
with pytest.raises(
ValueError, match="Cross derivatives must use different parameters."
):
dist.dl2_dpp(y, theta, (0, 0))
dist.dl2_dpp(y, theta, (param, param))
7 changes: 5 additions & 2 deletions tests/test_distributions_studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import numpy as np
from rolch.distributions.studentt import DistributionT

PARAM = np.arange(DistributionT().n_params)

def test_dl2_dpp_raises_value_error():

@pytest.mark.parametrize("param", PARAM)
def test_dl2_dpp_raises_value_error(param):
dist = DistributionT()
y = np.array([1, 2, 3])
theta = np.array([[0, 1, 2], [1, 2, 3], [2, 3, 4]])
with pytest.raises(
ValueError, match="Cross derivatives must use different parameters."
):
dist.dl2_dpp(y, theta, (1, 1))
dist.dl2_dpp(y, theta, (param, param))

0 comments on commit aa82e36

Please sign in to comment.