From 391d9116e13dbedc29f70b3e9e75152ef66dc4a4 Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Tue, 5 Nov 2024 10:34:23 -0800 Subject: [PATCH 1/6] added monitor_dnn_loss function to warn when DNN training loss decreases too quickly --- neurobayes/flax_nets/deterministic_nn.py | 16 ++++++++++++---- neurobayes/utils/utils.py | 5 +++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index d1da3e6..1da87bf 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -6,8 +6,10 @@ import optax from functools import partial from tqdm import tqdm +import numpy as np +import time -from ..utils.utils import split_in_batches +from ..utils.utils import split_in_batches, monitor_dnn_loss class TrainState(train_state.TrainState): @@ -97,20 +99,26 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_s num_batches = len(X_batches) with tqdm(total=epochs, desc="Training Progress", leave=True) as pbar: # Progress bar tracks epochs now + avg_epoch_losses = np.empty(epochs) for epoch in range(epochs): epoch_loss = 0.0 for i, (X_batch, y_batch) in enumerate(zip(X_batches, y_batches)): self.state, batch_loss = self.train_step(self.state, X_batch, y_batch) epoch_loss += batch_loss - + # Start storing parameters in the last n epochs if epochs - epoch <= self.average_last_n_weights: self._store_params(self.state.params) avg_epoch_loss = epoch_loss / num_batches + avg_epoch_losses[epoch] = avg_epoch_loss + if epoch == 1: + monitor_dnn_loss(avg_epoch_losses) + pbar.set_postfix_str(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}") pbar.update(1) - + timestr = time.strftime("%Y%m%d-%H%M%S") + np.savez(f'avg_epoch_losses_{timestr}.npz', loss=avg_epoch_losses) if self.params_history: # Ensure there is something to average self.state = self.state.replace(params=self.average_params()) @@ -143,4 +151,4 @@ def average_params(self) -> Dict: return avg_params def get_params(self) -> Dict: - return self.state.params \ No newline at end of file + return self.state.params diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 57c16ed..4270d7b 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -5,6 +5,7 @@ import numpy as np +import warnings def infer_device(device_preference: str = None): """ @@ -79,6 +80,10 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int return result +def monitor_dnn_loss(loss: np.ndarray) -> bool: + if np.diff(loss)[0] / loss[0] < -0.25: + warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.') + return def mse(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: """ From b163ff70b52f50dae199907065886c01b6838b51 Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Tue, 5 Nov 2024 13:05:51 -0800 Subject: [PATCH 2/6] remove saving of losses to file --- neurobayes/flax_nets/deterministic_nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index 1da87bf..02b8400 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -118,7 +118,6 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_s pbar.set_postfix_str(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}") pbar.update(1) timestr = time.strftime("%Y%m%d-%H%M%S") - np.savez(f'avg_epoch_losses_{timestr}.npz', loss=avg_epoch_losses) if self.params_history: # Ensure there is something to average self.state = self.state.replace(params=self.average_params()) From 605273ca0d05171c8abf3a70c8545aebc5ab63bf Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Tue, 5 Nov 2024 13:27:14 -0800 Subject: [PATCH 3/6] Updated to account for rapid drops in training loss over all of training --- neurobayes/flax_nets/deterministic_nn.py | 6 ++---- neurobayes/utils/utils.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index 02b8400..f9510dc 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -7,7 +7,6 @@ from functools import partial from tqdm import tqdm import numpy as np -import time from ..utils.utils import split_in_batches, monitor_dnn_loss @@ -99,7 +98,7 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_s num_batches = len(X_batches) with tqdm(total=epochs, desc="Training Progress", leave=True) as pbar: # Progress bar tracks epochs now - avg_epoch_losses = np.empty(epochs) + avg_epoch_losses = np.zeros(epochs) for epoch in range(epochs): epoch_loss = 0.0 for i, (X_batch, y_batch) in enumerate(zip(X_batches, y_batches)): @@ -112,12 +111,11 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_s avg_epoch_loss = epoch_loss / num_batches avg_epoch_losses[epoch] = avg_epoch_loss - if epoch == 1: + if epoch > 0: monitor_dnn_loss(avg_epoch_losses) pbar.set_postfix_str(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}") pbar.update(1) - timestr = time.strftime("%Y%m%d-%H%M%S") if self.params_history: # Ensure there is something to average self.state = self.state.replace(params=self.average_params()) diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 4270d7b..b6046b4 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -81,7 +81,8 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int return result def monitor_dnn_loss(loss: np.ndarray) -> bool: - if np.diff(loss)[0] / loss[0] < -0.25: + loss = loss[loss != 0] + if np.diff(loss)[-1] / loss[0] < -0.25: warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.') return From b824b187a1946d74ad44f40da19042f2381fd858 Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Tue, 5 Nov 2024 13:31:34 -0800 Subject: [PATCH 4/6] Changed stacklevel of warning to only print once --- neurobayes/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index b6046b4..bdd0851 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -83,7 +83,7 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int def monitor_dnn_loss(loss: np.ndarray) -> bool: loss = loss[loss != 0] if np.diff(loss)[-1] / loss[0] < -0.25: - warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.') + warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.', stacklevel=2) return def mse(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: From 8cf750889f90dc8c21880a5eb734c64bf3341847 Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Tue, 5 Nov 2024 13:40:11 -0800 Subject: [PATCH 5/6] Updated language of warning message to be more specific --- neurobayes/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index bdd0851..8bcd04a 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -83,7 +83,7 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int def monitor_dnn_loss(loss: np.ndarray) -> bool: loss = loss[loss != 0] if np.diff(loss)[-1] / loss[0] < -0.25: - warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.', stacklevel=2) + warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, increasing MAP sigma, or adjusting the learning rate.', stacklevel=2) return def mse(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: From 326b7e15d56548d2de27ad9e54164e5b62fc91d0 Mon Sep 17 00:00:00 2001 From: "sarah.allec" Date: Mon, 18 Nov 2024 09:35:33 -0800 Subject: [PATCH 6/6] Addresses suggestions to add None return type and check for length of loss --- neurobayes/utils/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 8bcd04a..812493f 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -80,10 +80,12 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int return result -def monitor_dnn_loss(loss: np.ndarray) -> bool: +def monitor_dnn_loss(loss: np.ndarray) -> None: + """Checks whether current change in loss is greater than a 25% decrease""" loss = loss[loss != 0] - if np.diff(loss)[-1] / loss[0] < -0.25: - warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, increasing MAP sigma, or adjusting the learning rate.', stacklevel=2) + if len(loss) > 1: + if np.diff(loss)[-1] / loss[0] < -0.25: + warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, increasing MAP sigma, or adjusting the learning rate.', stacklevel=2) return def mse(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: