Skip to content

Commit

Permalink
Various stylistic and commenting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cleong110 committed Dec 4, 2024
1 parent 3ca874e commit d7fb10e
Showing 1 changed file with 108 additions and 56 deletions.
164 changes: 108 additions & 56 deletions pose_evaluation/metrics/embedding_distance_metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Literal, List
from typing import Literal, List, Union
import logging

import torch
from torch import Tensor
from torch.types import Number
import numpy as np
from sentence_transformers import util as st_util

from pose_evaluation.metrics.base_embedding_metric import EmbeddingMetric


Expand All @@ -14,20 +18,29 @@
# * cosine_distance: https://github.com/pgvector/pgvector/blob/master/src/vector.c#L658
# * l2_distance https://github.com/pgvector/pgvector/blob/master/src/vector.c#L566

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

ValidDistanceKinds = Literal["cosine", "euclidean", "manhattan", "dot"]
TensorConvertableType = Union[List, np.ndarray, Tensor]


class EmbeddingDistanceMetric(EmbeddingMetric):
def __init__(
self,
kind: Literal["cosine", "euclidean", "dot"] = "cosine",
device: torch.device | str = None,
dtype=torch.float64,
kind: ValidDistanceKinds = "cosine",
device: Union[torch.device, str] = None,
dtype=torch.float32,
):
"""
Initialize the embedding distance metric.
Args:
kind (Literal["cosine", "euclidean"]): The type of distance metric.
device (torch.device | str): The device to use for computation. If None, automatically detects.
kind (ValidDistanceKinds): The type of distance metric.
device (Union[torch.device, str]): The device to use for computation.
If None, automatically detects.
dtype (torch.dtype): The data type to use for tensors.
If None, uses torch.get_default_dtype()
"""
super().__init__(f"EmbeddingDistanceMetric {kind}", higher_is_better=False)
self.kind = kind
Expand All @@ -36,32 +49,77 @@ def __init__(
else:
self.device = torch.device(device) if isinstance(device, str) else device

if dtype is None:
dtype = torch.get_default_dtype()

# Dispatch table for metric computations
self._metric_dispatch = {
"cosine": self.cosine_distances,
"euclidean": self.euclidean_distances,
"dot": self.dot_product,
"manhattan": self.manhattan_distances,
}

self.dtype = dtype

def _to_device_tensor(self, data: list | np.ndarray | Tensor, dtype=None) -> Tensor:
def set_device(self, device: Union[torch.device, str]) -> None:
"""
Explicitly set the device used for tensors.
Args:
device (Union[torch.device, str]): The device to use for computation.
"""
self.device = torch.device(device)
logger.info(f"Device set to: {self.device}")

def _to_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor:
"""
Convert input data to a tensor on the specified device.
Args:
data (TensorConvertableType: The input data to convert.
dtype (torch.dtype): The data type for the tensor.
Returns:
Tensor: Tensor representation of the data on the specified device.
"""
if dtype is None:
dtype = self.dtype
return st_util._convert_to_tensor(data).to(device=self.device, dtype=dtype)

def _to_batch_tensor_on_device(self, data: list | np.ndarray | Tensor, dtype=None) -> Tensor:
def _to_batch_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor:
"""
Convert input data to a batch tensor on the specified device.
Args:
data (TensorConvertableType): The input data to convert.
dtype (torch.dtype): The data type for the tensor.
Returns:
Tensor: Batch tensor representation of the data on the specified device.
"""
if dtype is None:
dtype = self.dtype
return st_util._convert_to_batch_tensor(data).to(device=self.device, dtype=dtype)

def score(
self,
hypothesis: list | np.ndarray | Tensor,
reference: list | np.ndarray | Tensor,
) -> float:
hypothesis: TensorConvertableType,
reference: TensorConvertableType,
) -> Number:
"""
Compute the distance between two embeddings.
Args:
hypothesis (list| np.ndarray | Tensor): A single embedding vector.
reference (list| np.ndarray | Tensor): Another single embedding vector.
hypothesis (TensorConvertableType): A single embedding vector.
reference (TensorConvertableType): Another single embedding vector.
Returns:
float: The calculated distance.
Number: The calculated distance.
Raises:
ValueError: If either input is None.
TypeError: If inputs cannot be converted to tensors.
"""
if hypothesis is None or reference is None:
raise ValueError("Neither 'hypothesis' nor 'reference' can be None.")
Expand All @@ -75,95 +133,89 @@ def score(

def score_all(
self,
hypotheses: List[list | np.ndarray | Tensor],
references: List[list | np.ndarray | Tensor],
hypotheses: Union[List[TensorConvertableType], Tensor],
references: Union[List[TensorConvertableType], Tensor],
progress_bar: bool = True,
) -> Tensor:
"""
Compute the pairwise distance between all hypotheses and references.
Expects 2D inputs, where each element in the second dimension is one embedding
Compute the distance between all hypotheses and all references.
Expects 2D inputs. If not already Tensors, will attempt to convert them.
Args:
hypotheses (list[list| np.ndarray | Tensor]): List of hypothesis embeddings.
references (list[list| np.ndarray | Tensor]): List of reference embeddings.
progress_bar (bool): Whether to display a progress bar.
hypotheses (Union[List[TensorConvertableType], Tensor]):
List of hypothesis embeddings or a single tensor.
references (Union[List[TensorConvertableType], Tensor]):
List of reference embeddings or a single tensor.
progress_bar (bool): Whether to display a progress bar. (not implemented yet)
Returns:
Tensor, distance matrix. Row i is the distances of hypotheses[i] to all rows of references
Tensor: Distance matrix. Row `i` is the distances of `hypotheses[i]` to all rows of `references`.
Shape is be NxM, where N is the number of hypotheses, and M is the number of references
Raises:
ValueError: If the specified metric is unsupported.
"""
# Convert inputs to tensors and stack
hypotheses = torch.stack([self._to_device_tensor(h) for h in hypotheses])
references = torch.stack([self._to_device_tensor(r) for r in references])

if self.kind == "dot":
distance_matrix = self.dot_product(hypotheses, references)

elif self.kind == "cosine":
distance_matrix = self.cosine_distances(hypotheses, references)

elif self.kind == "euclidean":
distance_matrix = self.euclidean_distances(hypotheses, references)
hypotheses = torch.stack([self._to_tensor_on_device(h) for h in hypotheses])
references = torch.stack([self._to_tensor_on_device(r) for r in references])

elif self.kind == "manhattan":
distance_matrix = self.manhattan_distances(hypotheses, references)

else:
if self.kind not in self._metric_dispatch:
logger.error(f"Unsupported distance metric: {self.kind}")
raise ValueError(f"Unsupported distance metric: {self.kind}")

distance_matrix = self._metric_dispatch[self.kind](hypotheses, references)
return distance_matrix

def dot_product(self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor) -> Tensor:
def dot_product(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Compute the dot product between embeddings.
Uses sentence_transformers.util.dot_score
"""
# TODO: test if this gives the same thing as previous matmul implementation, see stack overflow link below:
# https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul
return st_util.dot_score(hypotheses, references)

def euclidean_similarities(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def euclidean_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Returns the negative L2 norm/euclidean distances, which is what sentence-transformers uses for similarities.
Uses sentence_transformers.util.euclidean_sim
"""
return st_util.euclidean_sim(hypotheses, references)

def euclidean_distances(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def euclidean_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Seeing as how sentence-transformers just negates the distances to get "similarities",
We can re-negate to get them positive again.
Uses sentence_transformers.util.euclidean_similarities
"""
return -self.euclidean_similarities(hypotheses, references)

def cosine_similarities(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def cosine_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Calculates cosine similarities, which can be thought of as the angle between two embeddings.
The min value is -1 (least similar/pointing directly away), and the max is 1 (exactly the same angle).
Uses sentence_transformers.util.cos_sim
"""
return st_util.cos_sim(hypotheses, references)

def cosine_distances(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def cosine_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Converts cosine similarities to distances by simply subtracting from 1.
Max distance is 2, min distance is 0.
"""
return 1 - self.cosine_similarities(hypotheses, references)

def manhattan_similarities(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def manhattan_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Get the L1/Manhattan similarities, aka negative distances.
Uses sentence_transformers.util.manhattan_sim
"""
return st_util.manhattan_sim(hypotheses, references)

def manhattan_distances(
self, hypotheses: list | np.ndarray | Tensor, references: list | np.ndarray | Tensor
) -> Tensor:
def manhattan_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Convert Manhattan similarities to distances.
Sentence transformers defines similarity as negative distances.
We can re-negate to recover the distances.
"""
Expand Down

0 comments on commit d7fb10e

Please sign in to comment.