Skip to content

Commit

Permalink
optional monitoring of validation loss for the scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Mar 1, 2025
1 parent c8256d5 commit 80a112a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
9 changes: 9 additions & 0 deletions equitrain/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def add_optimizer_args(parser: argparse.ArgumentParser) -> argparse.ArgumentPars
)
parser.add_argument(
'--plateau-mode',
choices=['min', 'max'],
type=str,
default='min',
help='One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing (default: min)',
Expand Down Expand Up @@ -276,9 +277,17 @@ def get_args_parser(script_type: str) -> argparse.ArgumentParser:
parser.add_argument(
'--scheduler', help='LR scheduler type', type=str, default='plateau'
)
parser.add_argument(
'--scheduler_monitor',
help='Loss monitored by the scheduler [train (default), val]',
choices=['train', 'val'],
type=str,
default='train',
)
parser.add_argument(
'--loss-type',
help='Type of loss function [mae, smooth-l1, mse, huber (default)]',
choices=['mae', 'smooth-l1', 'mse', 'huber'],
type=str,
default='huber',
)
Expand Down
5 changes: 4 additions & 1 deletion equitrain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,10 @@ def _train_with_accelerator(args, accelerator: Accelerator):
)

if lr_scheduler is not None:
lr_scheduler.step(metric=train_loss.main['total'].avg, epoch=epoch)
if args.scheduler_monitor == 'train':
lr_scheduler.step(metric=train_loss.main['total'].avg, epoch=epoch)
if args.scheduler_monitor == 'val':
lr_scheduler.step(metric=val_loss.main['total'].avg, epoch=epoch)

if last_lr is not None and last_lr != lr_scheduler.get_last_lr()[0]:
logger.log(
Expand Down

0 comments on commit 80a112a

Please sign in to comment.