Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Local prediction saving #29

Closed
wants to merge 11 commits into from
10 changes: 8 additions & 2 deletions q2_ritme/feature_space/_process_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ def process_train(config, train_val, target, host_id, tax, seed_data):
# SPLIT
# todo: refine assignment of features to be used for modelling
train, val = split_data_by_host(train_val_t, host_id, 0.8, seed_data)
X_train, y_train = train[microbial_ft_ls_transf], train[target]
X_val, y_val = val[microbial_ft_ls_transf], val[target]
X_train, y_train, idx_train = (
train[microbial_ft_ls_transf],
train[target],
train.index,
)
X_val, y_val, idx_val = val[microbial_ft_ls_transf], val[target], val.index

return (
X_train.values,
y_train.values,
idx_train,
X_val.values,
y_val.values,
idx_val,
microbial_ft_ls_transf,
)
154 changes: 118 additions & 36 deletions q2_ritme/model_space/static_trainables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pickle
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple

import joblib
import numpy as np
Expand All @@ -16,7 +16,7 @@
from classo import Classo
from coral_pytorch.dataset import corn_label_from_logits
from coral_pytorch.losses import corn_loss
from lightning import LightningModule, Trainer, seed_everything
from lightning import Callback, LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from ray import train
from ray.air import session
Expand Down Expand Up @@ -118,6 +118,49 @@ def _report_results_manually(
return None


def _predict_from_engineered_x(model, model_type, X):
"""Use model of model_type to create predictions from engineered X"""
if isinstance(model, NeuralNet):
with torch.no_grad():
X_t = torch.tensor(X, dtype=torch.float32)
predicted = model(X_t)
predicted = model._prepare_predictions(predicted).values
elif isinstance(model, dict):
# trac model
log_geom, _ = _preprocess_taxonomy_aggregation(X, model["matrix_a"].values)
alpha = model["model"].values
predicted = log_geom.dot(alpha[1:]) + alpha[0]
predicted = predicted.flatten()
elif isinstance(model, xgb.core.Booster):
X_t = xgb.DMatrix(X)
predicted = model.predict(X_t).flatten()
else:
predicted = model.predict(X).flatten()
return predicted


def get_n_save_predictions(
model, model_type, X_train, y_train, idx_train, X_val, y_val, idx_val
):
split_dic = {
"train": (X_train, y_train, idx_train),
"val": (X_val, y_val, idx_val),
}
pred_ls = []
for split, data in split_dic.items():
X, y, idx = data
y_pred = _predict_from_engineered_x(model, model_type, X)
pred_df = pd.DataFrame({"true": y, "pred": y_pred}, index=idx)
pred_df["split"] = split
pred_ls.append(pred_df)
all_pred = pd.concat(pred_ls)
trial_path = train.get_context().get_trial_dir()
# todo: once you removed the former predictions -> rename to no suffix
path2save = os.path.join(trial_path, "debug_last_log_vs_preds.csv")
all_pred.to_csv(path2save, index=True)
return all_pred


def train_linreg(
config: Dict[str, Any],
train_val: pd.DataFrame,
Expand All @@ -143,10 +186,11 @@ def train_linreg(
None
"""
# ! process dataset: X with features & y with host_id
X_train, y_train, X_val, y_val, ft_col = process_train(
# todo: maybe group X,y,idx into pandas?
X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train(
config, train_val, target, host_id, tax, seed_data
)

# todo: add X_test, y_test here - with inferred feature engineering à la TunedModel
# ! model
np.random.seed(seed_model)
linreg = ElasticNet(
Expand All @@ -156,6 +200,11 @@ def train_linreg(
)
linreg.fit(X_train, y_train)

# ! save predictions
_ = get_n_save_predictions(
linreg, "linreg", X_train, y_train, idx_train, X_val, y_val, idx_val
)

_report_results_manually(linreg, X_train, y_train, X_val, y_val, tax)


Expand Down Expand Up @@ -230,7 +279,7 @@ def train_trac(
None
"""
# ! process dataset: X with features & y with host_id
X_train, y_train, X_val, y_val, ft_col = process_train(
X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train(
config, train_val, target, host_id, tax, seed_data
)
# ! derive matrix A
Expand All @@ -257,7 +306,11 @@ def train_trac(
matrices_train, selected_param, intercept=intercept
)

# ! save predictions
model = _bundle_trac_model(alpha, a_df)
_ = get_n_save_predictions(
model, "trac", X_train, y_train, idx_train, X_val, y_val, idx_val
)

_report_results_manually_trac(
model, log_geom_train, y_train, log_geom_val, y_val, tax
Expand Down Expand Up @@ -289,7 +342,7 @@ def train_rf(
None
"""
# ! process dataset
X_train, y_train, X_val, y_val, ft_col = process_train(
X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train(
config, train_val, target, host_id, tax, seed_data
)

Expand All @@ -307,11 +360,14 @@ def train_rf(
)
rf.fit(X_train, y_train)

# ! save predictions
_ = get_n_save_predictions(
rf, "rf", X_train, y_train, idx_train, X_val, y_val, idx_val
)
_report_results_manually(rf, X_train, y_train, X_val, y_val, tax)


class NeuralNet(LightningModule):
# TODO: adjust to have option of NNcorn also within
def __init__(self, n_units, learning_rate, nn_type="regression"):
super(NeuralNet, self).__init__()
self.save_hyperparameters() # This saves all passed arguments to self.hparams
Expand Down Expand Up @@ -443,32 +499,43 @@ def load_data(X_train, y_train, X_val, y_val, config):
return train_loader, val_loader


class NNTuneReportCheckpointCallback(TuneReportCheckpointCallback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
save_checkpoints: bool = True,
on: Union[str, List[str]] = "validation_end",
nb_features: int = None,
):
super().__init__(
metrics=metrics, filename=filename, save_checkpoints=save_checkpoints, on=on
class PostTrainingCallback(Callback):
def __init__(self, nn_type, X_train, y_train, idx_train, X_val, y_val, idx_val):
super().__init__()
self.nn_type = nn_type
self.X_train = X_train
self.y_train = y_train
self.idx_train = idx_train
self.X_val = X_val
self.y_val = y_val
self.idx_val = idx_val

def on_validation_epoch_end(self, trainer, pl_module):
# Your post-training logic here
_ = get_n_save_predictions(
pl_module,
self.nn_type,
self.X_train,
self.y_train,
self.idx_train,
self.X_val,
self.y_val,
self.idx_val,
)
self.nb_features = nb_features

def _handle(self, trainer: Trainer, pl_module: LightningModule):
# CUSTOM: includes also nb_features in report
if trainer.sanity_checking:
return

report_dict = self._get_report_dict(trainer, pl_module)
report_dict["nb_features"] = self.nb_features
if not report_dict:
return
class CustomTuneReportCallback(TuneReportCheckpointCallback):
def __init__(self, *args, post_training_callback=None, **kwargs):
super().__init__(*args, **kwargs)
self.post_training_callback = post_training_callback

with self._get_checkpoint(trainer) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
def on_validation_epoch_end(self, trainer, pl_module):
# this ensures that the predictions are saved before
# TuneReportCheckpointCallback is called and tune is getting the signal
# to stop the trial
if self.post_training_callback:
self.post_training_callback.on_validation_epoch_end(trainer, pl_module)
super().on_validation_epoch_end(trainer, pl_module)


def train_nn(
Expand All @@ -485,7 +552,7 @@ def train_nn(
seed_everything(seed_model, workers=True)

# Process dataset
X_train, y_train, X_val, y_val, ft_col = process_train(
X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train(
config, train_val, target, host_id, tax, seed_data
)

Expand Down Expand Up @@ -522,6 +589,15 @@ def train_nn(

os.makedirs(checkpoint_dir, exist_ok=True)

post_training_callback = PostTrainingCallback(
nn_type=nn_type,
X_train=X_train,
y_train=y_train,
idx_train=idx_train,
X_val=X_val,
y_val=y_val,
idx_val=idx_val,
)
callbacks = [
ModelCheckpoint(
monitor="val_rmse",
Expand All @@ -531,18 +607,22 @@ def train_nn(
dirpath=checkpoint_dir, # Automatically set dirpath
filename="{epoch}-{val_rmse:.2f}",
),
NNTuneReportCheckpointCallback(
# the below callback signals to ray tune that the trainable is finished
# - hence post_training_callback must be set to store predictions
CustomTuneReportCallback(
metrics={
"rmse_val": "val_rmse",
"rmse_train": "train_rmse",
"r2_val": "val_r2",
"r2_train": "train_r2",
"loss_val": "val_loss",
"loss_train": "train_loss",
# "nb_features": "nb_features",
},
filename="checkpoint",
on="validation_end",
nb_features=X_train.shape[1],
save_checkpoints=True,
post_training_callback=post_training_callback,
),
]

Expand Down Expand Up @@ -652,7 +732,7 @@ def train_xgb(
None
"""
# ! process dataset
X_train, y_train, X_val, y_val, ft_col = process_train(
X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train(
config, train_val, target, host_id, tax, seed_data
)
# Set seeds
Expand Down Expand Up @@ -682,12 +762,14 @@ def train_xgb(
)
# todo: add test set here to be tracked as well

xgb.train(
xgb_model = xgb.train(
config,
dtrain,
evals=[(dtrain, "train"), (dval, "val")],
callbacks=[checkpoint_callback],
custom_metric=custom_xgb_metric,
)

# TODO: add test set here to be tracked as well
# ! save predictions
_ = get_n_save_predictions(
xgb_model, "xgb", X_train, y_train, idx_train, X_val, y_val, idx_val
)
Loading