diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 0f1e693785..540d2db4f0 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -483,39 +483,20 @@ class UMAP(UniversalBase, umap_params.verbosity = self.verbose umap_params.a = self.a umap_params.b = self.b + umap_params.target_n_neighbors = self.target_n_neighbors + umap_params.target_weight = self.target_weight + umap_params.random_state = check_random_seed(self.random_state) + umap_params.deterministic = self.deterministic + if self.init == "spectral": umap_params.init = 1 else: # self.init == "random" umap_params.init = 0 - umap_params.target_n_neighbors = self.target_n_neighbors + if self.target_metric == "euclidean": umap_params.target_metric = MetricType.EUCLIDEAN else: # self.target_metric == "categorical" umap_params.target_metric = MetricType.CATEGORICAL - if self.build_algo == "brute_force_knn": - umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN - else: # self.init == "nn_descent" - umap_params.build_algo = graph_build_algo.NN_DESCENT - if self.build_kwds is None: - umap_params.nn_descent_params.graph_degree = 64 - umap_params.nn_descent_params.intermediate_graph_degree = 128 - umap_params.nn_descent_params.max_iterations = 20 - umap_params.nn_descent_params.termination_threshold = 0.0001 - umap_params.nn_descent_params.return_distances = True - umap_params.nn_descent_params.n_clusters = 1 - else: - umap_params.nn_descent_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) - umap_params.nn_descent_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) - umap_params.nn_descent_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) - umap_params.nn_descent_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = self.build_kwds.get("nnd_return_distances", True) - if self.build_kwds.get("nnd_n_clusters", 1) < 1: - logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1") - umap_params.nn_descent_params.n_clusters = self.build_kwds.get("nnd_n_clusters", 1) - - umap_params.target_weight = self.target_weight - umap_params.random_state = check_random_seed(self.random_state) - umap_params.deterministic = self.deterministic try: umap_params.metric = metric_parsing[self.metric.lower()] @@ -533,6 +514,21 @@ class UMAP(UniversalBase, else: umap_params.p = self.metric_kwds.get('p') + if self.build_algo == "brute_force_knn": + umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN + else: + umap_params.build_algo = graph_build_algo.NN_DESCENT + build_kwds = self.build_kwds or {} + umap_params.nn_descent_params.graph_degree = build_kwds.get("nnd_graph_degree", 64) + umap_params.nn_descent_params.intermediate_graph_degree = build_kwds.get("nnd_intermediate_graph_degree", 128) + umap_params.nn_descent_params.max_iterations = build_kwds.get("nnd_max_iterations", 20) + umap_params.nn_descent_params.termination_threshold = build_kwds.get("nnd_termination_threshold", 0.0001) + umap_params.nn_descent_params.return_distances = build_kwds.get("nnd_return_distances", True) + umap_params.nn_descent_params.n_clusters = build_kwds.get("nnd_n_clusters", 1) + # Forward metric & metric_kwds to nn_descent + umap_params.nn_descent_params.metric = umap_params.metric + umap_params.nn_descent_params.metric_arg = umap_params.p + cdef uintptr_t callback_ptr = 0 if self.callback: callback_ptr = self.callback.get_native_callback() diff --git a/python/cuml/cuml/manifold/umap_utils.pxd b/python/cuml/cuml/manifold/umap_utils.pxd index 498e495733..c82f0244d4 100644 --- a/python/cuml/cuml/manifold/umap_utils.pxd +++ b/python/cuml/cuml/manifold/umap_utils.pxd @@ -24,6 +24,7 @@ from libc.stdint cimport uint64_t, uintptr_t, int64_t from libcpp cimport bool from libcpp.memory cimport shared_ptr from cuml.metrics.distance_type cimport DistanceType +from cuml.metrics.raft_distance_type cimport DistanceType as RaftDistanceType from cuml.internals.logger cimport level_enum cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": @@ -39,6 +40,7 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": cdef cppclass GraphBasedDimRedCallback + cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbors::experimental::nn_descent": cdef struct index_params: uint64_t graph_degree, @@ -47,6 +49,8 @@ cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbor float termination_threshold, bool return_distances, uint64_t n_clusters, + RaftDistanceType metric, + float metric_arg cdef extern from "cuml/manifold/umapparams.h" namespace "ML": diff --git a/python/cuml/cuml/tests/test_umap.py b/python/cuml/cuml/tests/test_umap.py index 296678666f..5a5cf0200c 100644 --- a/python/cuml/cuml/tests/test_umap.py +++ b/python/cuml/cuml/tests/test_umap.py @@ -17,10 +17,10 @@ # Please install UMAP before running the code # use 'conda install -c conda-forge umap-learn' command to install it -import platform import pytest import copy import joblib +import umap from sklearn.metrics import adjusted_rand_score from sklearn.manifold import trustworthiness from sklearn.datasets import make_blobs @@ -45,12 +45,6 @@ scipy_sparse = cpu_only_import("scipy.sparse") -IS_ARM = platform.processor() == "aarch64" - -if not IS_ARM: - import umap - - dataset_names = ["iris", "digits", "wine", "blobs"] @@ -81,9 +75,6 @@ def test_blobs_cluster(nrows, n_feats, build_algo): @pytest.mark.parametrize( "n_feats", [unit_param(10), quality_param(100), stress_param(1000)] ) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) @pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_umap_fit_transform_score(nrows, n_feats, build_algo): @@ -256,9 +247,6 @@ def test_umap_transform_on_digits(target_metric): @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) @pytest.mark.parametrize("name", dataset_names) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) def test_umap_fit_transform_trust(name, target_metric): if name == "iris": @@ -302,9 +290,6 @@ def test_umap_fit_transform_trust(name, target_metric): @pytest.mark.parametrize("should_downcast", [True]) @pytest.mark.parametrize("input_type", ["dataframe", "ndarray"]) @pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) def test_umap_data_formats( input_type, should_downcast, @@ -343,9 +328,6 @@ def test_umap_data_formats( @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) @pytest.mark.filterwarnings("ignore:(.*)connected(.*):UserWarning:sklearn[.*]") @pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) def test_umap_fit_transform_score_default(target_metric, build_algo): n_samples = 500 @@ -545,9 +527,6 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(): @pytest.mark.filterwarnings("ignore:(.*)zero(.*)::scipy[.*]|umap[.*]") -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) @pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) def test_exp_decay_params(build_algo): def compare_exp_decay_params(a=None, b=None, min_dist=0.1, spread=1.0): @@ -692,9 +671,6 @@ def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): @pytest.mark.parametrize("n_rows", [200, 800]) @pytest.mark.parametrize("n_features", [8, 32]) @pytest.mark.parametrize("n_neighbors", [8, 16]) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): n_clusters = 30 random_state = 42 @@ -738,12 +714,12 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): ("canberra", True), ], ) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) -def test_umap_distance_metrics_fit_transform_trust(metric, supported): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_distance_metrics_fit_transform_trust( + metric, supported, build_algo +): data, labels = make_blobs( - n_samples=1000, n_features=64, centers=5, random_state=42 + n_samples=500, n_features=64, centers=5, random_state=42 ) if metric == "jaccard": @@ -753,7 +729,11 @@ def test_umap_distance_metrics_fit_transform_trust(metric, supported): n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) cuml_model = cuUMAP( - n_neighbors=10, min_dist=0.01, metric=metric, init="random" + n_neighbors=10, + min_dist=0.01, + metric=metric, + init="random", + build_algo=build_algo, ) if not supported: with pytest.raises(NotImplementedError): @@ -791,9 +771,6 @@ def test_umap_distance_metrics_fit_transform_trust(metric, supported): ("canberra", True, True), ], ) -@pytest.mark.skipif( - IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" -) def test_umap_distance_metrics_fit_transform_trust_on_sparse_input( metric, supported, umap_learn_supported ):