diff --git a/src/gluonts/torch/distributions/distribution_output.py b/src/gluonts/torch/distributions/distribution_output.py index af786ca4ef..e583d941b5 100644 --- a/src/gluonts/torch/distributions/distribution_output.py +++ b/src/gluonts/torch/distributions/distribution_output.py @@ -90,14 +90,6 @@ def loss( nll = nll * (variance.detach() ** self.beta) return nll - @property - def event_shape(self) -> Tuple: - r""" - Shape of each individual event contemplated by the distributions that - this object constructs. - """ - raise NotImplementedError() - @property def event_dim(self) -> int: r""" diff --git a/src/gluonts/torch/distributions/output.py b/src/gluonts/torch/distributions/output.py index 49385f0e55..83d22246bb 100644 --- a/src/gluonts/torch/distributions/output.py +++ b/src/gluonts/torch/distributions/output.py @@ -105,6 +105,13 @@ def loss( """ raise NotImplementedError() + @property + def event_shape(self) -> Tuple: + r""" + Shape of each individual event compatible with the output object. + """ + raise NotImplementedError() + @property def forecast_generator(self) -> ForecastGenerator: raise NotImplementedError() diff --git a/src/gluonts/torch/distributions/quantile_output.py b/src/gluonts/torch/distributions/quantile_output.py index 4bce0fad53..ce104c5703 100644 --- a/src/gluonts/torch/distributions/quantile_output.py +++ b/src/gluonts/torch/distributions/quantile_output.py @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None: def forecast_generator(self) -> ForecastGenerator: return QuantileForecastGenerator(quantiles=self.quantiles) + @property + def event_shape(self) -> Tuple: + return () + def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: return args diff --git a/src/gluonts/torch/model/tide/estimator.py b/src/gluonts/torch/model/tide/estimator.py index ab6c004c02..35f8cf5362 100644 --- a/src/gluonts/torch/model/tide/estimator.py +++ b/src/gluonts/torch/model/tide/estimator.py @@ -21,7 +21,6 @@ from gluonts.dataset.field_names import FieldName from gluonts.dataset.loader import as_stacked_batches from gluonts.itertools import Cyclic -from gluonts.model.forecast_generator import DistributionForecastGenerator from gluonts.time_feature import ( minute_of_hour, hour_of_day, @@ -49,10 +48,7 @@ from gluonts.torch.model.estimator import PyTorchLightningEstimator from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.torch.distributions import ( - DistributionOutput, - StudentTOutput, -) +from gluonts.torch.distributions import Output, StudentTOutput from .lightning_module import TiDELightningModule @@ -174,7 +170,7 @@ def __init__( weight_decay: float = 1e-8, patience: int = 10, scaling: Optional[str] = "mean", - distr_output: DistributionOutput = StudentTOutput(), + distr_output: Output = StudentTOutput(), batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, @@ -403,9 +399,7 @@ def create_predictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module, - forecast_generator=DistributionForecastGenerator( - self.distr_output - ), + forecast_generator=self.distr_output.forecast_generator, batch_size=self.batch_size, prediction_length=self.prediction_length, device="auto", diff --git a/src/gluonts/torch/model/tide/module.py b/src/gluonts/torch/model/tide/module.py index 875a05f8b2..e0eb06cb85 100644 --- a/src/gluonts/torch/model/tide/module.py +++ b/src/gluonts/torch/model/tide/module.py @@ -19,7 +19,7 @@ from gluonts.core.component import validated from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.model import Input, InputSpec -from gluonts.torch.distributions import DistributionOutput +from gluonts.torch.distributions import Output from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler from gluonts.torch.model.simple_feedforward import make_linear_layer from gluonts.torch.util import weighted_average @@ -242,7 +242,7 @@ def __init__( num_layers_encoder: int, num_layers_decoder: int, layer_norm: bool, - distr_output: DistributionOutput, + distr_output: Output, scaling: str, ) -> None: super().__init__() diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index caa5ae10ec..faea91b313 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -148,6 +148,14 @@ num_batches_per_epoch=3, trainer_kwargs=dict(max_epochs=2), ), + lambda dataset: TiDEEstimator( + freq=dataset.metadata.freq, + prediction_length=dataset.metadata.prediction_length, + distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), lambda dataset: WaveNetEstimator( freq=dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length,