Skip to content

Commit

Permalink
Merge pull request #41 from simon-hirsch/fix_j_from_equation
Browse files Browse the repository at this point in the history
Fix estimator.get_j_from_equation
  • Loading branch information
BerriJ authored Feb 12, 2025
2 parents 32e8151 + e18174c commit ee6f94b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/rolch/estimators/online_gamlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,21 @@ def get_J_from_equation(self, X: np.ndarray):
J[p] = X.shape[1] + int(self.fit_intercept[p])
if self.equation[p] == "intercept":
J[p] = 1
elif isinstance(self.equation[p], np.ndarray) or isinstance(
self.equation[p], list
):
elif isinstance(self.equation[p], np.ndarray):
if np.issubdtype(self.equation[p].dtype, bool):
if self.equation[p].shape[0] != X.shape[1]:
raise ValueError(f"Shape does not match for param {p}.")
J[p] = np.sum(self.equation[p]) + int(self.fit_intercept[p])
elif np.issubdtype(self.equation[p].dtype, np.integer):
if self.equation[p].max() >= X.shape[1]:
raise ValueError(f"Shape does not match for param {p}.")
J[p] = self.equation[p].shape[0] + int(self.fit_intercept[p])
else:
raise ValueError(
"If you pass a np.ndarray in the equation, "
"please make sure it is of dtype bool or int."
)
elif isinstance(self.equation[p], list):
J[p] = len(self.equation[p]) + int(self.fit_intercept[p])
else:
raise ValueError("Something unexpected happened")
Expand Down
89 changes: 89 additions & 0 deletions tests/test_online_gamlss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import pytest
from sklearn.datasets import make_regression

from rolch.distributions import DistributionJSU
from rolch.estimators import OnlineGamlss

FIT_INTERCEPT = [True, False]
N_FEATURES = np.round(np.geomspace(11, 100, 5)).astype(int)


@pytest.mark.parametrize("n_features", N_FEATURES)
@pytest.mark.parametrize("fit_intercept", FIT_INTERCEPT)
def test_get_J_from_equation(n_features, fit_intercept):

equation = {
0: "all", # should adjust to n_features
1: "intercept",
2: np.arange(0, 4),
3: np.array([True] * 10 + [False] * (n_features - 10)).astype(bool),
}

EXPECTED = {
0: {True: n_features + 1, False: n_features},
1: {True: 1, False: 1},
2: {True: 5, False: 4},
3: {True: 11, False: 10},
}

X, _ = make_regression(n_samples=100, n_features=n_features)
distribution = DistributionJSU()

estimator = OnlineGamlss(
distribution=distribution,
equation=equation,
fit_intercept=fit_intercept,
)

J = estimator.get_J_from_equation(X)
assert J[0] == EXPECTED[0][fit_intercept], "Wrong J for param == 0"
assert J[1] == EXPECTED[1][fit_intercept], "Wrong J for param == 1"
assert J[2] == EXPECTED[2][fit_intercept], "Wrong J for param == 2"
assert J[3] == EXPECTED[3][fit_intercept], "Wrong J for param == 3"


def test_get_J_from_equation_warnings():

n_features = 10
fit_intercept = True

equation_fail_2 = {
0: "all", # should adjust to n_features
1: "intercept",
2: np.arange(0, 20),
3: np.array([True] * n_features).astype(bool),
}

X, _ = make_regression(n_samples=100, n_features=n_features)
distribution = DistributionJSU()

estimator = OnlineGamlss(
distribution=distribution,
equation=equation_fail_2,
fit_intercept=fit_intercept,
)
with pytest.raises(ValueError, match="Shape does not match for param 2."):
J = estimator.get_J_from_equation(X)

# Test for parameter three
equation_fail_3 = {
0: "all", # should adjust to n_features
1: "intercept",
2: np.arange(0, n_features),
3: np.array([True, False] * 10).astype(bool),
}

X, _ = make_regression(n_samples=100, n_features=10)
distribution = DistributionJSU()
estimator = OnlineGamlss(
distribution=distribution,
equation=equation_fail_3,
fit_intercept=fit_intercept,
)

with pytest.raises(
ValueError,
match="Shape does not match for param 3.",
):
J = estimator.get_J_from_equation(X)

0 comments on commit ee6f94b

Please sign in to comment.