From 552cbf9c6e4053a5660dede8eba003a0b10e3c79 Mon Sep 17 00:00:00 2001 From: idc9 Date: Thu, 15 Oct 2020 15:45:17 -0400 Subject: [PATCH] fixed tests --- mvdr/mcca/tests/test_mcca.py | 2 +- mvdr/mcca/tests/utils.py | 12 ++++++++---- setup.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mvdr/mcca/tests/test_mcca.py b/mvdr/mcca/tests/test_mcca.py index 7c6eff5..32e9915 100644 --- a/mvdr/mcca/tests/test_mcca.py +++ b/mvdr/mcca/tests/test_mcca.py @@ -25,7 +25,7 @@ def test_mcca(): assert not out['centerers'][b].with_std if params['center']: assert np.allclose(out['centerers'][b].mean_, - Xs[b].mean(axis=1)) + Xs[b].mean(axis=0)) else: assert out['centerers'][b].mean_ is None diff --git a/mvdr/mcca/tests/utils.py b/mvdr/mcca/tests/utils.py index 9307fb3..0dbf3c5 100644 --- a/mvdr/mcca/tests/utils.py +++ b/mvdr/mcca/tests/utils.py @@ -1,8 +1,8 @@ import numpy as np from sklearn.utils import check_random_state +from mvlearn.utils import check_Xs -from mvdr.mcca.block_processing import get_blocks_metadata from mvdr.linalg_utils import normalize_cols from mvdr.mcca.mcca import check_regs, get_mcca_gevp_data @@ -51,7 +51,9 @@ def check_mcca_scores_and_loadings(Xs, out, common_norm_scores = out['common_norm_scores'] centerers = out['centerers'] - n_blocks, n_samples, n_features = get_blocks_metadata(Xs) + Xs, n_blocks, n_samples, n_features = check_Xs(Xs, multiview=True, + return_dimensions=True) + # make sure to apply centering transformations Xs = [centerers[b].transform(Xs[b]) for b in range(n_blocks)] @@ -63,7 +65,7 @@ def check_mcca_scores_and_loadings(Xs, out, # check common norm scores are the column normalized sum of the # block scores - cns_pred = normalize_cols(sum(bs for bs in block_scores)) + cns_pred = normalize_cols(sum(bs for bs in block_scores))[0] assert np.allclose(cns_pred, common_norm_scores) if check_normalization: @@ -90,7 +92,9 @@ def check_mcca_gevp(Xs, out, regs): evals = out['evals'] centerers = out['centerers'] - n_blocks, n_samples, n_features = get_blocks_metadata(Xs) + Xs, n_blocks, n_samples, n_features = check_Xs(Xs, multiview=True, + return_dimensions=True) + regs = check_regs(regs=regs, n_blocks=n_blocks) # make sure to apply centering transformations diff --git a/setup.py b/setup.py index e64fdbd..c3d3557 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ # return f.read() install_requires = ['numpy', 'scipy', 'scikit-learn', 'matplotlib', 'seaborn', - 'joblib'] + 'joblib', 'mvlearn'] setup(name='mvdr',