Skip to content

Commit

Permalink
Merge pull request #24 from sarah-allec/warning-dnn-loss
Browse files Browse the repository at this point in the history
Adds warning message when deterministic training loss stagnates too quickly in partial BNNs
  • Loading branch information
ziatdinovmax authored Nov 19, 2024
2 parents 857e21a + 326b7e1 commit e735fea
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
13 changes: 9 additions & 4 deletions neurobayes/flax_nets/deterministic_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import optax
from functools import partial
from tqdm import tqdm
import numpy as np

from ..utils.utils import split_in_batches
from ..utils.utils import split_in_batches, monitor_dnn_loss


class TrainState(train_state.TrainState):
Expand Down Expand Up @@ -97,20 +98,24 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_s
num_batches = len(X_batches)

with tqdm(total=epochs, desc="Training Progress", leave=True) as pbar: # Progress bar tracks epochs now
avg_epoch_losses = np.zeros(epochs)
for epoch in range(epochs):
epoch_loss = 0.0
for i, (X_batch, y_batch) in enumerate(zip(X_batches, y_batches)):
self.state, batch_loss = self.train_step(self.state, X_batch, y_batch)
epoch_loss += batch_loss

# Start storing parameters in the last n epochs
if epochs - epoch <= self.average_last_n_weights:
self._store_params(self.state.params)

avg_epoch_loss = epoch_loss / num_batches
avg_epoch_losses[epoch] = avg_epoch_loss
if epoch > 0:
monitor_dnn_loss(avg_epoch_losses)

pbar.set_postfix_str(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}")
pbar.update(1)

if self.params_history: # Ensure there is something to average
self.state = self.state.replace(params=self.average_params())

Expand Down Expand Up @@ -143,4 +148,4 @@ def average_params(self) -> Dict:
return avg_params

def get_params(self) -> Dict:
return self.state.params
return self.state.params
8 changes: 8 additions & 0 deletions neurobayes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np

import warnings

def infer_device(device_preference: str = None):
"""
Expand Down Expand Up @@ -79,6 +80,13 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int

return result

def monitor_dnn_loss(loss: np.ndarray) -> None:
"""Checks whether current change in loss is greater than a 25% decrease"""
loss = loss[loss != 0]
if len(loss) > 1:
if np.diff(loss)[-1] / loss[0] < -0.25:
warnings.warn('The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, increasing MAP sigma, or adjusting the learning rate.', stacklevel=2)
return

def mse(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
"""
Expand Down

0 comments on commit e735fea

Please sign in to comment.