Skip to content

Commit

Permalink
Fix QuantileForecast.plot() to use DateTimeIndex (#2269)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulfatir authored Sep 5, 2022
1 parent 44ef129 commit 2c946c6
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ ujson
orjson
requests
holidays~=0.9
matplotlib
5 changes: 1 addition & 4 deletions src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,14 +779,11 @@ def plot(self, label=None, output_file=None, keys=None, *args, **kwargs):
keys = self.forecast_keys

for k, v in zip(keys, self.forecast_array):
plt.plot(
self.index,
v,
pd.Series(data=v, index=self.index.to_timestamp()).plot(
label=f"{label_prefix}q{k}",
*args,
**kwargs,
)
plt.legend()
if output_file:
plt.savefig(output_file)

Expand Down
2 changes: 2 additions & 0 deletions test/model/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def percentile(value):
assert len(forecast.index) == pred_length
assert forecast.index[0] == START_DATE

forecast.plot()


@pytest.mark.parametrize(
"forecast, exp_index",
Expand Down

0 comments on commit 2c946c6

Please sign in to comment.