Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: refactor model definitions + evaluations #30

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ jobs:

- name: Install dependencies
run: |
pip install -q https://github.com/qiime2/q2lint/archive/master.zip
pip install -q flake8
pip install -q ruff

- name: Lint
run: make lint
- name: Ruff
uses: chartboost/ruff-action@v1

build-and-test:
needs: lint
Expand Down
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ PYTHON ?= python
all: ;

lint:
q2lint
flake8
ruff check

test: all
py.test
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ If you are using SLURM and ...
## Model tracking
In the config file you can choose to track your trials with MLflow (tracking_uri=="mlruns") or with WandB (tracking_uri=="wandb").

### Choice between MLflow & WandB
WandB stores aggregate metrics on their servers. The way *ritme* is set up no sample-specific information is stored remotely. This information is stored on your local machine.
To choose which tracking set-up works best for you, it is best to review the respective services.

### MLflow
In case of using MLflow you can view your models with `mlflow ui --backend-store-uri experiments/mlruns`. For more information check out the [official MLflow documentation](https://mlflow.org/docs/latest/index.html).

Expand Down
278 changes: 278 additions & 0 deletions experiments/raytune_nn_rmse_mismatch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inconsistent train metrics and logging frequency in Ray Tune with PyTorch Lightning\n",
"\n",
"**The Bug:**\n",
"When I train a Pytorch Lightning model with Ray Tune and I extract the best train metric from `results.get_best_result()` the metric differs from the recalculated metric using the best checkpoint on the same data (`rmse_train_recalc` in script). However, the respective validation metrics(`best_rmse_val` & `rmse_val_recalc`) remain consistent. \n",
"Also, the counts on how often the metrics are logged differ between train and validation set: train is only logged 9 times (`train_log_count`) whereas validation is logged 10 times (`val_log_count`).\n",
"Why is that? \n",
"\n",
"**Expected behavior:**\n",
"- Consistent train metrics between best result and recalculation using best checkpoint\n",
"- Equal logging frequency (10 times) for both train and validation metrics\n",
"\n",
"**Environment:**\n",
"- PyTorch: 2.4\n",
"- Lightning: 2.4.0\n",
"- Ray: 2.24.0\n",
"\n",
"can be created as:\n",
"```\n",
"conda create -n ray_bug_recom -y\n",
"conda activate ray_bug_recom\n",
"conda install -c conda-forge -c pytorch pytorch==2.4 lightning==2.4.0 -y\n",
"pip install \"ray[data,train,tune,serve]\"==2.24.0\n",
"```\n",
"\n",
"**Opened issue with `raytune_nn_rmse_mismatch.py` to [ray repos here](https://github.com/ray-project/ray/issues/47333) on Monday 26th Aug'24.**\n",
"\n",
"Note: Issue is reproducible in both ray_bug_recom and ritme conda environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from lightning import LightningModule, Trainer\n",
"from ray import init, tune\n",
"from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader, TensorDataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SimpleNN(LightningModule):\n",
" def __init__(self, input_size, hidden_size, learning_rate):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.model = nn.Sequential(\n",
" nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)\n",
" )\n",
" self.learning_rate = learning_rate\n",
" self.train_loss = 0\n",
" self.val_loss = 0\n",
" self.train_predictions = []\n",
" self.train_targets = []\n",
" self.val_predictions = []\n",
" self.val_targets = []\n",
" self.train_log_count = 0\n",
" self.val_log_count = 0\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" y_hat = self(x)\n",
" loss = nn.MSELoss()(y_hat, y)\n",
" self.train_loss = loss\n",
"\n",
" self.train_predictions.append(y_hat.detach())\n",
" self.train_targets.append(y.detach())\n",
"\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" y_hat = self(x)\n",
" loss = nn.MSELoss()(y_hat, y)\n",
" self.val_loss = loss\n",
"\n",
" self.val_predictions.append(y_hat.detach())\n",
" self.val_targets.append(y.detach())\n",
"\n",
" self.log(\"val_loss\", loss)\n",
"\n",
" return {\"val_loss\": loss}\n",
"\n",
" def on_train_epoch_end(self):\n",
" all_preds_train = torch.cat(self.train_predictions)\n",
" all_targets_train = torch.cat(self.train_targets)\n",
"\n",
" rmse_train = torch.sqrt(\n",
" nn.functional.mse_loss(all_preds_train, all_targets_train)\n",
" )\n",
" self.train_log_count += 1\n",
" self.log(\"train_log_count\", self.train_log_count)\n",
" self.log(\"rmse_train\", rmse_train)\n",
"\n",
" self.train_predictions.clear()\n",
" self.train_targets.clear()\n",
"\n",
" def on_validation_epoch_end(self):\n",
" all_preds_val = torch.cat(self.val_predictions)\n",
" all_targets_val = torch.cat(self.val_targets)\n",
"\n",
" rmse_val = torch.sqrt(nn.functional.mse_loss(all_preds_val, all_targets_val))\n",
"\n",
" self.val_log_count += 1\n",
" self.log(\"val_log_count\", self.val_log_count)\n",
" self.log(\"rmse_val\", rmse_val)\n",
"\n",
" self.val_predictions.clear()\n",
" self.val_targets.clear()\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=self.learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_nn(config, train_data, val_data):\n",
" model = SimpleNN(\n",
" input_size=10, hidden_size=config[\"hidden_size\"], learning_rate=config[\"lr\"]\n",
" )\n",
"\n",
" train_loader = DataLoader(train_data, batch_size=800)\n",
" val_loader = DataLoader(val_data, batch_size=800)\n",
"\n",
" trainer = Trainer(\n",
" max_epochs=10,\n",
" num_sanity_val_steps=0,\n",
" check_val_every_n_epoch=1,\n",
" val_check_interval=1,\n",
" callbacks=[\n",
" TuneReportCheckpointCallback(\n",
" metrics={\n",
" \"loss\": \"val_loss\",\n",
" \"rmse_val\": \"rmse_val\",\n",
" \"rmse_train\": \"rmse_train\",\n",
" \"val_log_count\": \"val_log_count\",\n",
" \"train_log_count\": \"train_log_count\",\n",
" },\n",
" filename=\"checkpoint\",\n",
" on=\"validation_end\",\n",
" save_checkpoints=True,\n",
" ),\n",
" ],\n",
" deterministic=True,\n",
" )\n",
"\n",
" trainer.fit(model, train_loader, val_loader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = {\n",
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
" \"hidden_size\": tune.choice([32, 64, 128]),\n",
"}\n",
"\n",
"init(\n",
" address=\"local\",\n",
" include_dashboard=False,\n",
" ignore_reinit_error=True,\n",
")\n",
"\n",
"torch.manual_seed(42)\n",
"X = torch.randn(1000, 10)\n",
"y = torch.sum(X, dim=1, keepdim=True)\n",
"X_train, y_train = X[:800], y[:800]\n",
"X_val, y_val = X[800:], y[800:]\n",
"\n",
"train_data = TensorDataset(X_train, y_train)\n",
"val_data = TensorDataset(X_val, y_val)\n",
"\n",
"tuner = tune.Tuner(\n",
" tune.with_parameters(train_nn, train_data=train_data, val_data=val_data),\n",
" tune_config=tune.TuneConfig(metric=\"rmse_val\", mode=\"min\", num_samples=2),\n",
" param_space=config,\n",
")\n",
"\n",
"results = tuner.fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get best ray result\n",
"best_result = results.get_best_result(\"rmse_val\", \"min\", scope=\"all\")\n",
"best_rmse_train = best_result.metrics[\"rmse_train\"]\n",
"best_rmse_val = best_result.metrics[\"rmse_val\"]\n",
"print(f\"Best trial final train rmse: {best_rmse_train}\")\n",
"print(f\"Best trial final validation rmse: {best_rmse_val}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get best model checkpoint\n",
"checkpoint_dir = best_result.checkpoint.path\n",
"checkpoint_path = os.path.join(checkpoint_dir, \"checkpoint\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load model\n",
"model = SimpleNN.load_from_checkpoint(checkpoint_path)\n",
"\n",
"# recalculate rmse_train\n",
"rmse_train_recalc = torch.sqrt(nn.functional.mse_loss(model(X_train), y_train)).item()\n",
"print(f\"rmse_train_recalc: {rmse_train_recalc}\")\n",
"\n",
"# recalculate rmse_val\n",
"rmse_val_recalc = torch.sqrt(nn.functional.mse_loss(model(X_val), y_val)).item()\n",
"print(f\"rmse_val_recalc: {rmse_val_recalc}\")\n",
"\n",
"# assertions\n",
"if not best_rmse_val == rmse_val_recalc:\n",
" raise ValueError(\"best_rmse_val != rmse_val_recalc\")\n",
"if not best_rmse_train == rmse_train_recalc:\n",
" raise ValueError(\"best_rmse_train != rmse_train_recalc\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ritme",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading