Skip to content

Commit

Permalink
update evaluation code and scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-jc committed Aug 22, 2024
1 parent c5618a5 commit d742e91
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 5 deletions.
7 changes: 6 additions & 1 deletion project/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ This directory contains the code and scripts for benchmarking.

`chronos_scripts` contains the scripts to run Chronos on different datasets.

Example:
### Examples
On Monash dataset:
```
sh chronos_scripts/monash_chronos_base.sh
```

On datasets for Probabilistic forecasting:
```
sh chronos_scripts/pf_chronos_base.sh
```
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_base.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=base
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=mini
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_small.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=small
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_tiny.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=tiny
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
39 changes: 35 additions & 4 deletions project/benchmarks/run_chronos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
from functools import partial

import numpy as np
import torch
Expand All @@ -24,14 +25,30 @@
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from uni2ts.eval_util.data import get_gluonts_test_dataset
from uni2ts.eval_util.data import get_gluonts_test_dataset, get_lsf_test_dataset
from uni2ts.eval_util.evaluation import evaluate_forecasts
from uni2ts.eval_util.metrics import MedianMSE


def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
def evaluate(
pipeline,
dataset,
save_path,
num_samples=20,
batch_size=512,
test_setting="monash",
pred_length=96,
):
print("-" * 5, f"Evaluating {dataset}", "-" * 5)
test_data, metadata = get_gluonts_test_dataset(dataset)
if test_setting == "monash" or test_setting == "pf":
get_dataset = get_gluonts_test_dataset # for monash and pf, the prediction length can be inferred.
elif test_setting == "lsf":
get_dataset = partial(get_lsf_test_dataset, prediction_length=pred_length)
else:
raise NotImplementedError(
f"Cannot find the test setting {test_setting}. Please select from monash, pf, lsf."
)
test_data, metadata = get_dataset(dataset)
prediction_length = metadata.prediction_length

while True:
Expand Down Expand Up @@ -110,6 +127,16 @@ def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
"--batch_size", type=int, default=512, help="Batch size for generating samples"
)
parser.add_argument("--run_name", type=str, default="test", help="Name of the run")
parser.add_argument(
"--test_setting",
type=str,
default="monash",
choices=["monash", "lsf", "pf"],
help="Name of the test setting",
)
parser.add_argument(
"--pred_length", type=int, default=96, help="Prediction length for LSF dataset"
)

args = parser.parse_args()
# Load Chronos
Expand All @@ -122,4 +149,8 @@ def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
output_dir = os.path.join(args.save_dir, args.run_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
evaluate(pipeline, args.dataset, os.path.join(output_dir, f"{args.dataset}.csv"))
if args.test_setting == "lsf":
save_dir = os.path.join(output_dir, f"{args.dataset}_{args.pred_length}.csv")
else:
save_dir = os.path.join(output_dir, f"{args.dataset}.csv")
evaluate(pipeline, args.dataset, save_dir, args.num_samples, args.batch_size)

0 comments on commit d742e91

Please sign in to comment.