Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
idc9 committed Oct 15, 2020
1 parent 5aab4b8 commit 552cbf9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mvdr/mcca/tests/test_mcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_mcca():
assert not out['centerers'][b].with_std
if params['center']:
assert np.allclose(out['centerers'][b].mean_,
Xs[b].mean(axis=1))
Xs[b].mean(axis=0))
else:
assert out['centerers'][b].mean_ is None

Expand Down
12 changes: 8 additions & 4 deletions mvdr/mcca/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from sklearn.utils import check_random_state

from mvlearn.utils import check_Xs

from mvdr.mcca.block_processing import get_blocks_metadata
from mvdr.linalg_utils import normalize_cols
from mvdr.mcca.mcca import check_regs, get_mcca_gevp_data

Expand Down Expand Up @@ -51,7 +51,9 @@ def check_mcca_scores_and_loadings(Xs, out,
common_norm_scores = out['common_norm_scores']
centerers = out['centerers']

n_blocks, n_samples, n_features = get_blocks_metadata(Xs)
Xs, n_blocks, n_samples, n_features = check_Xs(Xs, multiview=True,
return_dimensions=True)


# make sure to apply centering transformations
Xs = [centerers[b].transform(Xs[b]) for b in range(n_blocks)]
Expand All @@ -63,7 +65,7 @@ def check_mcca_scores_and_loadings(Xs, out,

# check common norm scores are the column normalized sum of the
# block scores
cns_pred = normalize_cols(sum(bs for bs in block_scores))
cns_pred = normalize_cols(sum(bs for bs in block_scores))[0]
assert np.allclose(cns_pred, common_norm_scores)

if check_normalization:
Expand All @@ -90,7 +92,9 @@ def check_mcca_gevp(Xs, out, regs):
evals = out['evals']
centerers = out['centerers']

n_blocks, n_samples, n_features = get_blocks_metadata(Xs)
Xs, n_blocks, n_samples, n_features = check_Xs(Xs, multiview=True,
return_dimensions=True)

regs = check_regs(regs=regs, n_blocks=n_blocks)

# make sure to apply centering transformations
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# return f.read()

install_requires = ['numpy', 'scipy', 'scikit-learn', 'matplotlib', 'seaborn',
'joblib']
'joblib', 'mvlearn']


setup(name='mvdr',
Expand Down

0 comments on commit 552cbf9

Please sign in to comment.