From a9b93cf72a792a4282d2c5bb16fcb47970cc8721 Mon Sep 17 00:00:00 2001 From: Andreas Peintner Date: Wed, 2 Oct 2024 11:00:16 +0200 Subject: [PATCH 1/2] init time based split --- config.yaml | 6 ++++ recbole/data/dataset/dataset.py | 63 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 config.yaml diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..60d521f2a --- /dev/null +++ b/config.yaml @@ -0,0 +1,6 @@ +model: GRU4Rec +eval_args: + split: {'TS': [0.8, 0.1, 0.1]} + #split: {'RS': [0.8, 0.1, 0.1]} + mode: full + order: TO \ No newline at end of file diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 35fce89c6..9dbbe23cf 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1729,6 +1729,62 @@ def leave_one_out(self, group_by, leave_one_mode): next_ds = [self.copy(_) for _ in next_df] return next_ds + def time_based_split(self, ratios, group_by): + """Split interaction records by time-based strategy that combines global temporal and leave-one-out constraints. + + Args: + ratios (list): List of split ratios. + group_by (str): Field name that interaction records should be grouped by before splitting. + + Returns: + list: List of :class:`~Dataset`, whose interaction features have been split. + """ + self.logger.debug(f"time based split, group_by=[{group_by}]") + if group_by is None: + raise ValueError("Time-based split strategy requires a group field") + + if self.time_field not in self.inter_feat: + raise ValueError(f"Field [{self.time_field}] is not in inter_feat.") + + self.logger.debug(f"time-based split with ratios [{ratios}], group_by=[{group_by}]") + tot_ratio = sum(ratios) + ratios = [_ / tot_ratio for _ in ratios] + + # Determine the global temporal boundary (e.g., 90th percentile) + all_times = self.inter_feat[self.time_field].numpy() + global_temporal_boundary = np.percentile(all_times, 100 * (1 - ratios[-1])) + + train_index, valid_index, test_index = [], [], [] + grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy()) + + for grouped_index in grouped_inter_feat_index: + grouped_index = np.array(grouped_index) + grouped_inter_feat = self.inter_feat[grouped_index] + grouped_inter_feat.sort(by=self.time_field) + + # Split into training/validation and test sets based on the global temporal boundary + times = grouped_inter_feat[self.time_field].numpy() + train_valid_mask = times <= global_temporal_boundary + test_mask = ~train_valid_mask + + train_valid_index = grouped_index[train_valid_mask] + test_user_indices = grouped_index[test_mask] + + split_point = int(len(train_valid_index) * (ratios[0] / (1 - ratios[2]))) + train_index.extend(train_valid_index[:split_point]) + valid_index.extend(train_valid_index[split_point:]) + + test_index.extend(test_user_indices) + + self._drop_unused_col() + next_df = [ + self.inter_feat[train_index], + self.inter_feat[valid_index], + self.inter_feat[test_index], + ] + next_ds = [self.copy(_) for _ in next_df] + return next_ds + def shuffle(self): """Shuffle the interaction records inplace.""" self.inter_feat.shuffle() @@ -1799,6 +1855,13 @@ def build(self): datasets = self.leave_one_out( group_by=self.uid_field, leave_one_mode=split_args["LS"] ) + elif split_mode == "TS": + if not isinstance(split_args["TS"], list): + raise ValueError(f'The value of "TS" [{split_args}] should be a list.') + datasets = self.time_based_split( + ratios=split_args["TS"], + group_by=self.uid_field + ) else: raise NotImplementedError( f"The splitting_method [{split_mode}] has not been implemented." From 7837989cf12c49d6a78173bd7b61b97ba04bd8d1 Mon Sep 17 00:00:00 2001 From: Andreas Peintner Date: Wed, 2 Oct 2024 13:21:57 +0200 Subject: [PATCH 2/2] remove test config file --- config.yaml | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 config.yaml diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 60d521f2a..000000000 --- a/config.yaml +++ /dev/null @@ -1,6 +0,0 @@ -model: GRU4Rec -eval_args: - split: {'TS': [0.8, 0.1, 0.1]} - #split: {'RS': [0.8, 0.1, 0.1]} - mode: full - order: TO \ No newline at end of file