diff --git a/.github/workflows/tests-xgboost.yaml b/.github/workflows/tests-xgboost.yaml index f5989d5fbd..3a3c2c5044 100644 --- a/.github/workflows/tests-xgboost.yaml +++ b/.github/workflows/tests-xgboost.yaml @@ -22,6 +22,7 @@ jobs: - name: Install other dependencies run: | python -m pip install -U pip setuptools wheel + python -m pip install -U "Cython<3.0" pip install . pip install -r requirements/requirements-test.txt pip install -r requirements/requirements-rotbaum.txt diff --git a/requirements/requirements-arrow.txt b/requirements/requirements-arrow.txt index 6e5a693686..19d8363071 100644 --- a/requirements/requirements-arrow.txt +++ b/requirements/requirements-arrow.txt @@ -1,2 +1 @@ -pyarrow>=6.0; python_version=='3.6.*' -pyarrow~=8.0; python_version>='3.7' +pyarrow diff --git a/src/gluonts/core/serde/_base.py b/src/gluonts/core/serde/_base.py index 6f51e02a70..5432bd8105 100644 --- a/src/gluonts/core/serde/_base.py +++ b/src/gluonts/core/serde/_base.py @@ -309,7 +309,7 @@ def decode(r: Any) -> Any: """ # structural recursion over the possible shapes of r - if type(r) == dict and "__kind__" in r: + if isinstance(r, dict) and "__kind__" in r: kind = r["__kind__"] cls = cast(Any, locate(r["class"])) @@ -331,10 +331,10 @@ def decode(r: Any) -> Any: raise ValueError(f"Unknown kind {kind}.") - if type(r) == dict: + if isinstance(r, dict): return valmap(decode, r) - if type(r) == list: + if isinstance(r, list): return list(map(decode, r)) return r diff --git a/src/gluonts/dataset/arrow/dec.py b/src/gluonts/dataset/arrow/dec.py index 3082b70790..148d5311c7 100644 --- a/src/gluonts/dataset/arrow/dec.py +++ b/src/gluonts/dataset/arrow/dec.py @@ -32,7 +32,7 @@ def from_schema(cls, schema): ) def decode(self, batch, row_number: int): - yield from self.decode_batch(batch.slice(row_number, row_number + 1)) + return next(self.decode_batch(batch.slice(row_number, row_number + 1))) def decode_batch(self, batch): for row in batch.to_pandas().to_dict("records"): diff --git a/src/gluonts/model/predictor.py b/src/gluonts/model/predictor.py index 69ff23ed84..e0d490c851 100644 --- a/src/gluonts/model/predictor.py +++ b/src/gluonts/model/predictor.py @@ -26,7 +26,7 @@ import gluonts from gluonts.core import fqname_for -from gluonts.core.component import equals, from_hyperparameters, validated +from gluonts.core.component import equals, from_hyperparameters from gluonts.core.serde import dump_json, load_json from gluonts.dataset.common import DataEntry, Dataset from gluonts.model.forecast import Forecast @@ -135,23 +135,18 @@ def from_inputs(cls, train_iter, **params): class RepresentablePredictor(Predictor): """ - An abstract predictor that can be subclassed by models that are not based - on Gluon. Subclasses should have @validated() constructors. - (De)serialization and value equality are all implemented on top of the. + An abstract predictor that can be subclassed by framework-specific models. + Subclasses should have ``@validated()`` constructors: + (de)serialization and equality test are all implemented on top of its logic. - @validated() logic. Parameters ---------- prediction_length Prediction horizon. + lead_time + Prediction lead time. """ - @validated() - def __init__(self, prediction_length: int, lead_time: int = 0) -> None: - super().__init__( - lead_time=lead_time, prediction_length=prediction_length - ) - def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]: for item in dataset: yield self.predict_item(item) diff --git a/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py b/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py index 7359beb3a7..5864d37931 100644 --- a/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py @@ -317,7 +317,7 @@ def from_inputs(cls, train_iter, **params): if field in params.keys(): is_params_field = ( params[field] - if type(params[field]) == bool + if isinstance(params[field], bool) else strtobool(params[field]) ) if is_params_field and not auto_params[field]: diff --git a/test/dataset/test_arrow.py b/test/dataset/test_arrow.py index 3c12818165..9aba14f7bb 100644 --- a/test/dataset/test_arrow.py +++ b/test/dataset/test_arrow.py @@ -21,6 +21,7 @@ from gluonts.dataset.arrow import ( File, + ArrowFile, ArrowWriter, ParquetWriter, ) @@ -78,3 +79,6 @@ def test_arrow(writer, flatten_arrays): for orig, arrow_value in zip(data, dataset): assert_equal(orig, arrow_value) + + if isinstance(dataset, ArrowFile): + assert_equal(dataset[4], data[4])