Skip to content

Commit

Permalink
Fix: set max idle transforms in PyTorch estimators (#2266)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Sep 5, 2022
1 parent d82ef19 commit 44ef129
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions src/gluonts/torch/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.env import env
from gluonts.itertools import Cached
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
Expand Down Expand Up @@ -154,36 +155,40 @@ def train_model(
) -> TrainOutput:
transformation = self.create_transformation()

transformed_training_data = transformation.apply(
training_data, is_train=True
)

training_network = self.create_lightning_module()

training_data_loader = self.create_training_data_loader(
transformed_training_data
if not cache_data
else Cached(transformed_training_data),
training_network,
num_workers=num_workers,
shuffle_buffer_length=shuffle_buffer_length,
)

validation_data_loader = None

if validation_data is not None:
transformed_validation_data = transformation.apply(
validation_data, is_train=True
with env._let(max_idle_transforms=max(len(training_data), 100)):
transformed_training_data = transformation.apply(
training_data, is_train=True
)
if cache_data:
transformed_training_data = Cached(transformed_training_data)

training_network = self.create_lightning_module()

validation_data_loader = self.create_validation_data_loader(
transformed_validation_data
if not cache_data
else Cached(transformed_validation_data),
training_data_loader = self.create_training_data_loader(
transformed_training_data,
training_network,
num_workers=num_workers,
shuffle_buffer_length=shuffle_buffer_length,
)

validation_data_loader = None

with env._let(max_idle_transforms=max(len(training_data), 100)):
if validation_data is not None:
transformed_validation_data = transformation.apply(
validation_data, is_train=True
)
if cache_data:
transformed_validation_data = Cached(
transformed_validation_data
)

validation_data_loader = self.create_validation_data_loader(
transformed_validation_data,
training_network,
num_workers=num_workers,
)

monitor = "train_loss" if validation_data is None else "val_loss"
checkpoint = pl.callbacks.ModelCheckpoint(
monitor=monitor, mode="min", verbose=True
Expand Down

0 comments on commit 44ef129

Please sign in to comment.