Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MetricTracker support for metrics nested in ClasswiseWrapper #2994

Open
nathanpainchaud opened this issue Mar 8, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@nathanpainchaud
Copy link

nathanpainchaud commented Mar 8, 2025

🚀 Feature

I would want to be able to track metrics wrapped in a ClasswiseWrapper. Currently, I am facing two issues with how MetricTracker and ClasswiseWrapper interact.

MetricTracker does not infer higher_is_better attribute of metrics wrapped by ClasswiseWrapper

An example of what I would want to do, but that currently raises an error:

metric = MulticlassAccuracy(num_classes=4, average=None)
classwise_metric = ClasswiseWrapper(metric)

# This works because MulticlassAccuracy defines a `higher_is_better` attribute
metric_tracker = MetricTracker(metric, maximize=None)

# Errors with "The metric 'ClasswiseWrapper' does not have a 'higher_is_better' attribute. Please provide the `maximize` argument explicitly"
# because ClasswiseWrapper does not propagate the `higher_is_better` attribute from its base metric
classwise_metric_tracker_1 = MetricTracker(classwise_metric, maximize=None)

# This works because we manually defines the `higher_is_better` attribute for the `MetricTracker`
classwise_metric_tracker_2 = MetricTracker(classwise_metric, maximize=True)

Intuitively, I would assume MetricTracker should be able to infer how to maximize the classwise wrapper, given that it can infer how to maximize the base metric. Because of this, I would expect the assignment of classwise_metric_tracker_1 to work.

MetricTracker cannot get the best metric of a ClasswiseWrapper

I understand that best_metric expects to be able to convert the metric to a scalar, and by definition the ClasswiseWrapper returns a dict of values.

I would like to track ClasswiseWrapper normally when it is part of a MetricCollection, like shown below. This would push users of MetricTracker and ClasswiseWrapper to also rely on MetricCollection if they want more integrated functionalities, but to me it seems like a reasonable assumption / fair compromise to make.

target = tensor([0, 1, 2, 3])
preds = tensor([0, 2, 1, 3])

metric = MulticlassAccuracy(num_classes=4, average=None)

# As explained above, I expected this first tracker to fail
metric_tracker = MetricTracker(metric, maximize=None)
metric_tracker.increment()
metric_tracker.update(target, preds)
# As expected, raises a UserWarning "a Tensor with 4 elements cannot be converted to Scalar"
metric_tracker.best_metric()

# But I thought I could solve the issue by nesting everything in a MetricCollection
collection = MetricCollection([metric])
collection_tracker = MetricTracker(collection)
collection_tracker.increment()
collection_tracker.update(target, preds)
# Still raises the same UserWarning "a Tensor with 4 elements cannot be converted to Scalar"
collection_tracker.best_metric()

After digging around through the documentation, I saw the disclaimer for the MetricTracker:

However, multiple layers of nesting, such as using a Metric inside a MetricWrapper inside a MetricCollection is not fully supported, especially the .best_metric method that cannot auto compute the best metric and index for such nested structures.

However, I would argue that this is unintuitive from a user's perspective. MetricTracker already supports MetricCollection, and MetricCollection unpacks nested dictionaries returned by compute methods. Thus, my first intuition was that nesting the ClasswiseWrapper inside a MetricCollection would allow MetricTracker to track each classwise value individually, just like it would any other metric in the collection.

Still, I understand that this second issue is trickier to implement, and I don't necessarily have a good recommendation for how I would go about doing it.

Motivation

I would argue it is a common pattern to track not just the average metric, but also classwise metric (e.g. when classes are imbalanced). Given the tracking and classwise functionalities already provided by torchmetrics, I expected to be able to implement a generic way to do this using the existing API.

A further motivation is that I would like the ClasswiseWrapper to function as much as possible like its base metric. I provided simplified examples, but in my real use-case I define my metrics (both base metrics and their classwise versions) in configuration files (w/ Hydra). I group everything in one MetricCollection, and add a MetricTracker on top of that.

The whole torchmetrics API is really well thought-out and incredibly useful to keep my metric logging code generic! Therefore, I would like to avoid having to check the type of metric and special-case ClasswiseWrapper compared to base metrics.

@nathanpainchaud nathanpainchaud added the enhancement New feature or request label Mar 8, 2025
Copy link

github-actions bot commented Mar 8, 2025

Hi! Thanks for your contribution! Great first issue!

@nathanpainchaud nathanpainchaud changed the title MetricTracker to support infering higher_is_better attribute of metrics wrapped by ClasswiseWrapper Improve MetricTracker support for metrics nested in ClasswiseWrapper Mar 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant