Skip to content

Commit

Permalink
Zebras: Rework Schema. (#2916)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper authored Jun 27, 2023
1 parent f4af090 commit 14ec5ca
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 215 deletions.
5 changes: 2 additions & 3 deletions src/gluonts/zebras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
"time_series",
"BatchTimeSeries",
"TimeSeries",
"Schema",
"Field",
"schema",
]

from typing import TypeVar
Expand All @@ -38,7 +37,7 @@
from ._split_frame import split_frame, SplitFrame, BatchSplitFrame
from ._time_frame import time_frame, TimeFrame, BatchTimeFrame
from ._time_series import time_series, TimeSeries, BatchTimeSeries
from ._schema import Field, Schema
from . import schema

Batchable = TypeVar("Batchable", TimeSeries, TimeFrame, SplitFrame)

Expand Down
202 changes: 0 additions & 202 deletions src/gluonts/zebras/_schema.py

This file was deleted.

15 changes: 13 additions & 2 deletions src/gluonts/zebras/_split_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def future(self):
_pad=Pad(0, self._pad.right),
)

def __getitem__(self, name):
if name in self._past:
if name in self._future:
return np.concatenate(
self._past[name], self._future[name], axis=self.tdims[name]
)
return self._past[name]
return self._future[name]

def __len__(self):
return self.past_length + self.future_length

Expand Down Expand Up @@ -117,7 +126,8 @@ def set_past(self, name, value, tdim=None):
assert self.tdims.get(name, tdim) == tdim

return _replace(
past=merge(self.past, {name: value}),
self,
_past=merge(self._past, {name: value}),
tdims=merge(self.tdims, {name: tdim}),
)

Expand All @@ -127,7 +137,8 @@ def set_future(self, name, value, tdim=None):
assert self.tdims.get(name, tdim) == tdim

return _replace(
future=merge(self.future, {name: value}),
self,
_future=merge(self.future, {name: value}),
tdims=merge(self.tdims, {name: tdim}),
)

Expand Down
Loading

0 comments on commit 14ec5ca

Please sign in to comment.