Skip to content

Commit

Permalink
Merge pull request #2094 from pintonos/timesplit
Browse files Browse the repository at this point in the history
[Feature] Time-based split in evaluation
  • Loading branch information
Fotiligner authored Feb 23, 2025
2 parents eceee7b + 7837989 commit c2e7c06
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,62 @@ def leave_one_out(self, group_by, leave_one_mode):
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def time_based_split(self, ratios, group_by):
"""Split interaction records by time-based strategy that combines global temporal and leave-one-out constraints.
Args:
ratios (list): List of split ratios.
group_by (str): Field name that interaction records should be grouped by before splitting.
Returns:
list: List of :class:`~Dataset`, whose interaction features have been split.
"""
self.logger.debug(f"time based split, group_by=[{group_by}]")
if group_by is None:
raise ValueError("Time-based split strategy requires a group field")

if self.time_field not in self.inter_feat:
raise ValueError(f"Field [{self.time_field}] is not in inter_feat.")

self.logger.debug(f"time-based split with ratios [{ratios}], group_by=[{group_by}]")
tot_ratio = sum(ratios)
ratios = [_ / tot_ratio for _ in ratios]

# Determine the global temporal boundary (e.g., 90th percentile)
all_times = self.inter_feat[self.time_field].numpy()
global_temporal_boundary = np.percentile(all_times, 100 * (1 - ratios[-1]))

train_index, valid_index, test_index = [], [], []
grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy())

for grouped_index in grouped_inter_feat_index:
grouped_index = np.array(grouped_index)
grouped_inter_feat = self.inter_feat[grouped_index]
grouped_inter_feat.sort(by=self.time_field)

# Split into training/validation and test sets based on the global temporal boundary
times = grouped_inter_feat[self.time_field].numpy()
train_valid_mask = times <= global_temporal_boundary
test_mask = ~train_valid_mask

train_valid_index = grouped_index[train_valid_mask]
test_user_indices = grouped_index[test_mask]

split_point = int(len(train_valid_index) * (ratios[0] / (1 - ratios[2])))
train_index.extend(train_valid_index[:split_point])
valid_index.extend(train_valid_index[split_point:])

test_index.extend(test_user_indices)

self._drop_unused_col()
next_df = [
self.inter_feat[train_index],
self.inter_feat[valid_index],
self.inter_feat[test_index],
]
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def shuffle(self):
"""Shuffle the interaction records inplace."""
self.inter_feat.shuffle()
Expand Down Expand Up @@ -1799,6 +1855,13 @@ def build(self):
datasets = self.leave_one_out(
group_by=self.uid_field, leave_one_mode=split_args["LS"]
)
elif split_mode == "TS":
if not isinstance(split_args["TS"], list):
raise ValueError(f'The value of "TS" [{split_args}] should be a list.')
datasets = self.time_based_split(
ratios=split_args["TS"],
group_by=self.uid_field
)
else:
raise NotImplementedError(
f"The splitting_method [{split_mode}] has not been implemented."
Expand Down

0 comments on commit c2e7c06

Please sign in to comment.