Improve MetricTracker
support for metrics nested in ClasswiseWrapper
#2994
Labels
enhancement
New feature or request
🚀 Feature
I would want to be able to track metrics wrapped in a
ClasswiseWrapper
. Currently, I am facing two issues with howMetricTracker
andClasswiseWrapper
interact.MetricTracker
does not inferhigher_is_better
attribute of metrics wrapped byClasswiseWrapper
An example of what I would want to do, but that currently raises an error:
Intuitively, I would assume
MetricTracker
should be able to infer how to maximize the classwise wrapper, given that it can infer how to maximize the base metric. Because of this, I would expect the assignment ofclasswise_metric_tracker_1
to work.MetricTracker
cannot get the best metric of aClasswiseWrapper
I understand that
best_metric
expects to be able to convert the metric to a scalar, and by definition theClasswiseWrapper
returns a dict of values.I would like to track
ClasswiseWrapper
normally when it is part of aMetricCollection
, like shown below. This would push users ofMetricTracker
andClasswiseWrapper
to also rely onMetricCollection
if they want more integrated functionalities, but to me it seems like a reasonable assumption / fair compromise to make.After digging around through the documentation, I saw the disclaimer for the
MetricTracker
:However, I would argue that this is unintuitive from a user's perspective.
MetricTracker
already supportsMetricCollection
, andMetricCollection
unpacks nested dictionaries returned by compute methods. Thus, my first intuition was that nesting theClasswiseWrapper
inside aMetricCollection
would allowMetricTracker
to track each classwise value individually, just like it would any other metric in the collection.Still, I understand that this second issue is trickier to implement, and I don't necessarily have a good recommendation for how I would go about doing it.
Motivation
I would argue it is a common pattern to track not just the average metric, but also classwise metric (e.g. when classes are imbalanced). Given the tracking and classwise functionalities already provided by torchmetrics, I expected to be able to implement a generic way to do this using the existing API.
A further motivation is that I would like the
ClasswiseWrapper
to function as much as possible like its base metric. I provided simplified examples, but in my real use-case I define my metrics (both base metrics and their classwise versions) in configuration files (w/ Hydra). I group everything in oneMetricCollection
, and add aMetricTracker
on top of that.The whole torchmetrics API is really well thought-out and incredibly useful to keep my metric logging code generic! Therefore, I would like to avoid having to check the type of metric and special-case
ClasswiseWrapper
compared to base metrics.The text was updated successfully, but these errors were encountered: