diff --git a/README.md b/README.md index 523517b4..35bb60b9 100644 --- a/README.md +++ b/README.md @@ -153,9 +153,10 @@ The recommender models supported by Cornac are listed below. Why don't you join | | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py) | | [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](cornac/models/lightgcn), [paper](https://arxiv.org/pdf/2002.02126.pdf) | [requirements.txt](cornac/models/lightgcn/requirements.txt) | [lightgcn_example.py](examples/lightgcn_example.py) | | [New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE)](cornac/models/recvae), [paper](https://doi.org/10.1145/3336191.3371831) | [requirements.txt](cornac/models/recvae/requirements.txt) | [recvae_example.py](examples/recvae_example.py) -| | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py) | | [Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF)](cornac/models/upcf), [paper](https://dl.acm.org/doi/abs/10.1145/3340631.3394850) | [requirements.txt](cornac/models/upcf/requirements.txt) | [upcf_tafeng.py](examples/upcf_tafeng.py) -| 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py) +| | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py) +| 2019 | [Correlation-Sensitive Next-Basket Recommendation (Beacon)](cornac/models/beacon), [paper](https://www.ijcai.org/proceedings/2019/0389.pdf) | [requirements.txt](cornac/models/beacon/requirements.txt) | [beacon_tafeng.py](examples/beacon_tafeng.py) +| | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py) | | [Neural Graph Collaborative Filtering (NGCF)](cornac/models/ngcf), [paper](https://arxiv.org/pdf/1905.08108.pdf) | [requirements.txt](cornac/models/ngcf/requirements.txt) | [ngcf_example.py](examples/ngcf_example.py) | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py) | | [Graph Convolutional Matrix Completion (GCMC)](cornac/models/gcmc), [paper](https://www.kdd.org/kdd2018/files/deep-learning-day/DLDay18_paper_32.pdf) | [requirements.txt](cornac/models/gcmc/requirements.txt) | [gcmc_example.py](examples/gcmc_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 49a85c5d..e87f4aa4 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -23,6 +23,7 @@ from .ann import HNSWLibANN from .ann import ScaNNANN from .baseline_only import BaselineOnly +from .beacon import Beacon from .bivaecf import BiVAECF from .bpr import BPR from .bpr import WBPR diff --git a/cornac/models/beacon/__init__.py b/cornac/models/beacon/__init__.py new file mode 100644 index 00000000..1b5763fc --- /dev/null +++ b/cornac/models/beacon/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_beacon import Beacon diff --git a/cornac/models/beacon/beacon_tf.py b/cornac/models/beacon/beacon_tf.py new file mode 100644 index 00000000..2b3e2fa8 --- /dev/null +++ b/cornac/models/beacon/beacon_tf.py @@ -0,0 +1,302 @@ +import numpy as np +import warnings + +# disable annoying tensorflow deprecated API warnings +warnings.filterwarnings("ignore", category=UserWarning) + +import tensorflow.compat.v1 as tf + +tf.logging.set_verbosity(tf.logging.ERROR) +tf.disable_v2_behavior() + + +def create_rnn_cell(cell_type, state_size, default_initializer, reuse=None): + if cell_type == "GRU": + return tf.nn.rnn_cell.GRUCell(state_size, activation=tf.nn.tanh, reuse=reuse) + elif cell_type == "LSTM": + return tf.nn.rnn_cell.LSTMCell( + state_size, + initializer=default_initializer, + activation=tf.nn.tanh, + reuse=reuse, + ) + else: + return tf.nn.rnn_cell.BasicRNNCell( + state_size, activation=tf.nn.tanh, reuse=reuse + ) + + +def create_rnn_encoder( + x, + rnn_units, + dropout_rate, + seq_length, + rnn_cell_type, + param_initializer, + seed, + reuse=None, +): + with tf.variable_scope("RNN_Encoder", reuse=reuse): + rnn_cell = create_rnn_cell(rnn_cell_type, rnn_units, param_initializer) + rnn_cell = tf.nn.rnn_cell.DropoutWrapper( + rnn_cell, input_keep_prob=1 - dropout_rate, seed=seed + ) + init_state = rnn_cell.zero_state(tf.shape(x)[0], tf.float32) + # RNN Encoder: Iteratively compute output of recurrent network + rnn_outputs, _ = tf.nn.dynamic_rnn( + rnn_cell, + x, + initial_state=init_state, + sequence_length=seq_length, + dtype=tf.float32, + ) + return rnn_outputs + + +def create_basket_encoder( + x, + dense_units, + param_initializer, + activation_func=None, + name="Basket_Encoder", + reuse=None, +): + with tf.variable_scope(name, reuse=reuse): + return tf.layers.dense( + x, + dense_units, + kernel_initializer=param_initializer, + bias_initializer=tf.zeros_initializer, + activation=activation_func, + ) + + +def get_last_right_output(full_output, max_length, actual_length, rnn_units): + batch_size = tf.shape(full_output)[0] + # Start indices for each sample + index = tf.range(0, batch_size) * max_length + (actual_length - 1) + # Indexing + return tf.gather(tf.reshape(full_output, [-1, rnn_units]), index) + + +class BeaconModel: + def __init__( + self, + sess, + emb_dim, + rnn_units, + alpha, + max_seq_length, + n_items, + item_probs, + adj_matrix, + rnn_cell_type, + rnn_dropout_rate, + seed, + lr, + ): + self.scope = "GRN" + self.session = sess + self.seed = seed + self.lr = tf.constant(lr) + + self.emb_dim = emb_dim + self.rnn_units = rnn_units + + self.max_seq_length = max_seq_length + self.n_items = n_items + self.item_probs = item_probs + self.alpha = alpha + + with tf.variable_scope(self.scope): + # Initialized for n_hop adjacency matrix + self.A = tf.constant( + adj_matrix.todense(), name="Adj_Matrix", dtype=tf.float32 + ) + + uniform_initializer = ( + np.ones(shape=(self.n_items), dtype=np.float32) / self.n_items + ) + self.I_B = tf.get_variable( + dtype=tf.float32, + initializer=tf.constant(uniform_initializer, dtype=tf.float32), + name="I_B", + ) + self.I_B_Diag = tf.nn.relu(tf.diag(self.I_B, name="I_B_Diag")) + + self.C_Basket = tf.get_variable( + dtype=tf.float32, initializer=tf.constant(adj_matrix.mean()), name="C_B" + ) + self.y = tf.placeholder( + dtype=tf.float32, + shape=(None, self.n_items), + name="Target_basket", + ) + + # Basket Sequence encoder + with tf.name_scope("Basket_Sequence_Encoder"): + self.bseq = tf.sparse.placeholder( + dtype=tf.float32, + name="bseq_input", + ) + self.bseq_length = tf.placeholder( + dtype=tf.int32, shape=(None,), name="bseq_length" + ) + + self.bseq_encoder = tf.sparse.reshape( + self.bseq, shape=[-1, self.n_items], name="bseq_2d" + ) + self.bseq_encoder = self.encode_basket_graph( + self.bseq_encoder, self.C_Basket, True + ) + self.bseq_encoder = tf.reshape( + self.bseq_encoder, + shape=[-1, self.max_seq_length, self.n_items], + name="bsxMxN", + ) + self.bseq_encoder = create_basket_encoder( + self.bseq_encoder, + emb_dim, + param_initializer=tf.initializers.he_uniform(), + activation_func=tf.nn.relu, + ) + + # batch_size x max_seq_length x H + rnn_encoder = create_rnn_encoder( + self.bseq_encoder, + self.rnn_units, + rnn_dropout_rate, + self.bseq_length, + rnn_cell_type, + param_initializer=tf.initializers.glorot_uniform(), + seed=self.seed, + ) + + # Hack to build the indexing and retrieve the right output. # batch_size x H + h_T = get_last_right_output( + rnn_encoder, self.max_seq_length, self.bseq_length, self.rnn_units + ) + + # Next basket estimation + with tf.name_scope("Next_Basket"): + W_H = tf.get_variable( + dtype=tf.float32, + initializer=tf.initializers.glorot_uniform(), + shape=(self.rnn_units, self.n_items), + name="W_H", + ) + + next_item_probs = tf.nn.sigmoid(tf.matmul(h_T, W_H)) + logits = ( + 1.0 - self.alpha + ) * next_item_probs + self.alpha * self.encode_basket_graph( + next_item_probs, tf.constant(0.0) + ) + + with tf.name_scope("Loss"): + self.loss = self.compute_loss(logits, self.y) + + self.predictions = tf.nn.sigmoid(logits) + + # Adam optimizer + train_op = tf.train.RMSPropOptimizer(learning_rate=self.lr) + + # Op to calculate every variable gradient + self.grads = train_op.compute_gradients(self.loss, tf.trainable_variables()) + self.update_grads = train_op.apply_gradients(self.grads) + + def train_batch(self, s, s_length, y): + bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True) + + [_, loss] = self.session.run( + [self.update_grads, self.loss], + feed_dict={ + self.bseq: (bseq_indices, bseq_values, bseq_shape), + self.bseq_length: s_length, + self.y: y, + }, + ) + + return loss + + def validate_batch(self, s, s_length, y): + bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True) + + loss = self.session.run( + self.loss, + feed_dict={ + self.bseq: (bseq_indices, bseq_values, bseq_shape), + self.bseq_length: s_length, + self.y: y, + }, + ) + return loss + + def predict(self, s, s_length): + bseq_indices, bseq_values, bseq_shape = self.get_sparse_tensor_info(s, True) + predictions = self.session.run( + self.predictions, + feed_dict={ + self.bseq: (bseq_indices, bseq_values, bseq_shape), + self.bseq_length: s_length, + }, + ) + return predictions.squeeze() + + def encode_basket_graph(self, binput, beta, is_sparse=False): + with tf.name_scope("Graph_Encoder"): + if is_sparse: + encoder = tf.sparse_tensor_dense_matmul( + binput, self.I_B_Diag, name="XxI_B" + ) + encoder += self.relu_with_threshold( + tf.sparse_tensor_dense_matmul(binput, self.A, name="XxA"), beta + ) + else: + encoder = tf.matmul(binput, self.I_B_Diag, name="XxI_B") + encoder += self.relu_with_threshold( + tf.matmul(binput, self.A, name="XxA"), beta + ) + return encoder + + def get_sparse_tensor_info(self, x, is_bseq=False): + indices = [] + if is_bseq: + for sid, bseq in enumerate(x): + for t, basket in enumerate(bseq): + for item_id in basket: + indices.append([sid, t, item_id]) + else: + for bid, basket in enumerate(x): + for item_id in basket: + indices.append([bid, item_id]) + + values = np.ones(len(indices), dtype=np.float32) + indices = np.array(indices, dtype=np.int32) + shape = np.array([len(x), self.max_seq_length, self.n_items], dtype=np.int64) + return indices, values, shape + + def compute_loss(self, logits, y): + sigmoid_logits = tf.nn.sigmoid(logits) + + neg_y = 1.0 - y + pos_logits = y * logits + + pos_max = tf.reduce_max(pos_logits, axis=1) + pos_max = tf.expand_dims(pos_max, axis=-1) + + pos_min = tf.reduce_min(pos_logits + neg_y * pos_max, axis=1) + pos_min = tf.expand_dims(pos_min, axis=-1) + + nb_pos, nb_neg = tf.count_nonzero(y, axis=1), tf.count_nonzero(neg_y, axis=1) + ratio = tf.cast(nb_neg, dtype=tf.float32) / tf.cast(nb_pos, dtype=tf.float32) + + pos_weight = tf.expand_dims(ratio, axis=-1) + loss = y * -tf.log(sigmoid_logits) * pos_weight + neg_y * -tf.log( + 1.0 - tf.nn.sigmoid(logits - pos_min) + ) + + return tf.reduce_mean(loss + 1e-8) + + def relu_with_threshold(self, x, threshold): + return tf.nn.relu(x - tf.abs(threshold)) diff --git a/cornac/models/beacon/recom_beacon.py b/cornac/models/beacon/recom_beacon.py new file mode 100644 index 00000000..1838a2e3 --- /dev/null +++ b/cornac/models/beacon/recom_beacon.py @@ -0,0 +1,290 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +from collections import Counter + +import numpy as np +from scipy.sparse import csc_matrix, csr_matrix, diags +from tqdm.auto import trange + +from ..recommender import NextBasketRecommender + + +class Beacon(NextBasketRecommender): + """Correlation-Sensitive Next-Basket Recommendation + + Parameters + ---------- + name: string, default: 'Beacon' + The name of the recommender model. + + emb_dim: int, optional, default: 2 + Embedding dimension + + rnn_unit: int, optional, default: 4 + Number of dimension in a rnn unit. + + alpha: float, optional, default: 0.5 + Hyperparameter to control the balance between correlative and sequential associations. + + rnn_cell_type: str, optional, default: 'LSTM' + RNN cell type, options including ['LSTM', 'GRU', None] + If None, BasicRNNCell will be used. + + dropout_rate: float, optional, default: 0.5 + Dropout rate of neural network dense layers + + nb_hop: int, optional, default: 1 + Number of hops for constructing correlation matrix. + If 0, zeros matrix will be used. + + n_epochs: int, optional, default: 15 + Number of training epochs + + batch_size: int, optional, default: 32 + Batch size + + lr: float, optional, default: 0.001 + Initial value of learning rate for the optimizer. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + + seed: int, optional, default: None + Random seed + + References + ---------- + LE, Duc Trong, Hady Wirawan LAUW, and Yuan Fang. + Correlation-sensitive next-basket recommendation. + International Joint Conferences on Artificial Intelligence, 2019. + + """ + + def __init__( + self, + name="Beacon", + emb_dim=2, + rnn_unit=4, + alpha=0.5, + rnn_cell_type="LSTM", + dropout_rate=0.5, + nb_hop=1, + n_epochs=15, + batch_size=32, + lr=0.001, + trainable=True, + verbose=False, + seed=None, + ): + super().__init__(name=name, trainable=trainable, verbose=verbose) + self.n_epochs = n_epochs + self.batch_size = batch_size + self.nb_hop = nb_hop + self.emb_dim = emb_dim + self.rnn_unit = rnn_unit + self.alpha = alpha + self.rnn_cell_type = rnn_cell_type + self.dropout_rate = dropout_rate + self.seed = seed + self.lr = lr + + def fit(self, train_set, val_set=None): + import tensorflow.compat.v1 as tf + + from .beacon_tf import BeaconModel + + tf.disable_eager_execution() + + # less verbose TF + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + tf.logging.set_verbosity(tf.logging.ERROR) + + super().fit(train_set=train_set, val_set=val_set) + + self.correlation_matrix = self._build_correlation_matrix( + train_set=train_set, val_set=val_set, n_items=self.total_items + ) + self.item_probs = self._compute_item_probs( + train_set=train_set, val_set=val_set, n_items=self.total_items + ) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.log_device_placement = False + sess = tf.Session(config=config) + + self.model = BeaconModel( + sess, + self.emb_dim, + self.rnn_unit, + self.alpha, + train_set.max_basket_size, + self.total_items, + self.item_probs, + self.correlation_matrix, + self.rnn_cell_type, + self.dropout_rate, + self.seed, + self.lr, + ) + + sess.run(tf.global_variables_initializer()) # init variable + + last_loss = np.inf + last_val_loss = np.inf + loop = trange(self.n_epochs, disable=not self.verbose) + loop.set_postfix( + loss=last_loss, + val_loss=last_val_loss, + ) + train_pool = [] + validation_pool = [] + for _ in loop: + train_loss = 0.0 + trained_cnt = 0 + for batch_basket_items in self._data_iter( + train_set, shuffle=True, current_pool=train_pool + ): + s, s_length, y = self._transform_data( + batch_basket_items, self.total_items + ) + loss = self.model.train_batch(s, s_length, y) + current_batch_size = len(batch_basket_items) + trained_cnt += current_batch_size + train_loss += loss * current_batch_size + last_loss = train_loss / trained_cnt + loop.set_postfix( + loss=last_loss, + val_loss=last_val_loss, + ) + + if val_set is not None: + val_loss = 0.0 + val_cnt = 0 + for batch_basket_items in self._data_iter( + val_set, shuffle=False, current_pool=validation_pool + ): + s, s_length, y = self._transform_data( + batch_basket_items, self.total_items + ) + loss = self.model.validate_batch(s, s_length, y) + current_batch_size = len(batch_basket_items) + val_cnt += current_batch_size + val_loss += loss * current_batch_size + last_val_loss = val_loss / val_cnt + loop.set_postfix( + loss=last_loss, + val_loss=last_val_loss, + ) + + return self + + def _data_iter(self, data_set, shuffle=False, current_pool=[]): + """This iterator ensure each batch has same size, the remaining data will be preceded in the next epoch""" + for _, _, batch_basket_items in data_set.ubi_iter( + batch_size=self.batch_size, shuffle=shuffle + ): + current_pool += batch_basket_items + if len(current_pool) >= self.batch_size: + yield current_pool[: self.batch_size] + del current_pool[self.batch_size :] + + def _transform_data(self, batch_basket_items, n_items): + assert len(batch_basket_items) == self.batch_size + s = [basket_items[:-1] for basket_items in batch_basket_items] + s_length = [len(b) for b in s] + y = np.zeros((self.batch_size, n_items), dtype="int32") + for inc, basket_items in enumerate(batch_basket_items): + y[inc, basket_items[-1]] = 1 + return s, s_length, y + + def _build_correlation_matrix(self, train_set, val_set, n_items): + if self.nb_hop == 0: + return csr_matrix((n_items, n_items), dtype="float32") + + pairs_cnt = Counter() + for _, _, [basket_items] in train_set.ubi_iter(1, shuffle=False): + for items in basket_items: + current_items = np.unique(items) + for i in range(len(current_items) - 1): + for j in range(i + 1, len(current_items)): + pairs_cnt[(current_items[i], current_items[j])] += 1 + if val_set is not None: + for _, _, [basket_items] in val_set.ubi_iter(1, shuffle=False): + for items in basket_items: + current_items = np.unique(items) + for i in range(len(current_items) - 1): + for j in range(i + 1, len(current_items)): + pairs_cnt[(current_items[i], current_items[j])] += 1 + data, row, col = [], [], [] + for pair, cnt in pairs_cnt.most_common(): + data.append(cnt) + row.append(pair[0]) + col.append(pair[1]) + correlation_matrix = csc_matrix( + (data, (row, col)), shape=(n_items, n_items), dtype="float32" + ) + correlation_matrix = self._normalize(correlation_matrix) + + w_mul = correlation_matrix + coeff = 1.0 + for _ in range(1, self.nb_hop): + coeff *= 0.85 + w_mul *= correlation_matrix + w_mul = self._remove_diag(w_mul) + w_adj_matrix = self._normalize(w_mul) + correlation_matrix += coeff * w_adj_matrix + + return correlation_matrix + + def _remove_diag(self, adj_matrix): + new_adj_matrix = csr_matrix(adj_matrix) + new_adj_matrix.setdiag(0.0) + new_adj_matrix.eliminate_zeros() + return new_adj_matrix + + def _normalize(self, adj_matrix: csr_matrix): + """Symmetrically normalize adjacency matrix.""" + row_sum = adj_matrix.sum(1).A.squeeze() + d_inv_sqrt = np.power( + row_sum, + -0.5, + out=np.zeros_like(row_sum, dtype="float32"), + where=row_sum != 0, + ) + d_mat_inv_sqrt = diags(d_inv_sqrt) + + normalized_matrix = ( + adj_matrix.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) + ) + + return normalized_matrix.tocsr() + + def _compute_item_probs(self, train_set, val_set, n_items): + item_freq = Counter(train_set.uir_tuple[1]) + if val_set is not None: + item_freq += Counter(val_set.uir_tuple[1]) + item_probs = np.zeros(n_items, dtype="float32") + total_cnt = len(train_set.uir_tuple[1]) + len(val_set.uir_tuple[1]) + for iid, cnt in item_freq.items(): + item_probs[iid] = cnt / total_cnt + return item_probs + + def score(self, user_idx, history_baskets, **kwargs): + s = [history_baskets] + s_length = [len(history_baskets)] + return self.model.predict(s, s_length) diff --git a/cornac/models/beacon/requirements.txt b/cornac/models/beacon/requirements.txt new file mode 100644 index 00000000..afc544f5 --- /dev/null +++ b/cornac/models/beacon/requirements.txt @@ -0,0 +1 @@ +tensorflow[and-cuda]==2.15.0 diff --git a/docs/source/api_ref/models.rst b/docs/source/api_ref/models.rst index 5006dd84..94b3a01a 100644 --- a/docs/source/api_ref/models.rst +++ b/docs/source/api_ref/models.rst @@ -54,6 +54,11 @@ Temporal-Item-Frequency-based User-KNN (TIFUKNN) .. automodule:: cornac.models.tifuknn.recom_tifuknn :members: +Correlation-Sensitive Next-Basket Recommendation (Beacon) +--------------------------------------------------- +.. automodule:: cornac.models.beacon.recom_beacon + :members: + Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ) ----------------------------------------------------------- .. automodule:: cornac.models.ease.recom_ease diff --git a/examples/README.md b/examples/README.md index 99b7606e..f7812f63 100644 --- a/examples/README.md +++ b/examples/README.md @@ -120,6 +120,8 @@ [gp_top_tafeng.py](gp_top_tafeng.py) - Next-basket recommendation model that merely uses item top frequency. +[beacon_tafeng.py](beacon_tafeng.py) - Correlation-Sensitive Next-Basket Recommendation (Beacon). + [tifuknn_tafeng.py](tifuknn_tafeng.py) - Example of Temporal-Item-Frequency-based User-KNN (TIFUKNN). [upcf_tafeng.py](upcf_tafeng.py) - Example of Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF). diff --git a/examples/beacon_tafeng.py b/examples/beacon_tafeng.py new file mode 100644 index 00000000..9c938bbd --- /dev/null +++ b/examples/beacon_tafeng.py @@ -0,0 +1,56 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example of Correlation-Sensitive Next-Basket Recommendation Model (Beacon)""" + +import cornac +from cornac.eval_methods import NextBasketEvaluation +from cornac.metrics import NDCG, HitRatio, Recall +from cornac.models import Beacon + +data = cornac.datasets.tafeng.load_basket( + reader=cornac.data.Reader( + min_basket_size=3, max_basket_size=50, min_basket_sequence=2 + ) +) + +next_basket_eval = NextBasketEvaluation( + data=data, fmt="UBITJson", test_size=0.2, val_size=0.08, seed=123, verbose=True +) + +models = [ + Beacon( + emb_dim=2, + rnn_unit=4, + alpha=0.5, + rnn_cell_type="LSTM", + dropout_rate=0.5, + nb_hop=1, + n_epochs=15, + batch_size=32, + lr=0.001, + verbose=True, + ) +] + +metrics = [ + Recall(k=10), + Recall(k=50), + NDCG(k=10), + NDCG(k=50), + HitRatio(k=10), + HitRatio(k=50), +] + +cornac.Experiment(eval_method=next_basket_eval, models=models, metrics=metrics).run()