diff --git a/pose_evaluation/metrics/embedding_distance_metric.py b/pose_evaluation/metrics/embedding_distance_metric.py index fdfa712..875f7f1 100644 --- a/pose_evaluation/metrics/embedding_distance_metric.py +++ b/pose_evaluation/metrics/embedding_distance_metric.py @@ -1,8 +1,12 @@ -from typing import Literal, List +from typing import Literal, List, Union +import logging + import torch from torch import Tensor +from torch.types import Number import numpy as np from sentence_transformers import util as st_util + from pose_evaluation.metrics.base_embedding_metric import EmbeddingMetric @@ -14,20 +18,29 @@ # * cosine_distance: https://github.com/pgvector/pgvector/blob/master/src/vector.c#L658 # * l2_distance https://github.com/pgvector/pgvector/blob/master/src/vector.c#L566 +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +ValidDistanceKinds = Literal["cosine", "euclidean", "manhattan", "dot"] +TensorConvertableType = Union[List, np.ndarray, Tensor] + class EmbeddingDistanceMetric(EmbeddingMetric): def __init__( self, - kind: Literal["cosine", "euclidean", "dot"] = "cosine", - device: torch.device | str = None, - dtype=torch.float64, + kind: ValidDistanceKinds = "cosine", + device: Union[torch.device, str] = None, + dtype=torch.float32, ): """ Initialize the embedding distance metric. Args: - kind (Literal["cosine", "euclidean"]): The type of distance metric. - device (torch.device | str): The device to use for computation. If None, automatically detects. + kind (ValidDistanceKinds): The type of distance metric. + device (Union[torch.device, str]): The device to use for computation. + If None, automatically detects. + dtype (torch.dtype): The data type to use for tensors. + If None, uses torch.get_default_dtype() """ super().__init__(f"EmbeddingDistanceMetric {kind}", higher_is_better=False) self.kind = kind @@ -36,32 +49,77 @@ def __init__( else: self.device = torch.device(device) if isinstance(device, str) else device + if dtype is None: + dtype = torch.get_default_dtype() + + # Dispatch table for metric computations + self._metric_dispatch = { + "cosine": self.cosine_distances, + "euclidean": self.euclidean_distances, + "dot": self.dot_product, + "manhattan": self.manhattan_distances, + } + self.dtype = dtype - def _to_device_tensor(self, data: list | np.ndarray | Tensor, dtype=None) -> Tensor: + def set_device(self, device: Union[torch.device, str]) -> None: + """ + Explicitly set the device used for tensors. + + Args: + device (Union[torch.device, str]): The device to use for computation. + """ + self.device = torch.device(device) + logger.info(f"Device set to: {self.device}") + + def _to_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor: + """ + Convert input data to a tensor on the specified device. + + Args: + data (TensorConvertableType: The input data to convert. + dtype (torch.dtype): The data type for the tensor. + + Returns: + Tensor: Tensor representation of the data on the specified device. + """ if dtype is None: dtype = self.dtype return st_util._convert_to_tensor(data).to(device=self.device, dtype=dtype) - def _to_batch_tensor_on_device(self, data: list | np.ndarray | Tensor, dtype=None) -> Tensor: + def _to_batch_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor: + """ + Convert input data to a batch tensor on the specified device. + + Args: + data (TensorConvertableType): The input data to convert. + dtype (torch.dtype): The data type for the tensor. + + Returns: + Tensor: Batch tensor representation of the data on the specified device. + """ if dtype is None: dtype = self.dtype return st_util._convert_to_batch_tensor(data).to(device=self.device, dtype=dtype) def score( self, - hypothesis: list | np.ndarray | Tensor, - reference: list | np.ndarray | Tensor, - ) -> float: + hypothesis: TensorConvertableType, + reference: TensorConvertableType, + ) -> Number: """ Compute the distance between two embeddings. Args: - hypothesis (list| np.ndarray | Tensor): A single embedding vector. - reference (list| np.ndarray | Tensor): Another single embedding vector. + hypothesis (TensorConvertableType): A single embedding vector. + reference (TensorConvertableType): Another single embedding vector. Returns: - float: The calculated distance. + Number: The calculated distance. + + Raises: + ValueError: If either input is None. + TypeError: If inputs cannot be converted to tensors. """ if hypothesis is None or reference is None: raise ValueError("Neither 'hypothesis' nor 'reference' can be None.") @@ -75,95 +133,89 @@ def score( def score_all( self, - hypotheses: List[list | np.ndarray | Tensor], - references: List[list | np.ndarray | Tensor], + hypotheses: Union[List[TensorConvertableType], Tensor], + references: Union[List[TensorConvertableType], Tensor], progress_bar: bool = True, ) -> Tensor: """ - Compute the pairwise distance between all hypotheses and references. - Expects 2D inputs, where each element in the second dimension is one embedding + Compute the distance between all hypotheses and all references. + + Expects 2D inputs. If not already Tensors, will attempt to convert them. Args: - hypotheses (list[list| np.ndarray | Tensor]): List of hypothesis embeddings. - references (list[list| np.ndarray | Tensor]): List of reference embeddings. - progress_bar (bool): Whether to display a progress bar. + hypotheses (Union[List[TensorConvertableType], Tensor]): + List of hypothesis embeddings or a single tensor. + references (Union[List[TensorConvertableType], Tensor]): + List of reference embeddings or a single tensor. + progress_bar (bool): Whether to display a progress bar. (not implemented yet) Returns: - Tensor, distance matrix. Row i is the distances of hypotheses[i] to all rows of references + Tensor: Distance matrix. Row `i` is the distances of `hypotheses[i]` to all rows of `references`. + Shape is be NxM, where N is the number of hypotheses, and M is the number of references + + Raises: + ValueError: If the specified metric is unsupported. """ # Convert inputs to tensors and stack - hypotheses = torch.stack([self._to_device_tensor(h) for h in hypotheses]) - references = torch.stack([self._to_device_tensor(r) for r in references]) - - if self.kind == "dot": - distance_matrix = self.dot_product(hypotheses, references) - - elif self.kind == "cosine": - distance_matrix = self.cosine_distances(hypotheses, references) - - elif self.kind == "euclidean": - distance_matrix = self.euclidean_distances(hypotheses, references) + hypotheses = torch.stack([self._to_tensor_on_device(h) for h in hypotheses]) + references = torch.stack([self._to_tensor_on_device(r) for r in references]) - elif self.kind == "manhattan": - distance_matrix = self.manhattan_distances(hypotheses, references) - - else: + if self.kind not in self._metric_dispatch: + logger.error(f"Unsupported distance metric: {self.kind}") raise ValueError(f"Unsupported distance metric: {self.kind}") + distance_matrix = self._metric_dispatch[self.kind](hypotheses, references) return distance_matrix - def dot_product(self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor) -> Tensor: + def dot_product(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Compute the dot product between embeddings. + Uses sentence_transformers.util.dot_score + """ # TODO: test if this gives the same thing as previous matmul implementation, see stack overflow link below: # https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul return st_util.dot_score(hypotheses, references) - def euclidean_similarities( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def euclidean_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ Returns the negative L2 norm/euclidean distances, which is what sentence-transformers uses for similarities. + Uses sentence_transformers.util.euclidean_sim """ return st_util.euclidean_sim(hypotheses, references) - def euclidean_distances( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def euclidean_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ Seeing as how sentence-transformers just negates the distances to get "similarities", We can re-negate to get them positive again. + Uses sentence_transformers.util.euclidean_similarities """ return -self.euclidean_similarities(hypotheses, references) - def cosine_similarities( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def cosine_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ Calculates cosine similarities, which can be thought of as the angle between two embeddings. The min value is -1 (least similar/pointing directly away), and the max is 1 (exactly the same angle). + Uses sentence_transformers.util.cos_sim """ return st_util.cos_sim(hypotheses, references) - def cosine_distances( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def cosine_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ Converts cosine similarities to distances by simply subtracting from 1. Max distance is 2, min distance is 0. """ return 1 - self.cosine_similarities(hypotheses, references) - def manhattan_similarities( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def manhattan_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ Get the L1/Manhattan similarities, aka negative distances. + Uses sentence_transformers.util.manhattan_sim """ return st_util.manhattan_sim(hypotheses, references) - def manhattan_distances( - self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor - ) -> Tensor: + def manhattan_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: """ + Convert Manhattan similarities to distances. Sentence transformers defines similarity as negative distances. We can re-negate to recover the distances. """