Skip to content

Commit

Permalink
bugfixes in new loss code
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Feb 12, 2025
1 parent fb49e22 commit 8605d5d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
5 changes: 4 additions & 1 deletion equitrain/loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def forward(self, y_pred, y_true):
for loss_type, loss_fn in self.loss_fns.items():
loss[loss_type], _ = loss_fn(
# Detach predictions for other loss functions
{key: value.detach() for key, value in y_pred.items()},
{
key: value.detach() if value is not None else None
for key, value in y_pred.items()
},
y_true,
)

Expand Down
36 changes: 26 additions & 10 deletions equitrain/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def log(self, logger, mode: str, epoch=None, step=None, time=None, lr=None):

logger.log(1, info_str)

def log_step(self, logger, epoch, step, length, time=None, lr=None):
def log_step(self, logger, epoch, step, length, mode, time=None, lr=None):
"""Log the current loss metrics."""

info_str_prefix = f'Epoch [{epoch:>4}][{step:>6}/{length}] -- '
info_str_prefix = f'Epoch [{epoch:>4}][{step:>6}/{length}] -- {mode}'
info_str_postfix = ''

if time is not None:
Expand All @@ -89,16 +89,16 @@ def log_step(self, logger, epoch, step, length, time=None, lr=None):
info_str_postfix += f', lr={lr:.2e}'

info_str = info_str_prefix
info_str += f': {self["total"].avg:.5f}'
info_str += f': {self["total"].avg:.6f}'

if self['energy'] is not None:
info_str += f', energy: {self["energy"].avg:.5f}'
info_str += f', energy: {self["energy"].avg:.6f}'

if self['forces'] is not None:
info_str += f', forces: {self["forces"].avg:.5f}'
info_str += f', forces: {self["forces"].avg:.6f}'

if self['stress'] is not None:
info_str += f', stress: {self["stress"].avg:.5f}'
info_str += f', stress: {self["stress"].avg:.6f}'

info_str += info_str_postfix

Expand Down Expand Up @@ -165,11 +165,27 @@ def log(self, logger, mode: str, epoch=None, step=None, time=None, lr=None):
f'{mode:>5} {"[" + loss_type + "]":7}',
epoch=epoch,
step=step,
time=time,
lr=lr,
time=None,
lr=None,
)

def log_step(self, logger, epoch, step, length, time=None, lr=None):
self.main.log_step(self, logger, epoch, step, length, time=time, lr=lr)
self.main.log_step(
logger,
epoch,
step,
length,
f'{"[" + self.main_type + "]":7}',
time=time,
lr=lr,
)
for loss_type, loss_metric in self.items():
loss_metric.log_step(self, logger, epoch, step, length, time=time, lr=lr)
loss_metric.log_step(
logger,
epoch,
step,
length,
f'{"[" + loss_type + "]":7}',
time=None,
lr=None,
)
31 changes: 15 additions & 16 deletions equitrain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def evaluate_main(

loss, error = loss_fn(y_pred, data)

# if loss.isnan():
# logger.log(1, 'Nan value detected. Skipping batch...')
# continue
if loss.main.isnan():
logger.log(1, 'Nan value detected. Skipping batch...')
continue

loss_collection += loss

Expand All @@ -74,8 +74,8 @@ def evaluate_main(
loss_for_metrics = loss_collection.gather_for_metrics(accelerator)

# Check if loss was NaN for all iterations
# if skip['total']:
# continue
if loss_collection.main['total'].n == 0.0:
continue

loss_metrics.update(loss_for_metrics)

Expand Down Expand Up @@ -115,7 +115,6 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer,
errors: torch.Tensor,
epoch: int,
print_freq: int,
logger: FileLogger,
):
loss_fn = LossFnCollection(**vars(args))
Expand Down Expand Up @@ -144,6 +143,8 @@ def train_one_epoch(
loss_collection = LossCollection(
args.loss_monitor.split(','), device=accelerator.device
)
# Reset gradients
optimizer.zero_grad()

# Sub-batching causes deadlocks when the number of sub-batches varies between
# processes. We need to loop over sub-batches withouth sync
Expand All @@ -155,14 +156,13 @@ def train_one_epoch(
# Evaluate metric to be optimized
loss, error = loss_fn(y_pred, data)

# if loss.isnan():
# logger.log(2, 'Nan value detected. Skipping batch...')
# continue
if loss.main.isnan():
logger.log(2, 'Nan value detected. Skipping batch...')
continue

# Backpropagate here to prevent out-of-memory errors, gradients
# will be accumulated. Since we accumulate gradients over sub-batches,
# we have to rescale before the backward pass

accelerator.backward(loss.main['total'].value / len(data_list))

loss_collection += loss
Expand All @@ -179,17 +179,15 @@ def train_one_epoch(
loss_for_metrics = loss_collection.gather_for_metrics(accelerator)

# Check if loss was NaN for all iterations in one of the processes
# if skip['total']:
# optimizer.zero_grad()
# continue
if loss_collection.main['total'].n == 0.0:
continue

# Clip gradients before optimization step
if args.gradient_clipping is not None and args.gradient_clipping > 0:
accelerator.clip_grad_value_(model.parameters(), args.gradient_clipping)

# Sync of gradients across processes occurs here
optimizer.step()
optimizer.zero_grad()

if model_ema is not None:
model_ema.update()
Expand All @@ -199,7 +197,7 @@ def train_one_epoch(
if accelerator.is_main_process:
# Print intermediate performance statistics only for higher verbose levels
if args.verbose > 1 and (
step % print_freq == 0 or step == len(dataloader) - 1
step % args.print_freq == 0 or step == len(dataloader) - 1
):
w = time.perf_counter() - start_time
e = (step + 1) / len(dataloader)
Expand All @@ -218,6 +216,8 @@ def train_one_epoch(
f'Training (lr={optimizer.param_groups[0]["lr"]:.0e}, loss={loss_metrics.main["total"].avg:.04f})'
)

# Reset gradients
optimizer.zero_grad()
# Synchronize updates across processes
accelerator.wait_for_everyone()
# Sum local errors across all processes
Expand Down Expand Up @@ -336,7 +336,6 @@ def _train_with_accelerator(args, accelerator: Accelerator):
optimizer=optimizer,
errors=errors,
epoch=epoch,
print_freq=args.print_freq,
logger=logger,
)

Expand Down

0 comments on commit 8605d5d

Please sign in to comment.