From 3c434d860c450f79f0f1d2e1ec54e5e80c3eb28f Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 Nov 2023 11:56:31 +0100 Subject: [PATCH] Backports for v0.14.2 (#3063) * Fix `iterable.Cached`. (#3060) * Torch: Remove double caching of dataset. (#3061) --------- Co-authored-by: Jasper --- src/gluonts/itertools.py | 25 +++++++++++++++---------- src/gluonts/torch/model/estimator.py | 20 +++++++++----------- test/test_itertools.py | 10 ++++++++++ 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index 8ed8d245cb..e90281a1d8 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -305,10 +305,9 @@ def split_into(xs: Sequence, n: int) -> Sequence: @dataclass class Cached: """ - An iterable wrapper, which caches values in a list the first time it is - iterated. + An iterable wrapper, which caches values in a list while iterated. - The primary use-case for this is to avoid re-computing the element of the + The primary use-case for this is to avoid re-computing the elements of the sequence, in case the inner iterable does it on demand. This should be used to wrap deterministic iterables, i.e. iterables where @@ -317,15 +316,21 @@ class Cached: """ iterable: SizedIterable - cache: list = field(default_factory=list, init=False) + provider: Iterable = field(init=False) + consumed: list = field(default_factory=list, init=False) + + def __post_init__(self): + # ensure we only iterate once over the iterable + self.provider = iter(self.iterable) def __iter__(self): - if not self.cache: - for element in self.iterable: - yield element - self.cache.append(element) - else: - yield from self.cache + # Yield already provided values first + yield from self.consumed + + # Now yield remaining elements. + for element in self.provider: + self.consumed.append(element) + yield element def __len__(self) -> int: return len(self.iterable) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 7cca653a15..b8a1147d44 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import NamedTuple, Optional, Iterable, Dict, Any, Union +from typing import NamedTuple, Optional, Iterable, Dict, Any import logging import numpy as np @@ -24,7 +24,7 @@ from gluonts.itertools import Cached from gluonts.model import Estimator, Predictor from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.transform import Transformation, TransformedDataset +from gluonts.transform import Transformation logger = logging.getLogger(__name__) @@ -156,18 +156,16 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data: Union[ - Cached, TransformedDataset - ] = transformation.apply(training_data, is_train=True) + transformed_training_data: Dataset = transformation.apply( + training_data, is_train=True + ) if cache_data: transformed_training_data = Cached(transformed_training_data) training_network = self.create_lightning_module() training_data_loader = self.create_training_data_loader( - Cached(transformed_training_data) - if cache_data - else transformed_training_data, + transformed_training_data, training_network, shuffle_buffer_length=shuffle_buffer_length, ) @@ -176,9 +174,9 @@ def train_model( if validation_data is not None: with env._let(max_idle_transforms=max(len(validation_data), 100)): - transformed_validation_data: Union[ - Cached, TransformedDataset - ] = transformation.apply(validation_data, is_train=True) + transformed_validation_data: Dataset = transformation.apply( + validation_data, is_train=True + ) if cache_data: transformed_validation_data = Cached( transformed_validation_data diff --git a/test/test_itertools.py b/test/test_itertools.py index 6cef2210f5..b9071d9ad0 100644 --- a/test/test_itertools.py +++ b/test/test_itertools.py @@ -119,6 +119,16 @@ def test_pickle(iterable: Iterable, assert_content: bool): assert data == data_copy +def test_cached_reentry(): + data = Cached(range(10)) + + assert len(data) == 10 + assert list(take(5, data)) == list(range(5)) + assert len(data) == 10 + assert list(take(10, data)) == list(range(10)) + assert len(data) == 10 + + @pytest.mark.parametrize( "given, expected", [