Skip to content

Commit

Permalink
bugfix valid_loss variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Mar 1, 2025
1 parent da4a1ec commit 3767839
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions equitrain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _train_with_accelerator(args, accelerator: Accelerator):

# Evaluate model before training
if True:
val_loss, _ = evaluate(
valid_loss, _ = evaluate(
args,
model=model,
model_ema=model_ema,
Expand All @@ -332,14 +332,14 @@ def _train_with_accelerator(args, accelerator: Accelerator):
logger=logger,
)

best_metrics.update(val_loss.main, args.epochs_start - 1)
best_metrics.update(valid_loss.main, args.epochs_start - 1)

accelerator.log(
{'val_loss': val_loss.main['total'].avg}, step=args.epochs_start - 1
{'val_loss': valid_loss.main['total'].avg}, step=args.epochs_start - 1
)

if accelerator.is_main_process:
val_loss.log(logger, 'val', epoch=args.epochs_start - 1)
valid_loss.log(logger, 'val', epoch=args.epochs_start - 1)

# Scheduler step before the first epoch for schedulers depending on the epoch
if lr_scheduler is not None:
Expand Down Expand Up @@ -401,7 +401,7 @@ def _train_with_accelerator(args, accelerator: Accelerator):
if args.scheduler_monitor == 'train':
lr_scheduler.step(metric=train_loss.main['total'].avg, epoch=epoch)
if args.scheduler_monitor == 'val':
lr_scheduler.step(metric=val_loss.main['total'].avg, epoch=epoch)
lr_scheduler.step(metric=valid_loss.main['total'].avg, epoch=epoch)

if last_lr is not None and last_lr != lr_scheduler.get_last_lr()[0]:
logger.log(
Expand Down

0 comments on commit 3767839

Please sign in to comment.