diff --git a/paper/README.md b/paper/README.md index a053acf..30b2f1a 100644 --- a/paper/README.md +++ b/paper/README.md @@ -4,25 +4,25 @@ - `20220310_plot_causalimpact.ipynb` used to plots the results of the causal impact analysis (e.g. generated via `xgboost_causalimpact.py`) - `20220310_plot_forecast_overview.ipynb` used to plot an overview of the historical forecasts (Figure 7 in the main text as well as the model evaluation) - `20220310_train_gbdt_on_all.ipynb` used to train the GBDT model on all data (subsequently used to compute the scenarios) -- `20220306_predict_w_gbdt.ipynb` example for training a GBDT model +- `20220306_predict_w_gbdt.ipynb` example for training a GBDT model -## Scripts -### Causal impact analysis +## Scripts +### Causal impact analysis - `causalimpact_sweep.py` run the hyperparamter sweep (assumes [Weights and Biases](https://wandb.ai/site) is set up) -- `causalimpact_xgboost.py` run the causal impact analysis using GBDT models +- `causalimpact_xgboost.py` run the causal impact analysis using GBDT models - `tcn_causalimpact.py` run the analysis using TCN models -- `step_times.pkl` contains the timestamps for the step changes in our study +- `step_times.pkl` contains the timestamps for the step changes in our study -### Scenarios -- `loop_over_maps_gbdt.py` / `loop_over_maps_scitas.py` used to create and submit slurm script for "scenario" analysis -- `plot_effects_gbdt.py` / `plot_effects.py` used to convert the outputs of the scenario scripts into heatmaps -- `run_gbdt_scenarios.py` / `run_scenarios.py` contain the logic for running the scenarios +### Scenarios +- `loop_over_maps_gbdt.py` / `loop_over_maps_scitas.py` used to create and submit slurm script for "scenario" analysis +- `plot_effects_gbdt.py` / `plot_effects.py` used to convert the outputs of the scenario scripts into heatmaps +- `run_gbdt_scenarios.py` / `run_scenarios.py` contain the logic for running the scenarios ### Models -Model checkpoints are archived on Zenodo (DOI: [https://dx.doi.org/10.5281/zenodo.5153417](10.5281/zenodo.5153417)) but also available in the `model` subdirectory. +Model checkpoints are archived on Zenodo (DOI: [https://dx.doi.org/10.5281/zenodo.5153417](10.5281/zenodo.5153417)) but also available in the `model` subdirectory. Unfortunately, we could only serialize the models as pickle files wherefore the same Python version and package versions are needed for reusing the models. -### Results +### Results -The `results` subdirectory contains pre-computed results that are used in the notebooks that plot the results. \ No newline at end of file +The `results` subdirectory contains pre-computed results that are used in the notebooks that plot the results. diff --git a/paper/causalimpact_sweep.py b/paper/causalimpact_sweep.py index ab89069..b6eb7fa 100644 --- a/paper/causalimpact_sweep.py +++ b/paper/causalimpact_sweep.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import logging import pickle @@ -92,7 +93,7 @@ def inner_train_test(x, y, day, target): x = x[features] x_trains = [] - y_trains = [] + y_trains = [] before, during, after, way_after = get_causalimpact_splits(x, y, day, times, DF) @@ -115,9 +116,9 @@ def inner_train_test(x, y, day, target): x_trains[shorter] = xscaler.transform(x_trains[shorter]) y_trains[shorter] = yscaler.transform(y_trains[shorter]) - steps =len(during[0]) + steps = len(during[0]) - if steps > len(x_trains[shorter]): + if steps > len(x_trains[shorter]): ts = choose_index(x, 0.3) x_before, x_after = x_trains[longer].split_before(ts) y_before, y_after = y_trains[longer].split_before(ts) diff --git a/paper/causalimpact_xgboost.py b/paper/causalimpact_xgboost.py index 8aa2f3e..f5918d8 100644 --- a/paper/causalimpact_xgboost.py +++ b/paper/causalimpact_xgboost.py @@ -1,27 +1,32 @@ -from aeml.causalimpact.utils import get_timestep_tuples, get_causalimpact_splits +# -*- coding: utf-8 -*- +import math import pickle -from aeml.causalimpact.utils import _select_unrelated_x +import time +from copy import deepcopy + +import click +import numpy as np +import pandas as pd +from darts import TimeSeries +from darts.dataprocessing.transformers import Scaler + +from aeml.causalimpact.utils import ( + _select_unrelated_x, + get_causalimpact_splits, + get_timestep_tuples, +) from aeml.models.gbdt.gbmquantile import LightGBMQuantileRegressor from aeml.models.gbdt.run import run_ci_model from aeml.models.gbdt.settings import * -from darts.dataprocessing.transformers import Scaler -from darts import TimeSeries -import pandas as pd -from copy import deepcopy -import time -import numpy as np -import click -import math - settings = { - 0: {0: ci_0_0, 1: ci_0_1}, -1: {0: ci_1_0, 1: ci_1_1}, -2: {0: ci_2_0, 1: ci_2_1}, -3: {0: ci_3_0, 1: ci_3_1}, -4: {0: ci_4_0, 1: ci_4_1}, -5: {0: ci_5_0, 1: ci_5_1}, -6: {0: ci_6_0, 1: ci_6_1} + 0: {0: ci_0_0, 1: ci_0_1}, + 1: {0: ci_1_0, 1: ci_1_1}, + 2: {0: ci_2_0, 1: ci_2_1}, + 3: {0: ci_3_0, 1: ci_3_1}, + 4: {0: ci_4_0, 1: ci_4_1}, + 5: {0: ci_5_0, 1: ci_5_1}, + 6: {0: ci_6_0, 1: ci_6_1}, } TIMESTR = time.strftime("%Y%m%d-%H%M%S") @@ -87,16 +92,17 @@ 6: [], } + def select_columns(day): feat_to_exclude = to_exclude[day] feats = [f for f in MEAS_COLUMNS if f not in feat_to_exclude] return feats -@click.command('cli') -@click.argument('day', type=click.INT) -@click.argument('target', type=click.INT) -def run_causalimpact_analysis(day, target): +@click.command("cli") +@click.argument("day", type=click.INT) +@click.argument("target", type=click.INT) +def run_causalimpact_analysis(day, target): cols = select_columns(day) y = TimeSeries.from_dataframe(DF)[TARGETS_clean[target]] x = TimeSeries.from_dataframe(DF[cols]) @@ -104,11 +110,9 @@ def run_causalimpact_analysis(day, target): x_trains = [] y_trains = [] - before, during, after, way_after = get_causalimpact_splits( - x, y, day, times, DF - ) + before, during, after, way_after = get_causalimpact_splits(x, y, day, times, DF) - # We do multiseries training + # We do multiseries training x_trains.append(before[0]) y_trains.append(before[1]) x_trains.append(way_after[0]) @@ -126,7 +130,7 @@ def run_causalimpact_analysis(day, target): x_trains[shorter] = xscaler.transform(x_trains[shorter]) y_trains[shorter] = yscaler.transform(y_trains[shorter]) - if len(x_trains[shorter]) < 300: + if len(x_trains[shorter]) < 300: x_trains.pop(shorter) y_trains.pop(shorter) @@ -160,39 +164,39 @@ def run_causalimpact_analysis(day, target): day_y_df = pd.concat([before_y_df, during_y_df, after_y_df], axis=0) day_y_ts = TimeSeries.from_dataframe(day_y_df) - - steps = math.ceil(len(during[0])/2)# * 2 + + steps = math.ceil(len(during[0]) / 2) # * 2 model = run_ci_model( x_trains, y_trains, **settings[day][target], num_features=len(cols), - quantiles=(0.05, 0.5, 0.95), - output_chunk_length=steps + quantiles=(0.05, 0.5, 0.95), + output_chunk_length=steps, ) - buffer = math.ceil(len(during[0])/3) + buffer = math.ceil(len(during[0]) / 3) b = before[1][:-buffer] - predictions = model.forecast( - n = len(during[0]) + 2* buffer, - series = b, - past_covariates = day_x_ts, - -) + predictions = model.forecast( + n=len(during[0]) + 2 * buffer, + series=b, + past_covariates=day_x_ts, + ) results = { - 'predictions': predictions, - 'x_all': day_x_ts, - 'before': before, - 'during': during, - 'after': after + "predictions": predictions, + "x_all": day_x_ts, + "before": before, + "during": during, + "after": after, } with open( - f"{TIMESTR}-causalimpact_{day}_{target}", - "wb", - ) as handle: - pickle.dump(results, handle) + f"{TIMESTR}-causalimpact_{day}_{target}", + "wb", + ) as handle: + pickle.dump(results, handle) + -if __name__ == '__main__': - run_causalimpact_analysis() \ No newline at end of file +if __name__ == "__main__": + run_causalimpact_analysis() diff --git a/paper/loop_over_maps.py b/paper/loop_over_maps.py index 49d66a3..21b9aac 100644 --- a/paper/loop_over_maps.py +++ b/paper/loop_over_maps.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from os import fchdir import subprocess import time +from os import fchdir import click @@ -46,7 +46,7 @@ #SBATCH --constraint=gpu #SBATCH --account=pr128 -module load daint-gpu +module load daint-gpu source /home/kjablonk/anaconda3/bin/activate conda activate aeml diff --git a/paper/loop_over_maps_scitas.py b/paper/loop_over_maps_scitas.py index 187f05f..bc17dc3 100644 --- a/paper/loop_over_maps_scitas.py +++ b/paper/loop_over_maps_scitas.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from os import fchdir import subprocess import time +from os import fchdir import click diff --git a/paper/plot_effects.py b/paper/plot_effects.py index 181a29e..bca3f38 100644 --- a/paper/plot_effects.py +++ b/paper/plot_effects.py @@ -1,13 +1,15 @@ -import matplotlib.pyplot as plt -from glob import glob +# -*- coding: utf-8 -*- +import os import pickle +import traceback +from glob import glob from pathlib import Path -import numpy as np + import click import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np from scipy.ndimage import gaussian_filter -import os -import traceback def load_pickle(filename): diff --git a/paper/plot_effects_gbdt.py b/paper/plot_effects_gbdt.py index d9b74d5..9f65e4f 100644 --- a/paper/plot_effects_gbdt.py +++ b/paper/plot_effects_gbdt.py @@ -1,23 +1,27 @@ -import matplotlib.pyplot as plt -from glob import glob +# -*- coding: utf-8 -*- +import os import pickle +import traceback +from glob import glob from pathlib import Path -import numpy as np + import click import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from loguru import logger from scipy.ndimage import gaussian_filter -import os -import traceback -from loguru import logger plt.style.reload_library() -plt.style.use('science') +plt.style.use("science") from matplotlib import rcParams -rcParams['font.family'] = 'sans-serif' + +rcParams["font.family"] = "sans-serif" from scipy.constants import golden TARGETS_clean = ["2-Amino-2-methylpropanol C4H11NO", "Piperazine C4H10N2"] + def load_pickle(filename): with open(filename, "rb") as handle: res = pickle.load(handle) @@ -37,6 +41,7 @@ def get_grids(d): return outer_keys, inner_keys + def make_image(res, objective="1"): outer, inner = get_grids(res) @@ -47,21 +52,17 @@ def make_image(res, objective="1"): for i, point_x in enumerate(outer): for j, point_y in enumerate(inner): image_m[i][j] = np.sum( - res[point_x][point_y][1][objective].values() - - res[0][0][1][objective].values() + res[point_x][point_y][1][objective].values() - res[0][0][1][objective].values() ) image_l[i][j] = np.sum( - res[point_x][point_y][0][objective].values() - - res[0][0][1][objective].values() + res[point_x][point_y][0][objective].values() - res[0][0][1][objective].values() ) image_t[i][j] = np.sum( - res[point_x][point_y][2][objective].values() - - res[0][0][1][objective].values() + res[point_x][point_y][2][objective].values() - res[0][0][1][objective].values() ) - return image_m, outer, inner @@ -79,7 +80,9 @@ def get_conditions_from_name(name): .replace("_amine", "") .replace("_amp", "") .replace("_pz", "") - .replace("_nh3", "").replace("_True", "").replace('_False', "") + .replace("_nh3", "") + .replace("_True", "") + .replace("_False", "") ) parts = stem.split("_") _ = parts.pop(0) @@ -121,9 +124,9 @@ def get_color_norm(numbers): def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): - return tuple(i/inch for i in tupl[0]) + return tuple(i / inch for i in tupl[0]) else: - return tuple(i/inch for i in tupl) + return tuple(i / inch for i in tupl) def plot_amp_pz_image( @@ -134,13 +137,12 @@ def plot_amp_pz_image( one_color_scale: bool = False, single_output: bool = False, forecast: bool = False, - targets = TARGETS_clean + targets=TARGETS_clean, ): pip_image, pip_inner, pip_outer = None, None, None amp_image, pip_inner, pip_outer = None, None, None - for file in all_files: if condition in file: @@ -150,14 +152,22 @@ def plot_amp_pz_image( if single_output: if "amp" in file: print(f"amp file {file}") - amp_image, amp_inner, amp_outer = make_image(load_pickle(file), objective=amp_name) + amp_image, amp_inner, amp_outer = make_image( + load_pickle(file), objective=amp_name + ) if "pz" in file: print(f"pz file {file}") - pip_image, pip_inner, pip_outer = make_image(load_pickle(file), objective=pz_name) + pip_image, pip_inner, pip_outer = make_image( + load_pickle(file), objective=pz_name + ) else: if "amp" in file: - amp_image, amp_inner, amp_outer = make_image(load_pickle(file), objective=amp_name) - pip_image, pip_inner, pip_outer = make_image(load_pickle(file), objective=pz_name) + amp_image, amp_inner, amp_outer = make_image( + load_pickle(file), objective=amp_name + ) + pip_image, pip_inner, pip_outer = make_image( + load_pickle(file), objective=pz_name + ) continue @@ -170,7 +180,7 @@ def plot_amp_pz_image( # except AssertionError: # print(condition) - fig, ax = plt.subplots(1, 2, sharex="all", sharey="all", figsize=cm2inch(10, 10/golden)) + fig, ax = plt.subplots(1, 2, sharex="all", sharey="all", figsize=cm2inch(10, 10 / golden)) print(pip_image.shape) if blur is not None: @@ -246,7 +256,9 @@ def plot_amp_pz_image( if outdir is not None: fig.savefig( - os.path.join(outdir, f"{raw_conditions}_{str(one_color_scale)}_{str(blur)}_{str(forecast)}.pdf"), + os.path.join( + outdir, f"{raw_conditions}_{str(one_color_scale)}_{str(blur)}_{str(forecast)}.pdf" + ), bbox_inches="tight", ) plt.close(fig) @@ -257,24 +269,31 @@ def plot_amp_pz_image( @click.command("cli") @click.argument("indir", type=click.Path(exists=True)) @click.argument("outdir", type=click.Path(exists=False)) -@click.option('--forecast', is_flag=True) +@click.option("--forecast", is_flag=True) def compute_single_output_maps(indir, outdir, forecast): if not os.path.exists(outdir): os.makedirs(outdir) if forecast: all_files = glob(os.path.join(indir, "*_True")) - targets = ['2-Amino-2-methylpropanol C4H11NO', 'Piperazine C4H10N2'] - else: + targets = ["2-Amino-2-methylpropanol C4H11NO", "Piperazine C4H10N2"] + else: all_files = glob(os.path.join(indir, "*_False")) - targets = ['0', '0'] + targets = ["0", "0"] - logger.info(f'Found {len(all_files)} files') + logger.info(f"Found {len(all_files)} files") all_conditions = get_all_conditions(all_files) for condition in all_conditions: try: - plot_amp_pz_image(condition, all_files, outdir=outdir, single_output=True, forecast=forecast, targets=targets) + plot_amp_pz_image( + condition, + all_files, + outdir=outdir, + single_output=True, + forecast=forecast, + targets=targets, + ) except Exception as e: print(traceback.format_exc()) print(e) diff --git a/paper/run_gbdt_scenarios.py b/paper/run_gbdt_scenarios.py index dfad150..bfa26d7 100644 --- a/paper/run_gbdt_scenarios.py +++ b/paper/run_gbdt_scenarios.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -from curses.ascii import TAB import os import pickle import time from collections import defaultdict from copy import deepcopy +from curses.ascii import TAB import click import joblib @@ -12,7 +12,6 @@ import pandas as pd from darts import TimeSeries - from aeml.utils import choose_index THIS_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -68,25 +67,20 @@ def dump_pickle(object, filename): FEAT_NUM_MAPPING = dict(zip(MEAS_COLUMNS, [str(i) for i in range(len(MEAS_COLUMNS))])) UPDATE_MAPPING = { "amp": { - "scaler": joblib.load( - "20220312_y_transformer" - ), + "scaler": joblib.load("20220312_y_transformer"), "model": joblib.load("20220312_model_all_data_0"), "name": ["2-Amino-2-methylpropanol C4H11NO"], }, "pz": { - "scaler": joblib.load( - "20220312_y_transformer" - ), + "scaler": joblib.load("20220312_y_transformer"), "model": joblib.load("20220312_model_all_data_1"), - "name": [ "Piperazine C4H10N2"], + "name": ["Piperazine C4H10N2"], }, } -SCALER = joblib.load( - "20220312_x_transformer" -) +SCALER = joblib.load("20220312_x_transformer") + +TARGETS_clean = ["2-Amino-2-methylpropanol C4H11NO", "Piperazine C4H10N2"] -TARGETS_clean = ['2-Amino-2-methylpropanol C4H11NO', 'Piperazine C4H10N2'] # making the input one item longer as "safety margin" def calculate_initialization_percentage(timeseries_length: int, input_sequence_length: int = 61): @@ -101,10 +95,10 @@ def run_update_historical_forecast(df, x, target="amine"): predictions = model_dict["model"].historical_forecasts( past_covariates=x, - series=y[model_dict['name']], + series=y[model_dict["name"]], start=calculate_initialization_percentage(len(y)), forecast_horizon=1, - retrain=False + retrain=False, ) return predictions @@ -120,7 +114,7 @@ def run_update_forecast(df, x, target="amine"): predictions = model_dict["model"].forecast( past_covariates=x, - series=y_past[model_dict['name']], + series=y_past[model_dict["name"]], n=len(x) - len(y_past), ) return predictions diff --git a/paper/run_scenarios.py b/paper/run_scenarios.py index e39dec7..048b83a 100644 --- a/paper/run_scenarios.py +++ b/paper/run_scenarios.py @@ -9,11 +9,10 @@ import joblib import numpy as np import pandas as pd -from darts import TimeSeries import torch +from darts import TimeSeries - -from aeml.models.forecast import parallelized_inference, forecast, summarize_results +from aeml.models.forecast import forecast, parallelized_inference, summarize_results from aeml.utils import choose_index THIS_DIR = os.path.dirname(os.path.realpath(__file__)) diff --git a/paper/tcn_causalimpact.py b/paper/tcn_causalimpact.py index 9f0f6d6..dc45388 100644 --- a/paper/tcn_causalimpact.py +++ b/paper/tcn_causalimpact.py @@ -1,23 +1,20 @@ -from statistics import quantiles +# -*- coding: utf-8 -*- import sys import traceback +from statistics import quantiles sys.path.append("../src") -from aeml.causalimpact.utils import get_timestep_tuples import pickle -from aeml.causalimpact.utils import _select_unrelated_x -from aeml.models.forecast import ( - forecast, - parallelized_inference, -) - -from aeml.models.tcn_dropout import TCNDropout +import time +from copy import deepcopy -from darts.dataprocessing.transformers import Scaler -from darts import TimeSeries import pandas as pd -from copy import deepcopy -import time +from darts import TimeSeries +from darts.dataprocessing.transformers import Scaler + +from aeml.causalimpact.utils import _select_unrelated_x, get_timestep_tuples +from aeml.models.forecast import forecast, parallelized_inference +from aeml.models.tcn_dropout import TCNDropout TIMESTR = time.strftime("%Y%m%d-%H%M%S") diff --git a/paper/train_forecast.py b/paper/train_forecast.py index b583615..c381807 100644 --- a/paper/train_forecast.py +++ b/paper/train_forecast.py @@ -1,19 +1,20 @@ -from aeml.models.run import run_model -from aeml.models.forecast import forecast, parallelized_inference, summarize_results -from aeml.models.utils import split_data -from aeml.models.plotting import make_forecast_plot -from aeml.utils.io import dump_pickle +# -*- coding: utf-8 -*- +import os +import time +import traceback +import hydra +import matplotlib.pyplot as plt +import pandas as pd from darts import TimeSeries from darts.dataprocessing.transformers import Scaler -import pandas as pd - -import matplotlib.pyplot as plt -import hydra -import time -import os from matplotlib import rcParams -import traceback + +from aeml.models.forecast import forecast, parallelized_inference, summarize_results +from aeml.models.plotting import make_forecast_plot +from aeml.models.run import run_model +from aeml.models.utils import split_data +from aeml.utils.io import dump_pickle plt.style.use("science") rcParams["font.family"] = "sans-serif" @@ -36,7 +37,6 @@ from aeml.models.run import run_model - STEP_INDICES = [75, 645, 2075, 2880, 3520, 4245] diff --git a/src/aeml/models/tcn/tcn_dropout.py b/src/aeml/models/tcn/tcn_dropout.py index 38dab2f..ccf534a 100644 --- a/src/aeml/models/tcn/tcn_dropout.py +++ b/src/aeml/models/tcn/tcn_dropout.py @@ -177,7 +177,6 @@ def predict_from_dataset( num_loader_workers: int = 0, enable_mc_dropout: bool = False, ) -> Sequence[TimeSeries]: - """ This method allows for predicting with a specific :class:`darts.utils.data.InferenceDataset` instance. These datasets implement a PyTorch ``Dataset``, and specify how the target and covariates are sliced @@ -315,7 +314,6 @@ def historical_forecasts( verbose: bool = False, enable_mc_dropout: bool = False, ) -> Union[TimeSeries, List[TimeSeries]]: - """Compute the historical forecasts that would have been obtained by this model on the `series`. This method uses an expanding training window; it repeatedly builds a training set from the beginning of `series`. It trains the