Skip to content

Commit

Permalink
Backports for v0.10.9 (#2610)
Browse files Browse the repository at this point in the history
* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Backport: Add gluonts.util.safe_extract (#2606)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Disable Py36 tests, fix version.

* Fixup.

* Cap numpy compatibility in `mxnet` extra requirements (#2506)

* xfail multivariate grouper test

Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>

---------

Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2023
1 parent be1a39e commit c9a6f96
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
max-parallel: 4
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}
Expand Down
10 changes: 2 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
max-parallel: 4
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}
Expand All @@ -19,16 +19,10 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install MXNet (Linux)
if: ${{ runner.os == 'Linux' }}
run: pip install mxnet~=1.8.0
- name: Install MXNet (Windows)
if: ${{ runner.os == 'Windows' }}
run: pip install mxnet~=1.7.0
- name: Install other dependencies
run: |
python -m pip install -U pip
pip install ".[arrow,shell]"
pip install ".[mxnet,arrow,shell]"
pip install -r requirements/requirements-test.txt
pip install -r requirements/requirements-extras-m-competitions.txt
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-arrow.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pyarrow>=6.*; python_version=='3.6.*'
pyarrow>=6.0; python_version=='3.6'
pyarrow~=8.0; python_version>='3.7'
2 changes: 1 addition & 1 deletion requirements/requirements-extras-r.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rpy2>=2.9.*,<3.*
rpy2>=2.9.0,<3.0
3 changes: 3 additions & 0 deletions requirements/requirements-mxnet.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# upper bound added since numpy==1.24 broke importing mxnet,
# see https://github.com/awslabs/gluonts/pull/2506
numpy<1.24
mxnet~=1.7
3 changes: 2 additions & 1 deletion src/gluonts/dataset/repository/_gp_copula_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from gluonts.dataset.common import MetaData, TrainDatasets
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository._util import metadata
from gluonts.util import safe_extractall


class GPCopulaDataset(NamedTuple):
Expand Down Expand Up @@ -140,7 +141,7 @@ def download_dataset(dataset_path: Path, ds_info: GPCopulaDataset):
request.urlretrieve(ds_info.url, dataset_path / f"{ds_info.name}.tar.gz")

with tarfile.open(dataset_path / f"{ds_info.name}.tar.gz") as tar:
tar.extractall(path=dataset_path)
safe_extractall(tar, path=dataset_path)


def get_data(dataset_path: Path, ds_info: GPCopulaDataset):
Expand Down
18 changes: 14 additions & 4 deletions src/gluonts/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ def predict(
# get the forecast start date
if start_date is None:
start_date = prediction.start_date
output = np.stack(output, axis=0)
output = np.concatenate(output, axis=0)

# aggregating output of different models
# default according to paper is median,
# but we can also make use of not aggregating
if self.aggregation_method == "median":
output = np.median(output, axis=0)
output = np.median(output, axis=0, keepdims=True)
elif self.aggregation_method == "mean":
output = np.mean(output, axis=0)
output = np.mean(output, axis=0, keepdims=True)
else: # "none": do not aggregate
pass

Expand Down Expand Up @@ -312,6 +312,10 @@ class NBEATSEnsembleEstimator(Estimator):
(trend). A list of strings of length 1 or 'num_stacks'.
Default and recommended value for generic mode: ["G"]
Recommended value for interpretable mode: ["T","S"]
aggregation_method
The method by which to aggregate the individual predictions of the
models. Either 'median', 'mean' or 'none', in which case no aggregation
happens. Default is 'median'.
**kwargs
Arguments passed down to the individual estimators.
"""
Expand All @@ -336,6 +340,7 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
sharing: Optional[List[bool]] = None,
stack_types: Optional[List[str]] = None,
aggregation_method: str = "median",
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -385,6 +390,7 @@ def __init__(
self.expansion_coefficient_lengths = expansion_coefficient_lengths
self.sharing = sharing
self.stack_types = stack_types
self.aggregation_method = aggregation_method

# Actually instantiate the different models
self.estimators = self._estimator_factory(**kwargs)
Expand Down Expand Up @@ -449,4 +455,8 @@ def train(
)
predictors.append(estimator.train(training_data, validation_data))

return NBEATSEnsemblePredictor(self.prediction_length, predictors)
return NBEATSEnsemblePredictor(
self.prediction_length,
predictors,
aggregation_method=self.aggregation_method,
)
3 changes: 2 additions & 1 deletion src/gluonts/nursery/sagemaker_sdk/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from gluonts.dataset.repository import datasets
from gluonts.model.estimator import Estimator
from gluonts.model.predictor import Predictor
from gluonts.util import safe_extractall

from .defaults import (
ENTRY_POINTS_FOLDER,
Expand Down Expand Up @@ -503,7 +504,7 @@ def _retrieve_model(self, locations):
with self._s3fs.open(locations.model_archive, "rb") as stream:
with tarfile.open(mode="r:gz", fileobj=stream) as archive:
with TemporaryDirectory() as temp_dir:
archive.extractall(temp_dir)
safe_extractall(archive, temp_dir)
predictor = Predictor.deserialize(Path(temp_dir))

return predictor
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/nursery/tsbench/src/cli/evaluations/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, cast, Dict, List, Optional
import botocore
import click
from gluonts.util import safe_extractall
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
from tsbench.analysis.utils import run_parallel
Expand Down Expand Up @@ -104,7 +105,7 @@ def _download_public_evaluations(
file = Path(tmp) / "metrics.tar.gz"
client.download_file(public_bucket, "metrics.tar.gz", str(file))
with tarfile.open(file, mode="r:gz") as tar:
tar.extractall(evaluations_path)
safe_extractall(tar, evaluations_path)

# Then, optionally download the forecasts
if include_forecasts:
Expand Down
48 changes: 48 additions & 0 deletions src/gluonts/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import tarfile

from pathlib import Path


def will_extractall_into(tar: tarfile.TarFile, path: Path) -> None:
"""
Check that the content of ``tar`` will be extracted within ``path``
upon calling ``extractall``.
Raise a ``PermissionError`` if not.
"""
path = Path(path).resolve()

for member in tar.getmembers():
member_path = (path / member.name).resolve()

try:
member_path.relative_to(path)
except ValueError:
raise PermissionError(f"'{member.name}' extracts out of target.")


def safe_extractall(
tar: tarfile.TarFile,
path: Path = Path("."),
members=None,
*,
numeric_owner=False,
):
"""
Safe wrapper around ``TarFile.extractall`` that checks all destination
files to be strictly within the given ``path``.
"""
will_extractall_into(tar, path)
tar.extractall(path, members, numeric_owner=numeric_owner)
3 changes: 3 additions & 0 deletions test/dataset/test_multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def test_multivariate_grouper_train(
MAX_TARGET_DIM = [2, 1]


@pytest.mark.xfail(
reason="This test is known to fail with numpy>=1.24, and a fix is pending"
)
@pytest.mark.parametrize(
"univariate_ts, multivariate_ts, test_fill_rule, max_target_dim",
zip(
Expand Down
51 changes: 51 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import tempfile
import tarfile
from pathlib import Path
from typing import Optional

import pytest

from gluonts.util import will_extractall_into


@pytest.mark.parametrize(
"arcname, expect_failure",
[
(None, False),
("./file.txt", False),
("/a/../file.txt", False),
("/a/../../file.txt", True),
("../file.txt", True),
],
)
def test_will_extractall_into(arcname: Optional[str], expect_failure: bool):
with tempfile.TemporaryDirectory() as tempdir:
file_path = Path(tempdir) / "a" / "file.txt"
file_path.parent.mkdir(parents=True)
file_path.touch()

with tarfile.open(Path(tempdir) / "archive.tar.gz", "w:gz") as tar:
tar.add(file_path, arcname=arcname)

if expect_failure:
with pytest.raises(PermissionError):
with tarfile.open(
Path(tempdir) / "archive.tar.gz", "r:gz"
) as tar:
will_extractall_into(tar, Path(tempdir) / "b")
else:
with tarfile.open(Path(tempdir) / "archive.tar.gz", "r:gz") as tar:
will_extractall_into(tar, Path(tempdir) / "b")

0 comments on commit c9a6f96

Please sign in to comment.