Skip to content

Commit

Permalink
Merge branch 'master' into beacon
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang authored Jan 17, 2024
2 parents 0ad497c + b845e88 commit d796824
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
| | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py)
| | [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](cornac/models/lightgcn), [paper](https://arxiv.org/pdf/2002.02126.pdf) | [requirements.txt](cornac/models/lightgcn/requirements.txt) | [lightgcn_example.py](examples/lightgcn_example.py)
| | [New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE)](cornac/models/recvae), [paper](https://doi.org/10.1145/3336191.3371831) | [requirements.txt](cornac/models/recvae/requirements.txt) | [recvae_example.py](examples/recvae_example.py)
| | [Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF)](cornac/models/upcf), [paper](https://dl.acm.org/doi/abs/10.1145/3340631.3394850) | [requirements.txt](cornac/models/upcf/requirements.txt) | [upcf_tafeng.py](examples/upcf_tafeng.py)
| | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py)
| 2019 | [Correlation-Sensitive Next-Basket Recommendation (Beacon)](cornac/models/beacon), [paper](https://www.ijcai.org/proceedings/2019/0389.pdf) | [requirements.txt](cornac/models/beacon/requirements.txt) | [beacon_tafeng.py](examples/beacon_tafeng.py)
| | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py)
Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .svd import SVD
from .tifuknn import TIFUKNN
from .trirank import TriRank
from .upcf import UPCF
from .vaecf import VAECF
from .vbpr import VBPR
from .vmf import VMF
Expand Down
16 changes: 16 additions & 0 deletions cornac/models/upcf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from .recom_upcf import UPCF
133 changes: 133 additions & 0 deletions cornac/models/upcf/recom_upcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import itertools

import numpy as np
from scipy.sparse import csr_matrix, vstack

from ..recommender import NextBasketRecommender


class UPCF(NextBasketRecommender):
"""User Popularity-based CF (UPCF)
Parameters
----------
name: string, default: 'UPCF'
The name of the recommender model.
recency: int, optional, default: 1
The size of recency window.
If 0, all baskets will be used.
locality: int, optional, default: 1
The strength we enforce the similarity between two items within a basket
asymmetry: float, optional, default: 0.25
Trade-off parameter which balances the importance of the probability of having item i given j and probability having item j given i.
This value will be computed via `similaripy.asymetric_cosine`.
verbose: boolean, optional, default: False
When True, running logs are displayed.
References
----------
Guglielmo Faggioli, Mirko Polato, and Fabio Aiolli. 2020.
Recency Aware Collaborative Filtering for Next Basket Recommendation.
In Proceedings of the 28th ACM Conference on User Modeling, Adaptation and Personalization (UMAP '20). Association for Computing Machinery, New York, NY, USA, 80–87. https://doi.org/10.1145/3340631.3394850
"""

def __init__(
self,
name="UPCF",
recency=1,
locality=1,
asymmetry=0.25,
verbose=False,
):
super().__init__(name=name, trainable=False, verbose=verbose)
self.recency = recency
self.locality = locality
self.asymmetry = asymmetry

def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
self.user_wise_popularity = vstack(
[
self._get_user_wise_popularity(basket_items)
for _, _, [basket_items] in train_set.ubi_iter(
batch_size=1, shuffle=False
)
]
)
(u_indices, i_indices, r_values) = train_set.uir_tuple
self.user_item_matrix = csr_matrix(
(r_values, (u_indices, i_indices)),
shape=(train_set.num_users, self.total_items),
dtype="float32",
)
return self

def _get_user_wise_popularity(self, basket_items):
users = []
items = []
scores = []
recent_basket_items = (
basket_items[-self.recency :] if self.recency > 0 else basket_items
)
for iid in list(set(itertools.chain.from_iterable(recent_basket_items))):
users.append(0)
items.append(iid)
denominator = (
min(self.recency, len(recent_basket_items))
if self.recency > 0
else len(recent_basket_items)
)
numerator = sum([1 for items in recent_basket_items if iid in items])
scores.append(numerator / denominator)
return csr_matrix(
(scores, (users, items)), shape=(1, self.total_items), dtype="float32"
)

def score(self, user_idx, history_baskets, **kwargs):
import similaripy as sim

items = list(set(itertools.chain.from_iterable(history_baskets)))
current_user_item_matrix = csr_matrix(
(np.ones(len(items)), (np.zeros(len(items)), items)),
shape=(1, self.total_items),
dtype="float32",
)
current_user_wise_popularity = self._get_user_wise_popularity(history_baskets)
user_wise_popularity = vstack(
[current_user_wise_popularity, self.user_wise_popularity]
)
user_item_matrix = vstack([current_user_item_matrix, self.user_item_matrix])
user_sim = sim.asymmetric_cosine(
user_item_matrix, alpha=self.asymmetry, target_rows=[0], verbose=False
)
scores = (
sim.dot_product(
user_sim.power(self.locality).tocsr()[0],
user_wise_popularity,
verbose=False,
)
.toarray()
.squeeze()
)

return scores
1 change: 1 addition & 0 deletions cornac/models/upcf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
similaripy==0.1.3
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,5 @@
[beacon_tafeng.py](beacon_tafeng.py) - Correlation-Sensitive Next-Basket Recommendation (Beacon).

[tifuknn_tafeng.py](tifuknn_tafeng.py) - Example of Temporal-Item-Frequency-based User-KNN (TIFUKNN).

[upcf_tafeng.py](upcf_tafeng.py) - Example of Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF).
50 changes: 50 additions & 0 deletions examples/upcf_tafeng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Example of Recency Aware Collaborative Filtering for Next Basket Recommendation (UPCF)"""

import cornac
from cornac.eval_methods import NextBasketEvaluation
from cornac.metrics import NDCG, HitRatio, Recall
from cornac.models import UPCF

data = cornac.datasets.tafeng.load_basket(
reader=cornac.data.Reader(
min_basket_size=3, max_basket_size=50, min_basket_sequence=2
)
)

next_basket_eval = NextBasketEvaluation(
data=data, fmt="UBITJson", test_size=0.2, val_size=0.08, seed=123, verbose=True
)

models = [
UPCF(
recency=1,
locality=1,
asymmetry=0.25,
verbose=True,
)
]

metrics = [
Recall(k=10),
Recall(k=50),
NDCG(k=10),
NDCG(k=50),
HitRatio(k=10),
HitRatio(k=50),
]

cornac.Experiment(eval_method=next_basket_eval, models=models, metrics=metrics).run()

0 comments on commit d796824

Please sign in to comment.