Skip to content

Commit

Permalink
Backports for v0.13.4 (#2979)
Browse files Browse the repository at this point in the history
* Turn `type` comparison into `isinstance` (#2958)

* Clean up RepresentablePredictor (#2967)

* Fix: use `isinstance` instead of type comparison (#2973)

* fix type comparison

---------

Co-authored-by: Pedro Eduardo Mercado Lopez <[email protected]>

* Fix ArrowDecoder.decode to return instead of yield (#2976)

* Unpin pyarrow version (#2977)

* Fix Cython version in XGBoost tests (#2966)

* black

---------

Co-authored-by: Pedro Mercado <[email protected]>
Co-authored-by: Pedro Eduardo Mercado Lopez <[email protected]>
  • Loading branch information
3 people authored Aug 25, 2023
1 parent 48d22d7 commit fd816ce
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests-xgboost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
- name: Install other dependencies
run: |
python -m pip install -U pip setuptools wheel
python -m pip install -U "Cython<3.0"
pip install .
pip install -r requirements/requirements-test.txt
pip install -r requirements/requirements-rotbaum.txt
Expand Down
3 changes: 1 addition & 2 deletions requirements/requirements-arrow.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pyarrow>=6.0; python_version=='3.6.*'
pyarrow~=8.0; python_version>='3.7'
pyarrow
6 changes: 3 additions & 3 deletions src/gluonts/core/serde/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def decode(r: Any) -> Any:
"""

# structural recursion over the possible shapes of r
if type(r) == dict and "__kind__" in r:
if isinstance(r, dict) and "__kind__" in r:
kind = r["__kind__"]
cls = cast(Any, locate(r["class"]))

Expand All @@ -331,10 +331,10 @@ def decode(r: Any) -> Any:

raise ValueError(f"Unknown kind {kind}.")

if type(r) == dict:
if isinstance(r, dict):
return valmap(decode, r)

if type(r) == list:
if isinstance(r, list):
return list(map(decode, r))

return r
2 changes: 1 addition & 1 deletion src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def from_schema(cls, schema):
)

def decode(self, batch, row_number: int):
yield from self.decode_batch(batch.slice(row_number, row_number + 1))
return next(self.decode_batch(batch.slice(row_number, row_number + 1)))

def decode_batch(self, batch):
for row in batch.to_pandas().to_dict("records"):
Expand Down
17 changes: 6 additions & 11 deletions src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import equals, from_hyperparameters, validated
from gluonts.core.component import equals, from_hyperparameters
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.model.forecast import Forecast
Expand Down Expand Up @@ -135,23 +135,18 @@ def from_inputs(cls, train_iter, **params):

class RepresentablePredictor(Predictor):
"""
An abstract predictor that can be subclassed by models that are not based
on Gluon. Subclasses should have @validated() constructors.
(De)serialization and value equality are all implemented on top of the.
An abstract predictor that can be subclassed by framework-specific models.
Subclasses should have ``@validated()`` constructors:
(de)serialization and equality test are all implemented on top of its logic.
@validated() logic.
Parameters
----------
prediction_length
Prediction horizon.
lead_time
Prediction lead time.
"""

@validated()
def __init__(self, prediction_length: int, lead_time: int = 0) -> None:
super().__init__(
lead_time=lead_time, prediction_length=prediction_length
)

def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
for item in dataset:
yield self.predict_item(item)
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def from_inputs(cls, train_iter, **params):
if field in params.keys():
is_params_field = (
params[field]
if type(params[field]) == bool
if isinstance(params[field], bool)
else strtobool(params[field])
)
if is_params_field and not auto_params[field]:
Expand Down
4 changes: 4 additions & 0 deletions test/dataset/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from gluonts.dataset.arrow import (
File,
ArrowFile,
ArrowWriter,
ParquetWriter,
)
Expand Down Expand Up @@ -78,3 +79,6 @@ def test_arrow(writer, flatten_arrays):

for orig, arrow_value in zip(data, dataset):
assert_equal(orig, arrow_value)

if isinstance(dataset, ArrowFile):
assert_equal(dataset[4], data[4])

0 comments on commit fd816ce

Please sign in to comment.