Skip to content

Commit

Permalink
[MAINT] Make imports consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-hirsch committed Feb 20, 2025
1 parent 05fd2bb commit 9949616
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
8 changes: 4 additions & 4 deletions tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sklearn.datasets import load_diabetes
from sklearn.linear_model import lasso_path

import rolch
from rolch import online_coordinate_descent_path


def test_coordinate_descent():
Expand All @@ -29,7 +29,7 @@ def test_coordinate_descent():
is_regularized = np.repeat(True, J)
beta_path = np.zeros((lambda_n, J))

rolch_lasso_path = rolch.online_coordinate_descent_path(
rolch_lasso_path = online_coordinate_descent_path(
x_gram=x_gram,
y_gram=y_gram,
beta_path=beta_path,
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_coordinate_descent_bounds():
is_regularized = np.repeat(True, J)
beta_path = np.zeros((lambda_n, J))

rolch_lasso_path_positive = rolch.online_coordinate_descent_path(
rolch_lasso_path_positive = online_coordinate_descent_path(
x_gram=x_gram,
y_gram=y_gram,
beta_path=beta_path,
Expand All @@ -86,7 +86,7 @@ def test_coordinate_descent_bounds():
max_iterations=1000,
)[0]

rolch_lasso_path_negative = rolch.online_coordinate_descent_path(
rolch_lasso_path_negative = online_coordinate_descent_path(
x_gram=x_gram,
y_gram=y_gram,
beta_path=beta_path,
Expand Down
11 changes: 5 additions & 6 deletions tests/test_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import scipy.stats as st

from rolch.gram import (
from rolch import (
init_gram,
init_inverted_gram,
init_y_gram,
Expand Down Expand Up @@ -30,20 +30,19 @@ def make_x_y_w(N, D, random_weights=True):
FORGET = [0, 0.0001, 0.001, 0.01, 0.1]
BATCH_SIZE = [10, 25]


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
def test_inverse_rank_deficit(
N, D, random_weights, forget
):
def test_inverse_rank_deficit(N, D, random_weights, forget):
X, _, w = make_x_y_w(N, D, random_weights=random_weights)
for d in range(1, D+1):
for d in range(1, D + 1):
choice = np.random.choice(np.arange(D), d)
XX = np.hstack((X, X[:, choice]))
with pytest.raises(ValueError):
gram_start = init_inverted_gram(XX[:-1], w[:-1], forget)


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
Expand Down
22 changes: 17 additions & 5 deletions tests/test_link_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
import numpy as np
import pytest

import rolch

REAL_LINE_LINKS = [rolch.IdentityLink]
POSITIVE_LINE_LINKS = [rolch.LogLink, rolch.SqrtLink, rolch.LogIdentLink]
SHIFTED_LINKS = [rolch.LogShiftValueLink, rolch.SqrtShiftValueLink]
from rolch import (
IdentityLink,
LogIdentLink,
LogLink,
LogShiftValueLink,
SqrtLink,
SqrtShiftValueLink,
)

# We don't test
# - LogShiftTwoLink
# - SqrtShiftTwoLink
# at the moment since they derive from the ShiftValueLink

REAL_LINE_LINKS = [IdentityLink]
POSITIVE_LINE_LINKS = [LogLink, SqrtLink, LogIdentLink]
SHIFTED_LINKS = [LogShiftValueLink, SqrtShiftValueLink]
VALUES = np.array([2, 5, 10, 25, 100])
M = 10000

Expand Down
6 changes: 3 additions & 3 deletions tests/test_python_against_r.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

import rolch
from rolch import DistributionNormal, OnlineGamlss

file = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv"
mtcars = np.genfromtxt(file, delimiter=",", skip_header=1)[:, 1:]
Expand Down Expand Up @@ -28,8 +28,8 @@ def test_normal_distribution():
coef_R_mu = np.array([36.51776626, -2.32470221, -0.01421071])
coef_R_sg = np.array([1.8782995906, -0.1262290913, -0.0003943062])

estimator = rolch.OnlineGamlss(
distribution=rolch.DistributionNormal(),
estimator = OnlineGamlss(
distribution=DistributionNormal(),
equation={0: np.array([0, 2]), 1: np.array([0, 2])},
method="ols",
scale_inputs=False,
Expand Down

0 comments on commit 9949616

Please sign in to comment.