diff --git a/tests/ignite/metrics/test_cosine_similarity.py b/tests/ignite/metrics/test_cosine_similarity.py index 3e0c99a5eee..db7a5d5d9f1 100644 --- a/tests/ignite/metrics/test_cosine_similarity.py +++ b/tests/ignite/metrics/test_cosine_similarity.py @@ -102,7 +102,7 @@ def update(engine, i): y_true_np = y_true.cpu().numpy() y_preds_np = y_preds.cpu().numpy() y_true_norm = np.clip(np.linalg.norm(y_true_np, axis=1, keepdims=True), 1e-8, None) - y_preds_norm = np.clip(np.linalg.norm(y_preds, axis=1, keepdims=True), 1e-8, None) + y_preds_norm = np.clip(np.linalg.norm(y_preds_np, axis=1, keepdims=True), 1e-8, None) true_res = np.sum((y_true_np / y_true_norm) * (y_preds_np / y_preds_norm), axis=1) true_res = np.mean(true_res)