From 6a8b15f8758cebf2d7441dd07ca3440453387c2a Mon Sep 17 00:00:00 2001 From: Colin Leong <122366389+cleong110@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:37:06 -0500 Subject: [PATCH 1/2] Adding basic MetricSignature functionality --- pose_evaluation/metrics/base.py | 71 ++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/pose_evaluation/metrics/base.py b/pose_evaluation/metrics/base.py index 75f8803..b9d663c 100644 --- a/pose_evaluation/metrics/base.py +++ b/pose_evaluation/metrics/base.py @@ -1,11 +1,76 @@ # pylint: disable=undefined-variable from tqdm import tqdm +from typing import Any, Callable +class Signature: + """Represents reproducibility signatures for metrics. Inspired by sacreBLEU + """ + def __init__(self, args: dict): + + self._abbreviated = { + "name":"n", + "higher_is_better":"hb" + } + + self.signature_info = { + "name": args.get("name", None), + "higher_is_better": args.get("higher_is_better", None) + } + + def update(self, key: str, value: Any): + self.signature_info[key] = value + + def update_signature_and_abbr(self, key:str, abbr:str, args:dict): + self._abbreviated.update({ + key: abbr + }) + + self.signature_info.update({ + key: args.get(key, None) + }) + + def format(self, short: bool = False) -> str: + pairs = [] + keys = list(self.signature_info.keys()) + for name in keys: + value = self.signature_info[name] + if value is not None: + # Check for nested signature objects + if hasattr(value, "get_signature"): + + # Wrap nested signatures in brackets + nested_signature = value.get_signature() + if isinstance(nested_signature, Signature): + nested_signature = nested_signature.format(short=short) + value = f"{{{nested_signature}}}" + if isinstance(value, bool): + # Replace True/False with yes/no + value = "yes" if value else "no" + if isinstance(value, Callable): + value = value.__name__ + final_name = self._abbreviated[name] if short else name + pairs.append(f"{final_name}:{value}") + + return "|".join(pairs) + + def __str__(self): + return self.format() + + def __repr__(self): + return self.format() + + +class SignatureMixin: + _SIGNATURE_TYPE = Signature + def get_signature(self) -> Signature: + return self._SIGNATURE_TYPE(self.__dict__) class BaseMetric[T]: """Base class for all metrics.""" + # Each metric should define its Signature class' name here + _SIGNATURE_TYPE = Signature - def __init__(self, name: str, higher_is_better: bool = True): + def __init__(self, name: str, higher_is_better: bool = False): self.name = name self.higher_is_better = higher_is_better @@ -38,3 +103,7 @@ def score_all(self, hypotheses: list[T], references: list[T], progress_bar=True) def __str__(self): return self.name + + def get_signature(self) -> Signature: + return self._SIGNATURE_TYPE(self.__dict__) + From c4dfdd356e3ebaa198183d467b1b5b76e3fabba0 Mon Sep 17 00:00:00 2001 From: Colin Leong <122366389+cleong110@users.noreply.github.com> Date: Fri, 31 Jan 2025 13:31:09 -0500 Subject: [PATCH 2/2] Remove SignatureMixin, and make 'name' mandatory for Signatures --- pose_evaluation/metrics/base.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pose_evaluation/metrics/base.py b/pose_evaluation/metrics/base.py index b9d663c..e9e7720 100644 --- a/pose_evaluation/metrics/base.py +++ b/pose_evaluation/metrics/base.py @@ -5,7 +5,7 @@ class Signature: """Represents reproducibility signatures for metrics. Inspired by sacreBLEU """ - def __init__(self, args: dict): + def __init__(self, name:str, args: dict): self._abbreviated = { "name":"n", @@ -13,7 +13,7 @@ def __init__(self, args: dict): } self.signature_info = { - "name": args.get("name", None), + "name": name, "higher_is_better": args.get("higher_is_better", None) } @@ -59,12 +59,6 @@ def __str__(self): def __repr__(self): return self.format() - -class SignatureMixin: - _SIGNATURE_TYPE = Signature - def get_signature(self) -> Signature: - return self._SIGNATURE_TYPE(self.__dict__) - class BaseMetric[T]: """Base class for all metrics.""" # Each metric should define its Signature class' name here