Skip to content

Commit

Permalink
Apply ruff/pyupgrade to test. (#2489)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper authored Dec 13, 2022
1 parent cd53fd2 commit 87f0007
Show file tree
Hide file tree
Showing 36 changed files with 70 additions and 125 deletions.
7 changes: 3 additions & 4 deletions test/core/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# permissions and limitations under the License.

import random
from textwrap import dedent
from typing import Dict, List

from gluonts.core.component import validated
Expand Down Expand Up @@ -119,9 +118,9 @@ def test_component_ctor():
bar01 = Bar(x_list, input_fields=fields, x_dict=x_dict)
bar02 = load_json(dump_json(bar01))

assert list == type(bar01.x_list) == type(bar02.x_list)
assert dict == type(bar01.x_dict) == type(bar02.x_dict)
assert list == type(bar01.input_fields) == type(bar02.input_fields)
assert all(isinstance(bar.x_list, list) for bar in [bar01, bar02])
assert all(isinstance(bar.x_dict, dict) for bar in [bar01, bar02])
assert all(isinstance(bar.input_fields, list) for bar in [bar01, bar02])

assert bar01.x_list == bar02.x_list
assert bar01.x_dict == bar02.x_dict
Expand Down
2 changes: 1 addition & 1 deletion test/core/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from pydantic import BaseModel

from gluonts.core.settings import Settings, let, inject
from gluonts.core.settings import Settings, let


class MySettings(Settings):
Expand Down
1 change: 0 additions & 1 deletion test/dataset/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
File,
ArrowWriter,
ParquetWriter,
write_dataset,
)


Expand Down
2 changes: 1 addition & 1 deletion test/dataset/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import pandas as pd
import pytest
from gluonts.dataset.common import Dataset, ProcessStartField
from gluonts.dataset.common import Dataset


@pytest.mark.parametrize(
Expand Down
2 changes: 0 additions & 2 deletions test/dataset/test_dataset_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from gluonts.dataset.common import (
FileDataset,
ListDataset,
MetaData,
ProcessDataEntry,
)
from gluonts.dataset.jsonl import JsonLinesFile

Expand Down
1 change: 0 additions & 1 deletion test/dataset/test_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tempfile
from pathlib import Path

import pytest

from gluonts.dataset.common import FileDataset

Expand Down
37 changes: 37 additions & 0 deletions test/dataset/test_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import sys
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest

from gluonts.dataset.common import FileDataset
from gluonts.dataset.arrow import ArrowWriter, ParquetWriter
from gluonts.dataset.jsonl import JsonLinesWriter
from gluonts.dataset.repository.datasets import get_dataset


@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires PyArrow v8.")
@pytest.mark.parametrize(
"writer", [ArrowWriter(), ParquetWriter(), JsonLinesWriter()]
)
def test_dataset_writer(writer):
dataset = get_dataset("constant")

with TemporaryDirectory() as temp_dir:
writer.write_to_folder(dataset.train, Path(temp_dir))

loaded = FileDataset(Path(temp_dir), freq="h")
assert len(dataset.train) == len(loaded)
2 changes: 0 additions & 2 deletions test/ext/rotbaum/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from itertools import chain

import numpy as np
import pytest

from gluonts.ext.rotbaum import TreeEstimator
Expand Down
8 changes: 3 additions & 5 deletions test/mx/block/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ def test_feature_assembler(config, hybridize):

def test_parameters_length():
exp_params_len = sum(
[
len(config[k]["embedding_dims"])
for k in ["embed_static", "embed_dynamic"]
if k in enabled_embedders
]
len(config[k]["embedding_dims"])
for k in ["embed_static", "embed_dynamic"]
if k in enabled_embedders
)
act_params_len = len(assemble_feature.collect_params().keys())
assert exp_params_len == act_params_len
Expand Down
4 changes: 0 additions & 4 deletions test/mx/distribution/test_distribution_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@
Gaussian,
GenPareto,
Laplace,
MultivariateGaussian,
NegativeBinomial,
OneInflatedBeta,
PiecewiseLinear,
Poisson,
StudentT,
TransformedDistribution,
Uniform,
ZeroAndOneInflatedBeta,
ZeroInflatedBeta,
ZeroInflatedPoissonOutput,
)

