From 3767839c8230b4ab9447423c4ba081d5b69c7a6e Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Sat, 1 Mar 2025 18:32:18 +0100 Subject: [PATCH] bugfix valid_loss variable name --- equitrain/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/equitrain/train.py b/equitrain/train.py index 241772e..a6918c5 100644 --- a/equitrain/train.py +++ b/equitrain/train.py @@ -323,7 +323,7 @@ def _train_with_accelerator(args, accelerator: Accelerator): # Evaluate model before training if True: - val_loss, _ = evaluate( + valid_loss, _ = evaluate( args, model=model, model_ema=model_ema, @@ -332,14 +332,14 @@ def _train_with_accelerator(args, accelerator: Accelerator): logger=logger, ) - best_metrics.update(val_loss.main, args.epochs_start - 1) + best_metrics.update(valid_loss.main, args.epochs_start - 1) accelerator.log( - {'val_loss': val_loss.main['total'].avg}, step=args.epochs_start - 1 + {'val_loss': valid_loss.main['total'].avg}, step=args.epochs_start - 1 ) if accelerator.is_main_process: - val_loss.log(logger, 'val', epoch=args.epochs_start - 1) + valid_loss.log(logger, 'val', epoch=args.epochs_start - 1) # Scheduler step before the first epoch for schedulers depending on the epoch if lr_scheduler is not None: @@ -401,7 +401,7 @@ def _train_with_accelerator(args, accelerator: Accelerator): if args.scheduler_monitor == 'train': lr_scheduler.step(metric=train_loss.main['total'].avg, epoch=epoch) if args.scheduler_monitor == 'val': - lr_scheduler.step(metric=val_loss.main['total'].avg, epoch=epoch) + lr_scheduler.step(metric=valid_loss.main['total'].avg, epoch=epoch) if last_lr is not None and last_lr != lr_scheduler.get_last_lr()[0]: logger.log(