Skip to content

Commit

Permalink
exclude main loss from monitored loss types
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Mar 2, 2025
1 parent 3767839 commit 8e7c658
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
10 changes: 10 additions & 0 deletions equitrain/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,16 @@ def check_args_complete(args: argparse.ArgumentParser, script_type: str):
raise ValueError(f'Unexpected arguments: {extra}')


def get_loss_monitor(args: argparse.ArgumentParser) -> list[str]:
# Create list of loss types
loss_monitor = [item.strip().lower() for item in args['loss_monitor'].split(',')]

if args.loss_type in loss_monitor:
loss_monitor.remove(args.loss_type)

return loss_monitor


class ArgumentError(ValueError):
"""Custom exception raised when invalid or missing argument is present."""

Expand Down
3 changes: 2 additions & 1 deletion equitrain/loss_fn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from equitrain.argparser import get_loss_monitor
from equitrain.data.scatter import scatter_mean
from equitrain.loss import Loss, LossCollection

Expand Down Expand Up @@ -228,7 +229,7 @@ def __init__(self, **args):

# Additional loss metrics
self.loss_fns = {}
for loss_type in args['loss_monitor'].split(','):
for loss_type in get_loss_monitor(args):
args_new = {**args, 'loss_type': loss_type}
self.loss_fns[loss_type] = LossFn(**args_new)

Expand Down
7 changes: 4 additions & 3 deletions equitrain/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from equitrain.argparser import get_loss_monitor
from equitrain.loss import LossCollection


Expand Down Expand Up @@ -141,13 +142,13 @@ def update(self, loss, epoch):
class LossMetrics(dict):
def __init__(self, args):
self.main = LossMetric(args)
self.main_type = args.loss_type
for loss_type in args.loss_monitor.split(','):
self.main_type = args.loss_type.lower()
for loss_type in get_loss_monitor(args):
self[loss_type] = LossMetric(args)

def update(self, loss: LossCollection):
self.main.update(loss.main)
for loss_type, loss_metric in self.items():
for loss_type, _ in self.items():
self[loss_type].update(loss[loss_type])

def log(self, logger, mode: str, epoch=None, step=None, time=None, lr=None):
Expand Down

0 comments on commit 8e7c658

Please sign in to comment.