test_cases = [
Expand Down
2 changes: 0 additions & 2 deletions test/mx/distribution/test_distribution_output_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import mxnet as mx
import pytest

from gluonts.core.serde import encode, decode, dump_json
Expand All @@ -27,7 +26,6 @@
DirichletMultinomialOutput,
DirichletOutput,
EmpiricalDistributionOutput,
EmpiricalDistribution,
GammaOutput,
GaussianOutput,
GenParetoOutput,
Expand Down
2 changes: 0 additions & 2 deletions test/mx/distribution/test_distribution_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
PiecewiseLinear,
Poisson,
StudentT,
TransformedDistribution,
Uniform,
ZeroAndOneInflatedBeta,
ZeroInflatedPoissonOutput,
)
from gluonts.testutil import empirical_cdf

Expand Down
1 change: 0 additions & 1 deletion test/mx/distribution/test_distribution_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import NamedTuple, Optional

import mxnet as mx
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion test/mx/distribution/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

# Standard library imports
from functools import partial
from typing import Tuple

# Third-party imports
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion test/mx/distribution/test_isqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from gluonts.core.serde import dump_json, load_json
from gluonts.mx.distribution import ISQF, ISQFOutput
from gluonts.testutil import empirical_cdf

serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]

Expand Down
1 change: 0 additions & 1 deletion test/mx/distribution/test_issue_287.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import pytest

from gluonts.mx.distribution import DistributionOutput
from gluonts.mx.distribution.beta import BetaOutput
from gluonts.mx.distribution.gamma import GammaOutput
from gluonts.mx.distribution.neg_binomial import NegativeBinomialOutput
Expand Down
2 changes: 1 addition & 1 deletion test/mx/distribution/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_mixture_logprob(
lp.asnumpy(),
np.log(p) + gaussian.log_prob(values_outside_support).asnumpy(),
atol=1e-6,
), f"log_prob(x) should be equal to log(p)+gaussian.log_prob(x)"
), "log_prob(x) should be equal to log(p)+gaussian.log_prob(x)"

fit_mixture = fit_mixture_distribution(
values_outside_support,
Expand Down
6 changes: 3 additions & 3 deletions test/mx/distribution/test_mx_distribution_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
from functools import reduce

from typing import Iterable, List, Tuple
from typing import List, Tuple

import mxnet as mx
import numpy as np
Expand Down Expand Up @@ -158,7 +158,7 @@ def maximum_likelihood_estimate_sgd(
cumulative_loss += mx.nd.mean(loss).asscalar()

assert not np.isnan(cumulative_loss)
print("Epoch %s, loss: %s" % (e, cumulative_loss / num_batches))
print("Epoch {}, loss: {}".format(e, cumulative_loss / num_batches))

if len(distr_args[0].shape) == 1:
return [
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_studentT_likelihood(
), f"sigma did not match: sigma = {sigma}, sigma_hat = {sigma_hat}"
assert (
np.abs(nu_hat - nu) < TOL * nu
), "nu0 did not match: nu0 = %s, nu_hat = %s" % (nu, nu_hat)
), "nu0 did not match: nu0 = {}, nu_hat = {}".format(nu, nu_hat)


@pytest.mark.parametrize("alpha, beta", [(3.75, 1.25)])
Expand Down
5 changes: 2 additions & 3 deletions test/mx/distribution/test_nan_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
NanMixtureOutput,
StudentTOutput,
)
from gluonts.mx.distribution.distribution import Distribution

serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]

Expand Down Expand Up @@ -203,7 +202,7 @@ def test_nanmixture_gaussian_inference() -> None:
args_proj.initialize()
args_proj.hybridize()

input = mx.nd.ones((NUM_SAMPLES))
input = mx.nd.ones(NUM_SAMPLES)

trainer = mx.gluon.Trainer(
args_proj.collect_params(), "sgd", {"learning_rate": 0.00001}
Expand Down Expand Up @@ -258,7 +257,7 @@ def test_nanmixture_categorical_inference() -> None:
args_proj.initialize()
args_proj.hybridize()

input = mx.nd.ones((NUM_SAMPLES))
input = mx.nd.ones(NUM_SAMPLES)

trainer = mx.gluon.Trainer(
args_proj.collect_params(), "sgd", {"learning_rate": 0.000002}
Expand Down
2 changes: 1 addition & 1 deletion test/mx/model/seq2seq/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_inference_quantile_prediction(quantiles, inference_quantiles):
for (i, pred) in enumerate(
forecasts[item_id].quantile(inference_quantile)
)
), f"quantile-crossing occurred"
), "quantile-crossing occurred"


