From a60bb86d2d2d56485d519e39af099a9ffa87f028 Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong Date: Tue, 12 Dec 2023 15:51:40 -0800 Subject: [PATCH] Make ANNs compatible with Experiment (#563) --- README.md | 8 +- cornac/eval_methods/base_method.py | 6 +- cornac/exception.py | 11 +-- cornac/experiment/experiment.py | 2 +- cornac/models/ann/recom_ann_annoy.py | 12 ++- cornac/models/ann/recom_ann_base.py | 103 ++++++++++++++++++++++--- cornac/models/ann/recom_ann_faiss.py | 8 +- cornac/models/ann/recom_ann_hnswlib.py | 3 + cornac/models/ann/recom_ann_scann.py | 19 ++++- cornac/models/recommender.py | 41 ++++++++-- examples/ann_all.ipynb | 18 ++--- examples/ann_example.py | 45 +++++++++++ tutorials/ann_hnswlib.ipynb | 26 +++---- 13 files changed, 244 insertions(+), 58 deletions(-) create mode 100644 examples/ann_example.py diff --git a/README.md b/README.md index 108a9039d..a11768f7f 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,10 @@ One important aspect of deploying recommender model is efficient retrieval via A | Supported framework | Cornac wrapper | Examples | | :---: | :---: | :---: | -| [spotify/annoy](https://github.com/spotify/annoy) | [AnnoyANN](cornac/models/ann/recom_ann_annoy.py) | [ann_all.ipynb](examples/ann_all.ipynb) -| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_all.ipynb](examples/ann_all.ipynb) -| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb) -| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb) +| [spotify/annoy](https://github.com/spotify/annoy) | [AnnoyANN](cornac/models/ann/recom_ann_annoy.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb) +| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb) +| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_example.py](examples/ann_example.py), [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb) +| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb) ## Models diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index 66ad8165a..7540bf40c 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -157,6 +157,8 @@ def ranking_eval( if len(metrics) == 0: return [], [] + max_k = max(m.k for m in metrics) + avg_results = [] user_results = [{} for _ in enumerate(metrics)] @@ -203,7 +205,9 @@ def pos_items(csr_row): u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0] u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0] - item_rank, item_scores = model.rank(user_idx, item_indices) + item_rank, item_scores = model.rank( + user_idx=user_idx, item_indices=item_indices, k=max_k + ) for i, mt in enumerate(metrics): mt_score = mt.compute( diff --git a/cornac/exception.py b/cornac/exception.py index 610478e13..7f79050ed 100644 --- a/cornac/exception.py +++ b/cornac/exception.py @@ -13,17 +13,14 @@ # limitations under the License. # ============================================================================ -class CornacException(Exception): - """Exception base class to extend from - """ +class CornacException(Exception): + """Exception base class to extend from""" pass class ScoreException(CornacException): - """Exception raised in score function when facing unknowns + """Exception raised in score function when facing unknowns""" - """ - - pass \ No newline at end of file + pass diff --git a/cornac/experiment/experiment.py b/cornac/experiment/experiment.py index c3ee49e26..139417908 100644 --- a/cornac/experiment/experiment.py +++ b/cornac/experiment/experiment.py @@ -150,7 +150,7 @@ def run(self): if self.val_result is not None: self.val_result.append(val_result) - if not isinstance(self.result, CVExperimentResult): + if self.save_dir and (not isinstance(self.result, CVExperimentResult)): model.save(self.save_dir) output = "" diff --git a/cornac/models/ann/recom_ann_annoy.py b/cornac/models/ann/recom_ann_annoy.py index 4970de543..f818c9afc 100644 --- a/cornac/models/ann/recom_ann_annoy.py +++ b/cornac/models/ann/recom_ann_annoy.py @@ -69,7 +69,6 @@ def __init__( ): super().__init__(model=model, name=name, verbose=verbose) - self.model = model self.n_trees = n_trees self.search_k = search_k self.num_threads = num_threads @@ -85,6 +84,8 @@ def __init__( def build_index(self): """Building index from the base recommender model.""" + super().build_index() + from annoy import AnnoyIndex assert self.measure in SUPPORTED_MEASURES @@ -92,7 +93,9 @@ def build_index(self): self.index = AnnoyIndex( self.item_vectors.shape[1], SUPPORTED_MEASURES[self.measure] ) - self.index.set_seed(self.seed) + + if self.seed is not None: + self.index.set_seed(self.seed) for i, v in enumerate(self.item_vectors): self.index.add_item(i, v) @@ -115,6 +118,11 @@ def knn_query(self, query, k): ] neighbors = np.array([r[0] for r in result], dtype="int") distances = np.array([r[1] for r in result], dtype="float32") + + # make sure distances respect the notion of nearest neighbors (smaller is better) + if self.higher_is_better: + distances = 1.0 - distances + return neighbors, distances def save(self, save_dir=None): diff --git a/cornac/models/ann/recom_ann_base.py b/cornac/models/ann/recom_ann_base.py index c256a721e..e3935595c 100644 --- a/cornac/models/ann/recom_ann_base.py +++ b/cornac/models/ann/recom_ann_base.py @@ -14,10 +14,12 @@ # ============================================================================ import copy +import warnings import numpy as np from ..recommender import Recommender from ..recommender import is_ann_supported +from ..recommender import MEASURE_DOT, MEASURE_COSINE class BaseANN(Recommender): @@ -41,20 +43,50 @@ def __init__(self, model, name="BaseANN", verbose=False): if not is_ann_supported(model): raise ValueError(f"{model.name} doesn't support ANN search") - # ANN required attributes - self.measure = copy.deepcopy(model.get_vector_measure()) - self.user_vectors = copy.deepcopy(model.get_user_vectors()) - self.item_vectors = copy.deepcopy(model.get_item_vectors()) + self.model = model - # get basic attributes to be a proper recommender - super().fit(train_set=model.train_set, val_set=model.val_set) + self.ignored_attrs.append("model") # not to save the base model with ANN - def build_index(self): - """Building index from the base recommender model. + if model.is_fitted: + Recommender.fit(self, model.train_set, model.val_set) - :raise NotImplementedError + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object """ - raise NotImplementedError() + Recommender.fit(self, train_set, val_set) + + if not self.model.is_fitted: + if self.verbose: + print(f"Fitting base recommender model {self.model.name}...") + self.model.fit(train_set, val_set) + + self.build_index() + + return self + + def build_index(self): + """Building index from the base recommender model.""" + if not self.model.is_fitted: + warnings.warn(f"Base recommender model {self.model.name} is not fitted!") + + # ANN required attributes + self.measure = copy.deepcopy(self.model.get_vector_measure()) + self.user_vectors = copy.deepcopy(self.model.get_user_vectors()) + self.item_vectors = copy.deepcopy(self.model.get_item_vectors()) + + self.higher_is_better = self.measure in {MEASURE_DOT, MEASURE_COSINE} def knn_query(self, query, k): """Implementing ANN search for a given query. @@ -65,6 +97,57 @@ def knn_query(self, query, k): """ raise NotImplementedError() + def rank(self, user_idx, item_indices=None, k=-1, **kwargs): + """Rank all test items for a given user. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform item raking. + + item_indices: 1d array, optional, default: None + A list of candidate item indices to be ranked by the user. + If `None`, list of ranked known item indices and their scores will be returned. + + k: int, required + Cut-off length for recommendations, k=-1 will return ranked list of all items. + + Returns + ------- + (ranked_items, item_scores): tuple + `ranked_items` contains item indices being ranked by their scores. + `item_scores` contains scores of items corresponding to index in `item_indices` input. + + """ + query = self.user_vectors[[user_idx]] + knn_items, distances = self.knn_query(query, k=k) + + top_k_items = knn_items[0] + top_k_scores = -distances[0] + + item_scores = np.full(self.total_items, -np.Inf) + item_scores[top_k_items] = top_k_scores + + all_items = np.arange(self.total_items) + ranked_items = np.concatenate( + [ + top_k_items, + all_items[~np.isin(all_items, top_k_items, assume_unique=True)], + ] + ) + + # rank items based on their scores + if item_indices is None: + item_scores = item_scores[: self.num_items] + ranked_items = ranked_items[: self.num_items] + else: + item_scores = item_scores[item_indices] + ranked_items = ranked_items[ + np.isin(ranked_items, item_indices, assume_unique=True) + ] + + return ranked_items, item_scores + def recommend(self, user_id, k=-1, remove_seen=False, train_set=None): """Generate top-K item recommendations for a given user. Backward compatibility. diff --git a/cornac/models/ann/recom_ann_faiss.py b/cornac/models/ann/recom_ann_faiss.py index f30521195..d90725fa4 100644 --- a/cornac/models/ann/recom_ann_faiss.py +++ b/cornac/models/ann/recom_ann_faiss.py @@ -68,7 +68,6 @@ def __init__( ): super().__init__(model=model, name=name, verbose=verbose) - self.model = model self.nlist = nlist self.nprobe = nprobe self.use_gpu = use_gpu @@ -87,6 +86,8 @@ def __init__( def build_index(self): """Building index from the base recommender model.""" + super().build_index() + import faiss faiss.omp_set_num_threads(self.num_threads) @@ -129,6 +130,11 @@ def knn_query(self, query, k): Array of k-nearest neighbors and corresponding distances for the given query. """ distances, neighbors = self.index.search(query, k) + + # make sure distances respect the notion of nearest neighbors (smaller is better) + if self.higher_is_better: + distances = 1.0 - distances + return neighbors, distances def save(self, save_dir=None): diff --git a/cornac/models/ann/recom_ann_hnswlib.py b/cornac/models/ann/recom_ann_hnswlib.py index 784c59708..91f554146 100644 --- a/cornac/models/ann/recom_ann_hnswlib.py +++ b/cornac/models/ann/recom_ann_hnswlib.py @@ -78,6 +78,7 @@ def __init__( verbose=False, ): super().__init__(model=model, name=name, verbose=verbose) + self.M = M self.ef_construction = ef_construction self.ef = ef @@ -96,6 +97,8 @@ def __init__( def build_index(self): """Building index from the base recommender model.""" + super().build_index() + import hnswlib assert self.measure in SUPPORTED_MEASURES diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py index 662e5944f..50ebdfa4a 100644 --- a/cornac/models/ann/recom_ann_scann.py +++ b/cornac/models/ann/recom_ann_scann.py @@ -80,10 +80,18 @@ def __init__( ): super().__init__(model=model, name=name, verbose=verbose) + if partition_params is None: + partition_params = {"num_leaves": 100, "num_leaves_to_search": 50} + if score_params is None: - score_params = {} + score_params = { + "dimensions_per_block": 2, + "anisotropic_quantization_threshold": 0.2, + } + + if rescore_params is None: + rescore_params = {"reordering_num_neighbors": 100} - self.model = model self.partition_params = partition_params self.score_params = score_params self.score_brute_force = score_brute_force @@ -103,6 +111,8 @@ def __init__( def build_index(self): """Building index from the base recommender model.""" + super().build_index() + import scann assert self.measure in SUPPORTED_MEASURES @@ -148,6 +158,11 @@ def knn_query(self, query, k): Array of k-nearest neighbors and corresponding distances for the given query. """ neighbors, distances = self.index.search_batched(query, final_num_neighbors=k) + + # make sure distances respect the notion of nearest neighbors (smaller is better) + if self.higher_is_better: + distances = 1.0 - distances + return neighbors, distances def save(self, save_dir=None): diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index cfe86c552..6e174ce1a 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -17,6 +17,7 @@ import copy import inspect import pickle +import warnings from glob import glob from datetime import datetime @@ -130,6 +131,7 @@ def __init__(self, name, trainable=True, verbose=False): self.name = name self.trainable = trainable self.verbose = verbose + self.is_fitted = False # attributes to be ignored when saving model self.ignored_attrs = ["train_set", "val_set", "test_set"] @@ -180,8 +182,9 @@ def reset_info(self): def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) + ignored_attrs = set(self.ignored_attrs) for k, v in self.__dict__.items(): - if k in self.ignored_attrs: + if k in ignored_attrs: continue setattr(result, k, copy.deepcopy(v)) return result @@ -302,6 +305,11 @@ def fit(self, train_set, val_set=None): ------- self : object """ + if self.is_fitted: + warnings.warn( + "Model is already fitted. Re-fitting will overwrite the previous model." + ) + self.reset_info() train_set.reset() if val_set is not None: @@ -320,6 +328,8 @@ def fit(self, train_set, val_set=None): self.train_set = train_set self.val_set = val_set + self.is_fitted = True + return self def knows_user(self, user_idx): @@ -450,7 +460,7 @@ def rate(self, user_idx, item_idx, clipping=True): return rating_pred - def rank(self, user_idx, item_indices=None, **kwargs): + def rank(self, user_idx, item_indices=None, k=-1, **kwargs): """Rank all test items for a given user. Parameters @@ -462,6 +472,10 @@ def rank(self, user_idx, item_indices=None, **kwargs): A list of candidate item indices to be ranked by the user. If `None`, list of ranked known item indices and their scores will be returned. + k: int, required + Cut-off length for recommendations, k=-1 will return ranked list of all items. + This is more important for ANN to know the limit to avoid exhaustive ranking. + Returns ------- (ranked_items, item_scores): tuple @@ -484,12 +498,23 @@ def rank(self, user_idx, item_indices=None, **kwargs): all_item_scores[: self.num_items] = known_item_scores # rank items based on their scores - if item_indices is None: - item_scores = all_item_scores[: self.num_items] - ranked_items = item_scores.argsort()[::-1] - else: - item_scores = all_item_scores[item_indices] - ranked_items = np.array(item_indices)[item_scores.argsort()[::-1]] + item_indices = ( + np.arange(self.num_items) + if item_indices is None + else np.asarray(item_indices) + ) + item_scores = all_item_scores[item_indices] + + if ( + k != -1 + ): # O(n + k log k), faster for small k which is usually the case + partitioned_idx = np.argpartition(item_scores, -k) + top_k_idx = partitioned_idx[-k:] + sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])] + partitioned_idx[-k:] = sorted_top_k_idx + ranked_items = item_indices[partitioned_idx[::-1]] + else: # O(n log n) + ranked_items = item_indices[item_scores.argsort()[::-1]] return ranked_items, item_scores diff --git a/examples/ann_all.ipynb b/examples/ann_all.ipynb index f9794a456..277551666 100644 --- a/examples/ann_all.ipynb +++ b/examples/ann_all.ipynb @@ -60,7 +60,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b873c3de940d4f06be58dfbac3b7536e", + "model_id": "49afd52e202546a69dd9c4a245f6db80", "version_major": 2, "version_minor": 0 }, @@ -83,7 +83,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fc6054361c6b47aea308253e7f86a80e", + "model_id": "190f65f8aa9e4b6d9cd8a0e1805dfc53", "version_major": 2, "version_minor": 0 }, @@ -103,7 +103,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8530 | 0.0669 | 0.9041 | 6.4182\n", + "MF | 0.8530 | 0.0669 | 0.9060 | 6.7622\n", "\n" ] } @@ -179,8 +179,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1min 17s, sys: 15.3 ms, total: 1min 17s\n", - "Wall time: 1.63 s\n" + "CPU times: user 1min 14s, sys: 27.3 ms, total: 1min 14s\n", + "Wall time: 1.56 s\n" ] } ], @@ -215,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 6, "id": "fa280b9d-ec04-41eb-9de2-acfb67fbeb80", "metadata": {}, "outputs": [ @@ -223,10 +223,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "AnnoyANN\t\tIndexing=30ms\t\tRetrieval=590ms\t\tRecall=0.01299\n", - "FaissANN\t\tIndexing=16ms\t\tRetrieval=897ms\t\tRecall=0.99938\n", + "AnnoyANN\t\tIndexing=34ms\t\tRetrieval=589ms\t\tRecall=0.01299\n", + "FaissANN\t\tIndexing=109ms\t\tRetrieval=905ms\t\tRecall=0.99938\n", "HNSWLibANN\t\tIndexing=91ms\t\tRetrieval=215ms\t\tRecall=0.99874\n", - "ScaNNANN\t\tIndexing=107ms\t\tRetrieval=512ms\t\tRecall=0.99997\n" + "ScaNNANN\t\tIndexing=1564ms\t\tRetrieval=479ms\t\tRecall=0.99997\n" ] } ], diff --git a/examples/ann_example.py b/examples/ann_example.py new file mode 100644 index 000000000..d5dc62fa8 --- /dev/null +++ b/examples/ann_example.py @@ -0,0 +1,45 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example for comparing different ANN Searchers with BPR model""" + +import cornac +from cornac.data import Reader +from cornac.datasets.netflix import load_feedback +from cornac.eval_methods import RatioSplit +from cornac.metrics import AUC, Recall +from cornac.models import BPR, AnnoyANN, FaissANN, HNSWLibANN, ScaNNANN + + +bpr = BPR(k=50, max_iter=200, learning_rate=0.001, lambda_reg=0.001, verbose=True) + +# using default params of the ANN searchers +# performance could be better if they are carefuly tuned +ann1 = AnnoyANN(bpr, verbose=True) +ann2 = FaissANN(bpr, verbose=True) +ann3 = HNSWLibANN(bpr, verbose=True) +ann4 = ScaNNANN(bpr, verbose=True) + +cornac.Experiment( + eval_method=RatioSplit( + data=load_feedback(variant="small", reader=Reader(bin_threshold=1.0)), + test_size=0.1, + rating_threshold=1.0, + exclude_unknowns=True, + verbose=True, + ), + models=[bpr, ann1, ann2, ann3, ann4], + metrics=[AUC(), Recall(k=50)], + user_based=True, +).run() diff --git a/tutorials/ann_hnswlib.ipynb b/tutorials/ann_hnswlib.ipynb index 583b929cd..fab367036 100644 --- a/tutorials/ann_hnswlib.ipynb +++ b/tutorials/ann_hnswlib.ipynb @@ -10,8 +10,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: hnswlib in /opt/conda/lib/python3.10/site-packages (0.7.0)\n", - "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from hnswlib) (1.26.0)\n" + "Requirement already satisfied: hnswlib in /home/ubuntu/miniconda3/envs/cornac/lib/python3.10/site-packages (0.7.0)\n", + "Requirement already satisfied: numpy in /home/ubuntu/miniconda3/envs/cornac/lib/python3.10/site-packages (from hnswlib) (1.26.2)\n" ] } ], @@ -82,7 +82,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e588c08cad71410aae17e2cc84456631", + "model_id": "018ab8d62ba047868f24b1b138ecd40e", "version_major": 2, "version_minor": 0 }, @@ -105,7 +105,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "31b149c19a0549b782edb0a4bc597b94", + "model_id": "916ea093c8be41deaafdb2e5cc546b54", "version_major": 2, "version_minor": 0 }, @@ -125,7 +125,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8530 | 0.0669 | 1.1054 | 11.9213\n", + "MF | 0.8530 | 0.0669 | 0.9060 | 6.3865\n", "\n" ] } @@ -229,8 +229,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 39s, sys: 15.7 ms, total: 2min 39s\n", - "Wall time: 4.98 s\n" + "CPU times: user 1min 14s, sys: 18.1 ms, total: 1min 14s\n", + "Wall time: 1.56 s\n" ] } ], @@ -252,8 +252,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 288 ms, sys: 29 µs, total: 288 ms\n", - "Wall time: 285 ms\n" + "CPU times: user 218 ms, sys: 32 µs, total: 218 ms\n", + "Wall time: 216 ms\n" ] } ], @@ -283,7 +283,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "99.87549999999999\n" + "99.87450000000001\n" ] } ], @@ -319,7 +319,7 @@ { "data": { "text/plain": [ - "'save_dir/HNSWLibANN/2023-11-14_00-19-12-481323.pkl'" + "'save_dir/HNSWLibANN/2023-12-08_19-27-58-137671.pkl'" ] }, "execution_count": 9, @@ -391,7 +391,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "99.87549999999999\n" + "99.87450000000001\n" ] } ], @@ -423,7 +423,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" } }, "nbformat": 4,