diff --git a/libcity/config/data/STPGCNDataset.json b/libcity/config/data/STPGCNDataset.json index ebe314e4..d2edeaae 100644 --- a/libcity/config/data/STPGCNDataset.json +++ b/libcity/config/data/STPGCNDataset.json @@ -11,7 +11,6 @@ "input_window": 12, "output_window": 12, - "points_per_hour": 12, "alpha": 4, "beta": 2 } diff --git a/libcity/data/dataset/dataset_subclass/stpgcn_dataset.py b/libcity/data/dataset/dataset_subclass/stpgcn_dataset.py index f0e1373c..7838f82a 100644 --- a/libcity/data/dataset/dataset_subclass/stpgcn_dataset.py +++ b/libcity/data/dataset/dataset_subclass/stpgcn_dataset.py @@ -38,7 +38,7 @@ def __init__(self, config): 'point_based_{}.npz'.format(self.parameters_str)) self.feature_name = {'X': 'float', 'y': 'float', 'pos_w': 'int', 'pos_d': 'int'} - self.points_per_hour = config.get('points_per_hour', 12) + self.points_per_hour = self.time_intervals // 60 self.alpha = config.get('alpha', 4) self.beta = config.get('beta', 2) self.t_size = self.beta + 1 @@ -60,7 +60,7 @@ def get_data_feature(self): """ return {"scaler": self.scaler, "ext_dim": self.ext_dim, "spatial_distance": self.spatial_distance, "range_mask": self.range_mask, "num_nodes": self.num_nodes, "feature_dim": self.feature_dim, - "output_dim": self.output_dim, "num_batches": self.num_batches} + "output_dim": self.output_dim, "num_batches": self.num_batches, "points_per_hour": self.points_per_hour} def _load_cache_train_val_test(self): self._logger.info('Loading ' + self.cache_file_name) diff --git a/libcity/model/traffic_flow_prediction/STPGCN.py b/libcity/model/traffic_flow_prediction/STPGCN.py index c6b5c687..4dfea84e 100644 --- a/libcity/model/traffic_flow_prediction/STPGCN.py +++ b/libcity/model/traffic_flow_prediction/STPGCN.py @@ -381,7 +381,7 @@ def __init__(self, config, data_feature): self.beta = config.get("beta", 2) self.t_size = self.beta + 1 self.week_len = 7 - self.day_len = config.get("points_per_hour") * 24 + self.day_len = self.data_feature.get("points_per_hour") * 24 self.range_mask = torch.Tensor(self.range_mask).to(self.device) self.PAD = GeneratePad(self.device, self.C, self.V, self.d, self.beta)