@pytest.mark.parametrize("is_iqf", [True, False])
Expand Down
10 changes: 5 additions & 5 deletions test/mx/model/seq2seq/test_quantile_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_compute_quantile_loss(quantile_weights, correct_qt_loss) -> None:
else:
assert (
nd.mean(loss(y_true, y_pred)) - correct_qt_loss < tol
), f"computing weighted quantile loss fails!"
), "computing weighted quantile loss fails!"


@pytest.mark.parametrize(
Expand All @@ -72,7 +72,7 @@ def test_crps_pwl_quantile_weights(quantiles, true_quantile_weight) -> None:
assert len(quantiles) == len(true_quantile_weight), (
f"length quantiles {quantiles} "
f"and quantile_weights {true_quantile_weight} "
f"do not match."
"do not match."
)
tol = 1e-5
quantile_weights = crps_weights_pwl(quantiles)
Expand All @@ -82,7 +82,7 @@ def test_crps_pwl_quantile_weights(quantiles, true_quantile_weight) -> None:
for i in range(len(quantiles))
)
< tol
), f"inaccurate computation of quantile weights"
), "inaccurate computation of quantile weights"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_infer_quantile_forecast(
- quantile_forecast.quantile(q)
)
< tol
), f"infer_quantile_forecast failed for singleton quantile."
), "infer_quantile_forecast failed for singleton quantile."

else:
assert (
Expand All @@ -156,4 +156,4 @@ def test_infer_quantile_forecast(
for q in inference_quantiles
)
< tol
), f"infer_quantile_forecast failed."
), "infer_quantile_forecast failed."
1 change: 0 additions & 1 deletion test/mx/model/simple_feedforward/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pandas as pd
import numpy as np

from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model import Predictor
from gluonts.mx.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx import Trainer
Expand Down
1 change: 0 additions & 1 deletion test/mx/representation/test_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# permissions and limitations under the License.

import mxnet as mx
import numpy as np
import pytest

from gluonts.mx.representation import Representation
Expand Down
4 changes: 2 additions & 2 deletions test/mx/test_jitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Both gpu and cpu as well as single and double precision are tested.
@pytest.mark.skipif(
sys.platform == "linux",
reason=f"skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
reason="skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
)
@pytest.mark.parametrize("ctx", ["cpu", "gpu"])
@pytest.mark.parametrize("jitter_method", ["iter", "eig"])
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_jitter_unit(jitter_method, float_type, ctx) -> None:
# and gpu and for single and double precision.
@pytest.mark.skipif(
sys.platform == "linux",
reason=f"skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
reason="skipping since potrf crashes on mxnet 1.6.0 on linux when matrix is not spd",
)
@pytest.mark.parametrize("ctx", ["cpu", "gpu"])
@pytest.mark.parametrize("jitter_method", ["iter", "eig"])
Expand Down
4 changes: 2 additions & 2 deletions test/mx/test_mx_item_id_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def test_item_id_info(dataset: Dataset, estimator: Estimator):
predictor = estimator.train(dataset)
forecasts = predictor.predict(dataset)
for data_entry, forecast in zip(dataset, forecasts):
assert (not "item_id" in data_entry) or data_entry[
assert ("item_id" not in data_entry) or data_entry[
"item_id"
] == forecast.item_id
assert (not "info" in data_entry) or data_entry[
assert ("info" not in data_entry) or data_entry[
"info"
] == forecast.info
3 changes: 0 additions & 3 deletions test/mx/test_mx_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from textwrap import dedent
from typing import List

import pytest
Expand All @@ -23,8 +22,6 @@
from gluonts.core import serde
from gluonts.core.component import equals

import gluonts.mx.prelude as _


class CategoricalFeatureInfo(BaseModel):
name: str
Expand Down
6 changes: 2 additions & 4 deletions test/mx/test_variable_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@

import itertools
from functools import partial
from typing import Any, Dict, Iterable
from typing import Iterable

import mxnet as mx
import numpy as np
from pandas.tseries.frequencies import to_offset
import pytest

from gluonts.dataset.common import Dataset, ListDataset
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import (
DataBatch,
DataLoader,
InferenceDataLoader,
TrainDataLoader,
)
Expand Down
Loading

0 comments on commit 87f0007

Please sign in to comment.