diff --git a/pose_evaluation/evaluation/evaluate_signclip.py b/pose_evaluation/evaluation/evaluate_signclip.py new file mode 100644 index 0000000..d293c05 --- /dev/null +++ b/pose_evaluation/evaluation/evaluate_signclip.py @@ -0,0 +1,318 @@ +import argparse +from pathlib import Path +import time +import json +import random +import pandas as pd +import numpy as np +import torch +from tqdm import tqdm +from pose_evaluation.metrics.embedding_distance_metric import EmbeddingDistanceMetric + +def load_embedding(file_path: Path) -> np.ndarray: + """ + Load a SignCLIP embedding from a .npy file, ensuring it has the correct shape. + + Args: + file_path (Path): Path to the .npy file. + + Returns: + np.ndarray: The embedding with shape (768,). + """ + embedding = np.load(file_path) + if embedding.ndim == 2 and embedding.shape[0] == 1: + embedding = embedding[0] # Reduce shape from (1, 768) to (768,) + return embedding + + +def match_embeddings_to_glosses(emb_dir: Path, split_df: pd.DataFrame) -> pd.DataFrame: + """ + Match .npy embeddings to the corresponding glosses based on the numerical ID. + + Args: + emb_dir (Path): Directory containing the .npy files. + split_df (pd.DataFrame): DataFrame containing the split file with the "Video file" column. + + Returns: + pd.DataFrame: Updated DataFrame with an additional column for embeddings. + """ + + # Step 1: Create a mapping of numerical IDs to .npy files + map_start = time.perf_counter() + embeddings_map = {npy_file.stem.split("-")[0]: npy_file for npy_file in emb_dir.glob("*.npy")} + map_end = time.perf_counter() + print(f"Creating embeddings map took {map_end - map_start:.4f} seconds") + + # Step 2: Vectorized matching of embeddings + match_start = time.perf_counter() + + def get_embedding(video_file): + numerical_id = video_file.split("-")[0] + npy_file = embeddings_map.get(numerical_id) + if npy_file is not None: + return load_embedding(npy_file) + return None + + split_df["embedding"] = split_df["Video file"].apply(get_embedding) + match_end = time.perf_counter() + print(f"Matching embeddings to glosses took {match_end - match_start:.4f} seconds") + + return split_df + + +def calculate_mean_distances( + distance_matrix: torch.Tensor, indices_a: torch.Tensor, indices_b: torch.Tensor, exclude_self: bool = False +) -> float: + """ + Calculate the mean of distances between two sets of indices in a 2D distance matrix. + + Args: + distance_matrix (torch.Tensor): A 2D tensor representing pairwise distances. + indices_a (torch.Tensor): A tensor of row indices. + indices_b (torch.Tensor): A tensor of column indices. + exclude_self (bool): Whether to exclude distances where indices_a == indices_b. + + Returns: + float: The mean distance between all pairs of (indices_a, indices_b). + """ + # Create all pair combinations + row_indices, col_indices = torch.meshgrid(indices_a, indices_b, indexing="ij") + + if exclude_self: + # Apply a mask to remove self-distances + mask = row_indices != col_indices + row_indices = row_indices[mask] + col_indices = col_indices[mask] + + # Gather distances + selected_distances = distance_matrix[row_indices.flatten(), col_indices.flatten()] + + # Return the mean + return selected_distances.mean().item() + + +def generate_synthetic_data(num_items, num_classes, num_items_per_class=4): + + torch.manual_seed(42) + random.seed(42) + # distance_matrix = torch.rand((num_items, num_items)) * 100 + distance_matrix = torch.full((num_items, num_items), 10.0) + distance_matrix.fill_diagonal_(0) + indices = list(range(num_items)) + random.shuffle(indices) + + classes = { + f"CLASS_{i}": torch.tensor([indices.pop() for _ in range(num_items_per_class)]) for i in range(num_classes) + } + # Assign intra-class distances + mean_values_by_class = {} + for i, class_name in enumerate(classes.keys()): + mean_value = i + 1 + mean_values_by_class[class_name] = mean_value + for class_name, indices in classes.items(): + mean_value = mean_values_by_class[class_name] + for i in indices: + for j in indices: + if i != j: # Exclude self-distances + distance_matrix[i, j] = mean_value + return classes, distance_matrix + + +def calculate_class_means(gloss_indices, scores): + class_means_by_gloss = {} + all_indices = torch.arange(scores.size(0), dtype=int) + + for gloss, indices in tqdm(gloss_indices.items(), desc="Finding mean values by gloss"): + indices = torch.LongTensor(indices) + class_means_by_gloss[gloss] = {} + within_class_mean = calculate_mean_distances(scores, indices, indices, exclude_self=True) + + class_means_by_gloss[gloss]["in_class"] = within_class_mean + + complement_indices = all_indices[~torch.isin(all_indices, indices)] + without_class_mean = calculate_mean_distances(scores, indices, complement_indices) + class_means_by_gloss[gloss]["out_of_class"] = without_class_mean + + return class_means_by_gloss + + +# def calculate_class_means(gloss_indices, scores): +# all_within_class_distances = np.array([]) # Initialize as empty NumPy array +# all_between_class_distances = np.array([]) # Initialize as empty NumPy array +# within_class_means_by_gloss = {} +# for gloss, indices in tqdm(gloss_indices.items(), desc="Finding mean values by gloss"): +# # Within-class distances +# within_class_distances = scores[np.ix_(indices, indices)] +# within_class_mean = torch.mean(within_class_distances) +# within_class_means_by_gloss[gloss] = within_class_mean +# within_class_distances = within_class_distances[np.triu_indices(len(indices), k=1)] +# all_within_class_distances = np.concatenate([all_within_class_distances, within_class_distances.ravel()]) +# +# # Between-class distances +# other_indices = np.setdiff1d(np.arange(len(scores)), indices) +# between_class_distances = scores[np.ix_(indices, other_indices)] +# all_between_class_distances = np.concatenate([all_between_class_distances, between_class_distances.ravel()]) +# +# for gloss, mean in within_class_means_by_gloss.items(): +# print(f"Within {gloss}: {within_class_means_by_gloss[gloss]}") +# +# print(f"Mean within classes: {np.mean(all_within_class_distances)}") +# print(f"Mean between classes: {np.mean(all_between_class_distances)}") +# return within_class_means_by_gloss + + +def evaluate_signclip(emb_dir: Path, split_file: Path, out_path: Path, kind: str = "cosine"): + """ + Evaluate SignCLIP embeddings using score_all. + + Args: + emb_dir (Path): Directory containing .npy embeddings. + split_file (Path): Path to the split CSV file. + kind (str): Metric type ("cosine" or "l2"). Default is "cosine". + """ + overall_start = time.perf_counter() # Start overall benchmarking + + # Step 1: Load split file + split_load_start = time.perf_counter() + split_df = pd.read_csv(split_file) + split_load_end = time.perf_counter() + print(f"Loading split file took {split_load_end - split_load_start:.4f} seconds") + # print(f"{split_df.info()}") + + # Step 2: Match embeddings to glosses + match_start = time.perf_counter() + split_df = match_embeddings_to_glosses(emb_dir, split_df) + match_end = time.perf_counter() + print(f"Matching embeddings to glosses took {match_end - match_start:.4f} seconds") + # print(split_df.info()) + + # Step 3: Filter out rows without embeddings + filter_start = time.perf_counter() + items_with_embeddings_df = split_df.dropna(subset=["embedding"]).reset_index(drop=True) + embeddings = items_with_embeddings_df["embedding"].tolist() + filter_end = time.perf_counter() + print(f"Filtering embeddings took {filter_end - filter_start:.4f} seconds") + print(items_with_embeddings_df.info()) + + # Step 4: Initialize the distance metric + metric_start = time.perf_counter() + # metric = EmbeddingDistanceMetric(kind=kind, device="cpu") + metric = EmbeddingDistanceMetric(kind=kind) + metric_end = time.perf_counter() + print(f"Initializing metric took {metric_end - metric_start:.4f} seconds") + + # Step 5: Compute all pairwise scores + score_start = time.perf_counter() + print(f"Computing {kind} distances for {len(embeddings)} embeddings...") + scores = metric.score_all(embeddings, embeddings) + score_end = time.perf_counter() + print(f"Score_all took {score_end - score_start:.3f} seconds") + + # Step 7: Extract file list from DataFrame + files_start = time.perf_counter() + files = items_with_embeddings_df["Video file"].tolist() + files_end = time.perf_counter() + print(f"Extracting file list took {files_end - files_start:.4f} seconds") + + analysis_start = time.perf_counter() + index_to_check = 0 + number_to_check = 10 + print(f"The first {number_to_check} scores for {files[index_to_check]} to...") + for ref, score in list(zip(files, scores[index_to_check]))[:number_to_check]: + print("\t*------------->", f"{ref}".ljust(35), "\t", score.item()) + + unique_glosses = items_with_embeddings_df["Gloss"].unique() + print(f"We have a vocabulary of {len(unique_glosses)} glosses") + gloss_indices = {} + for gloss in items_with_embeddings_df["Gloss"].unique(): + gloss_indices[gloss] = items_with_embeddings_df.index[items_with_embeddings_df["Gloss"] == gloss].tolist() + + for gloss, indices in list(gloss_indices.items())[:10]: + print(f"Here are the {len(indices)} indices for {gloss}:{indices}") + + find_class_distances_start = time.perf_counter() + + # synthetic_classes, synthetic_distances = generate_synthetic_data(30000, 2700, 8) + # class_means = calculate_class_means(synthetic_classes, synthetic_distances) + class_means = calculate_class_means(gloss_indices, scores) + + find_class_distances_end = time.perf_counter() + + print(f"Finding within and without took {find_class_distances_end-find_class_distances_start}") + + analysis_end = time.perf_counter() + analysis_duration = analysis_end - analysis_start + + in_class_means = [mean_dict["in_class"] for mean_dict in class_means.values()] + out_class_means = [mean_dict["out_of_class"] for mean_dict in class_means.values()] + + for gloss, means in list(class_means.items())[:10]: + print(gloss, means) + + print(f"Mean of in-class means: {np.mean(in_class_means)}") + print(f"Mean of out-of-class means: {np.mean(out_class_means)}") + + print(f"Analysis took {analysis_duration} seconds") + + # Step 8: Save the scores and files to a compressed file + + save_start = time.perf_counter() + class_means_json = out_path.with_name(f"{out_path.stem}_class_means").with_suffix(".json") + with open(class_means_json, "w") as f: + print(f"Writing class means to {f}") + json.dump(class_means, f) + np.savez(out_path, scores=scores, files=files) + save_end = time.perf_counter() + print(f"Saving scores and files took {save_end - save_start:.4f} seconds") + print(f"Scores of shape {scores.shape} with files list of length {len(files)} saved to {out_path}") + + # Step 9: Read back the saved scores + read_start = time.perf_counter() + read_back_in = np.load(f"{out_path}") + read_end = time.perf_counter() + print(f"Reading back the file took {read_end - read_start:.4f} seconds") + + # Step 10: Verify if the read data matches the original scores + verify_start = time.perf_counter() + if np.allclose(read_back_in["scores"], scores): + print("Yay! All the same!") + else: + print("Mismatch found!") + verify_end = time.perf_counter() + print(f"Verification step took {verify_end - verify_start:.4f} seconds") + + # Overall time + overall_end = time.perf_counter() + print(f"Total script runtime: {overall_end - overall_start:.4f} seconds") + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate SignCLIP embeddings with score_all.") + parser.add_argument("emb_dir", type=Path, help="Path to the directory containing SignCLIP .npy files") + parser.add_argument("--split_file", type=Path, required=True, help="Path to the split CSV file (e.g., test.csv)") + parser.add_argument( + "--kind", + type=str, + choices=["cosine", "l2"], + default="cosine", + help="Type of distance metric to use (default: cosine)", + ) + + parser.add_argument("--out_path", type=Path, help="Where to save output distance npz matrix+file list") + + args = parser.parse_args() + + output_file = args.out_path + if output_file is None: + output_file = Path(f"signclip_scores_{args.split_file.name}").with_suffix(".npz") + + if output_file.suffix != ".npz": + output_file = Path(f"{output_file}.npz") + + print(f"Scores will be saved to {output_file}") + + evaluate_signclip(emb_dir=args.emb_dir, split_file=args.split_file, out_path=output_file, kind=args.kind) + + +if __name__ == "__main__": + main() diff --git a/pose_evaluation/metrics/.gitignore b/pose_evaluation/metrics/.gitignore new file mode 100644 index 0000000..cd78447 --- /dev/null +++ b/pose_evaluation/metrics/.gitignore @@ -0,0 +1 @@ +temp/ \ No newline at end of file diff --git a/pose_evaluation/metrics/base_embedding_metric.py b/pose_evaluation/metrics/base_embedding_metric.py new file mode 100644 index 0000000..2fb61c8 --- /dev/null +++ b/pose_evaluation/metrics/base_embedding_metric.py @@ -0,0 +1,9 @@ +from typing import TypeVar +import torch +from pose_evaluation.metrics.base import BaseMetric + + +# Define a type alias for embeddings (e.g., torch.Tensor) +Embedding = TypeVar("Embedding", bound=torch.Tensor) + +EmbeddingMetric = BaseMetric[Embedding] diff --git a/pose_evaluation/metrics/conftest.py b/pose_evaluation/metrics/conftest.py new file mode 100644 index 0000000..c04f587 --- /dev/null +++ b/pose_evaluation/metrics/conftest.py @@ -0,0 +1,50 @@ +import shutil +from pathlib import Path +from typing import Callable, Union +import torch +import numpy as np +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def clean_test_artifacts(): + """Fixture to clean up test artifacts before each test session.""" + test_artifacts_dir = Path(__file__).parent / "tests" # Using Path + if test_artifacts_dir.exists(): + shutil.rmtree(test_artifacts_dir) # shutil.rmtree still works with Path + test_artifacts_dir.mkdir(parents=True, exist_ok=True) # Using Path.mkdir + yield # This allows the test session to run + # (Optional) You can add cleanup logic here to run after the session if needed + + +@pytest.fixture(name="distance_matrix_shape_checker") +def fixture_distance_matrix_shape_checker() -> Callable[[torch.Tensor, torch.Tensor], None]: + def _check_shape(hyp_count: int, ref_count: int, distance_matrix: torch.Tensor): + + expected_shape = torch.Size([hyp_count, ref_count]) + assert ( + distance_matrix.shape == expected_shape + ), f"For M={hyp_count} hypotheses, N={ref_count} references, Distance Matrix should be MxN={expected_shape}. Instead, received {distance_matrix.shape}" + + return _check_shape + + +@pytest.fixture(name="distance_range_checker") +def fixture_distance_range_checker() -> Callable[[Union[torch.Tensor, np.ndarray], float, float], None]: + def _check_range( + distances: Union[torch.Tensor, np.ndarray], + min_val: float = 0, + max_val: float = 2, + ) -> None: + max_distance = distances.max().item() + min_distance = distances.min().item() + + # Use np.isclose for comparisons with tolerance + assert ( + np.isclose(min_distance, min_val, atol=1e-6) or min_val <= min_distance <= max_val + ), f"Minimum distance ({min_distance}) is outside the expected range [{min_val}, {max_val}]" + assert ( + np.isclose(max_distance, max_val, atol=1e-6) or min_val <= max_distance <= max_val + ), f"Maximum distance ({max_distance}) is outside the expected range [{min_val}, {max_val}]" + + return _check_range diff --git a/pose_evaluation/metrics/embedding_distance_metric.py b/pose_evaluation/metrics/embedding_distance_metric.py new file mode 100644 index 0000000..6044faa --- /dev/null +++ b/pose_evaluation/metrics/embedding_distance_metric.py @@ -0,0 +1,185 @@ +from typing import Literal, List, Union +import logging + +import torch +from torch import Tensor +from torch.types import Number +import numpy as np +from sentence_transformers import util as st_util + +from pose_evaluation.metrics.base_embedding_metric import EmbeddingMetric + + +# Useful reference: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/util.py#L31 +# * Helper functions such as batch_to_device, _convert_to_tensor, _convert_to_batch, _convert_to_batch_tensor +# * a whole semantic search function, with chunking and top_k + +# See also pgvector's C implementation: https://github.com/pgvector/pgvector/blob/master/src/vector.c +# * cosine_distance: https://github.com/pgvector/pgvector/blob/master/src/vector.c#L658 +# * l2_distance https://github.com/pgvector/pgvector/blob/master/src/vector.c#L566 + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +ValidDistanceKinds = Literal["cosine", "euclidean", "manhattan", "dot"] +TensorConvertableType = Union[List, np.ndarray, Tensor] + + +class EmbeddingDistanceMetric(EmbeddingMetric): + def __init__( + self, + kind: ValidDistanceKinds = "cosine", + device: Union[torch.device, str] = None, + dtype=None, + ): + """ + Args: + kind (ValidDistanceKinds): The type of distance metric, e.g. "cosine", or "euclidean". + device (Union[torch.device, str]): The device to use for computation. + If None, automatically detects. + dtype (torch.dtype): The data type to use for tensors. + If None, uses torch.get_default_dtype() + """ + super().__init__(f"EmbeddingDistanceMetric {kind}", higher_is_better=False) + self.kind = kind + if device is None: + self.device = torch.device(st_util.get_device_name()) + else: + self.device = torch.device(device) if isinstance(device, str) else device + + if dtype is None: + dtype = torch.get_default_dtype() + + self.dtype = dtype + + # Dispatch table for metric computations + self._metric_dispatch = { + "cosine": self.cosine_distances, + "euclidean": self.euclidean_distances, + "dot": self.dot_product, + "manhattan": self.manhattan_distances, + } + + def to(self, device: Union[torch.device, str]) -> None: + """ + Explicitly set the device used for tensors. + """ + self.device = torch.device(device) + logger.info(f"Device set to: {self.device}") + return self + + def _to_batch_tensor_on_device(self, data: TensorConvertableType) -> Tensor: + """ + Convert input data to a batch tensor on the specified device. + + Returns: + Tensor: Batch tensor representation of the data on the specified device. + """ + # better performance this way, see https://github.com/pytorch/pytorch/issues/13918 + if isinstance(data, list) and all(isinstance(x, np.ndarray) for x in data): + data = np.asanyarray(data) + + if isinstance(data, list) and all(isinstance(x, torch.Tensor) for x in data): + # prevents ValueError: only one element tensors can be converted to Python scalars + # https://stackoverflow.com/questions/55050717/converting-list-of-tensors-to-tensors-pytorch + data = torch.stack(data) + + return st_util._convert_to_batch_tensor(data).to(device=self.device, dtype=self.dtype) + + def score( + self, + hypothesis: TensorConvertableType, + reference: TensorConvertableType, + ) -> Number: + """ + Compute the distance between two embeddings. + + Returns: + Number: The calculated distance. + + """ + return self.score_all(hypothesis, reference).item() + + def score_all( + self, + hypotheses: Union[List[TensorConvertableType], Tensor], + references: Union[List[TensorConvertableType], Tensor], + progress_bar: bool = True, + ) -> Tensor: + """ + Compute the distance between all hypotheses and all references. + + Expects 2D inputs. If not already Tensors, will attempt to convert them. + + Returns: + Tensor: Distance matrix. Row `i` is the distances of `hypotheses[i]` to all rows of `references`. + Shape is be NxM, where N is the number of hypotheses, and M is the number of references + + Raises: + TypeError: If either hypotheses or references cannot be converted to a batch tensor + ValueError: If the specified metric is unsupported. + """ + try: + hypotheses = self._to_batch_tensor_on_device(hypotheses) + references = self._to_batch_tensor_on_device(references) + except RuntimeError as e: + raise TypeError(f"Inputs must support conversion to device tensors: {e}") from e + + assert ( + hypotheses.ndim == 2 and references.ndim == 2 + ), f"score_all received non-2D input: hypotheses: {hypotheses.shape}, references: {references.shape}" + + return self._metric_dispatch[self.kind](hypotheses, references) + + def dot_product(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Compute the dot product between embeddings. + Uses sentence_transformers.util.dot_score + """ + # https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul + return st_util.dot_score(hypotheses, references) + + def euclidean_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Returns the negative L2 norm/euclidean distances, which is what sentence-transformers uses for similarities. + Uses sentence_transformers.util.euclidean_sim + """ + return st_util.euclidean_sim(hypotheses, references) + + def euclidean_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Seeing as how sentence-transformers just negates the distances to get "similarities", + We can re-negate to get them positive again. + Uses sentence_transformers.util.euclidean_similarities + """ + return -self.euclidean_similarities(hypotheses, references) + + def cosine_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Calculates cosine similarities, which can be thought of as the angle between two embeddings. + The min value is -1 (least similar/pointing directly away), and the max is 1 (exactly the same angle). + Uses sentence_transformers.util.cos_sim + """ + return st_util.cos_sim(hypotheses, references) + + def cosine_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Converts cosine similarities to distances by simply subtracting from 1. + Max distance is 2, min distance is 0. + """ + return 1 - self.cosine_similarities(hypotheses, references) + + def manhattan_similarities(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Get the L1/Manhattan similarities, aka negative distances. + Uses sentence_transformers.util.manhattan_sim + """ + return st_util.manhattan_sim(hypotheses, references) + + def manhattan_distances(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor: + """ + Convert Manhattan similarities to distances. + Sentence transformers defines similarity as negative distances. + We can re-negate to recover the distances. + """ + return -self.manhattan_similarities(hypotheses, references) diff --git a/pose_evaluation/metrics/test_embedding_distance_metric.py b/pose_evaluation/metrics/test_embedding_distance_metric.py new file mode 100644 index 0000000..ab275c6 --- /dev/null +++ b/pose_evaluation/metrics/test_embedding_distance_metric.py @@ -0,0 +1,493 @@ +import itertools +from pathlib import Path +from typing import List, Callable, Tuple +import logging +import pytest +import numpy as np +import matplotlib.pyplot as plt +import torch +from pose_evaluation.metrics.embedding_distance_metric import EmbeddingDistanceMetric + + +# TODO: many fixes. Including the fact that we test cosine but not Euclidean, + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Device configuration for PyTorch +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_default_device(DEVICE) # so that we get arrays on the same device + + +# named the fixture this way to solve many pylint W0621 +# https://stackoverflow.com/questions/46089480/pytest-fixtures-redefining-name-from-outer-scope-pylint +@pytest.fixture(name="cosine_metric") +def fixture_cosine_metric(): + """Fixture to create an EmbeddingDistanceMetric instance.""" + return EmbeddingDistanceMetric(kind="cosine") + + +@pytest.fixture(name="embeddings") +def fixture_embeddings() -> List[torch.Tensor]: + """Fixture to create dummy embeddings for testing.""" + return [random_tensor(768) for _ in range(5)] + + +def test_shape_checker(distance_matrix_shape_checker): + emb_len = 768 + hyps = torch.rand((3, emb_len)) + refs = torch.rand((4, emb_len)) + + m = hyps.shape[0] + n = refs.shape[0] + + wrong_shapes = [1, m, n, emb_len] + wrong_shapes.extend(list(itertools.permutations(wrong_shapes, r=2))) + for wrong_shape in wrong_shapes: + if wrong_shape != (m, n): + distances_with_wrong_shape = torch.rand(wrong_shape) + with pytest.raises(AssertionError, match="Distance Matrix should be MxN"): + # This SHOULD happen. If this doesn't happen then the checker itself is not working. + distance_matrix_shape_checker(m, n, distances_with_wrong_shape) + + +def call_and_call_with_inputs_swapped( + hyps: torch.Tensor, refs: torch.Tensor, scoring_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + score1 = scoring_function(hyps, refs) + score2 = scoring_function(refs, hyps) + return score1, score2 + + +def call_with_both_input_orders_and_do_standard_checks( + hyps: torch.Tensor, + refs: torch.Tensor, + scoring_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + distance_range_checker, + distance_matrix_shape_checker, + expected_shape: Tuple = None, +): + scores, scores2 = call_and_call_with_inputs_swapped(hyps, refs, scoring_function) + if expected_shape is not None: + m, n = expected_shape + else: + m = hyps.shape[0] + n = refs.shape[0] + distance_range_checker(scores, min_val=0, max_val=2) + distance_range_checker(scores2, min_val=0, max_val=2) + distance_matrix_shape_checker(m, n, scores) + distance_matrix_shape_checker(n, m, scores2) + + return scores, scores2 + + +def save_and_plot_distances(distances, matrix_name, num_points, dim): + """Helper function to save distance matrix and plot distances.""" + + distances = distances.cpu() + test_artifacts_dir = Path(__file__).parent / "temp" + output_path = test_artifacts_dir / f"distance_matrix_{matrix_name}_{num_points}_{dim}D.csv" + np.savetxt(output_path, distances.numpy(), delimiter=",", fmt="%.4f") + print(f"Distance matrix saved to {output_path}") + + # Generate plot + plt.figure(figsize=(10, 6)) + for i, row in enumerate(distances.numpy()): + plt.plot(row, label=f"Point {i}") + plt.title(f"Distance Matrix Rows ({matrix_name})") + plt.xlabel("Point Index") + plt.ylabel("Distance") + plt.legend() + plot_path = output_path.with_suffix(".png") + plt.savefig(plot_path) + print(f"Distances plot saved to {plot_path}") + plt.close() + + +def random_tensor(size: int) -> torch.Tensor: + """Generate a random tensor on the appropriate device.""" + return torch.rand(size, dtype=torch.float32, device=DEVICE) + + +def generate_unit_circle_points(num_points: int, dim: int = 2) -> torch.Tensor: + angles = torch.linspace(0, 2 * np.pi, num_points + 1)[:-1] + x_coords = torch.cos(angles) + y_coords = torch.sin(angles) + points = torch.stack([x_coords, y_coords], dim=1) + if dim > 2: + padding = torch.zeros((num_points, dim - 2)) + points = torch.cat([points, padding], dim=1) + return points + + +def generate_orthogonal_rows_with_repeats(num_rows: int, dim: int) -> torch.Tensor: + orthogonal_rows = torch.empty(0, dim) + for _ in range(min(num_rows, dim)): + random_vector = torch.randn(1, dim) + if orthogonal_rows.shape[0] > 0: + random_vector -= ( + torch.matmul(random_vector, orthogonal_rows.T) + @ orthogonal_rows + / torch.norm(orthogonal_rows, dim=1, keepdim=True) ** 2 + ) + orthogonal_rows = torch.cat([orthogonal_rows, random_vector / torch.norm(random_vector)]) + if num_rows > dim: + orthogonal_rows = orthogonal_rows.repeat(num_rows // dim + 1, 1)[:num_rows] + return orthogonal_rows + + +def generate_orthogonal_rows_in_pairs(num_pairs: int, dim: int) -> torch.Tensor: + """ + Generates a tensor with orthogonal rows in pairs. + The first row of each pair is orthogonal to the second row of the same pair. + + Args: + num_pairs: The number of orthogonal pairs to generate. + dim: The dimensionality of the vectors. + + Returns: + A PyTorch tensor with orthogonal rows in pairs. + """ + + orthogonal_rows = torch.empty(0, dim) + for _ in range(num_pairs): + # Generate the first vector of the pair + first_vector = torch.randn(1, dim) + first_vector = first_vector / torch.norm(first_vector) # Normalize + + # Generate the second vector orthogonal to the first + second_vector = torch.randn(1, dim) + second_vector = second_vector - (second_vector @ first_vector.T) * first_vector + second_vector = second_vector / torch.norm(second_vector) # Normalize + + # Concatenate the pair to the result + orthogonal_rows = torch.cat([orthogonal_rows, first_vector, second_vector], dim=0) + + return orthogonal_rows + + +def generate_ones_tensor(rows: int, dims: int) -> torch.Tensor: + """Generates a tensor with all elements equal to 1.0 (float).""" + return torch.ones(rows, dims, dtype=torch.float32) + + +def generate_identity_matrix_rows(rows, cols): + """ + Returns an identity matrix with the specified number of rows and columns. + """ + identity = torch.eye(max(rows, cols)) + return identity[:rows, :cols] + + +def create_increasing_rows_tensor(num_rows: int, num_cols: int) -> torch.Tensor: + """ + Creates a tensor where every row has identical values all the way across, + but increasing row by row. + + Args: + num_rows: The number of rows in the tensor. + num_cols: The number of columns in the tensor. + + Returns: + A PyTorch tensor with the specified properties. + """ + + tensor = torch.arange(1.0, num_rows + 1).unsqueeze(1).repeat(1, num_cols) + return tensor + + +def test_score_symmetric(cosine_metric: EmbeddingDistanceMetric) -> None: + """Test that the metric is symmetric for cosine distance.""" + emb1 = random_tensor(768) + emb2 = random_tensor(768) + + score1, score2 = call_and_call_with_inputs_swapped(emb1, emb2, cosine_metric.score) + + logger.info(f"Score 1: {score1}, Score 2: {score2}") + assert pytest.approx(score1) == score2, "Score should be symmetric." + + +def test_score_with_path(cosine_metric: EmbeddingDistanceMetric, tmp_path: Path) -> None: + """Test that score works with embeddings loaded from file paths.""" + emb1 = random_tensor(768).cpu().numpy() # Save as NumPy for file storage + emb2 = random_tensor(768).cpu().numpy() + + # Save embeddings to temporary files + file1 = tmp_path / "emb1.npy" + file2 = tmp_path / "emb2.npy" + np.save(file1, emb1) + np.save(file2, emb2) + + # Load files as PyTorch tensors + emb1_loaded = torch.tensor(np.load(file1), dtype=torch.float32, device=DEVICE) + emb2_loaded = torch.tensor(np.load(file2), dtype=torch.float32, device=DEVICE) + + score = cosine_metric.score(emb1_loaded, emb2_loaded) + expected_score = cosine_metric.score(torch.tensor(emb1, device=DEVICE), torch.tensor(emb2, device=DEVICE)) + + logger.info(f"Score from file: {score}, Direct score: {expected_score}") + assert pytest.approx(score) == expected_score, "Score with paths should match direct computation." + + +def test_score_all_against_self( + cosine_metric: EmbeddingDistanceMetric, + embeddings: List[torch.Tensor], + distance_range_checker, + distance_matrix_shape_checker, +) -> None: + """Test the score_all function.""" + scores = cosine_metric.score_all(embeddings, embeddings) + distance_matrix_shape_checker(len(embeddings), len(embeddings), scores) + distance_range_checker(scores, min_val=0, max_val=2) + + assert torch.allclose( + torch.diagonal(scores), torch.zeros(len(embeddings), dtype=scores.dtype), atol=1e-6 + ), "Self-comparison scores should be zero for cosine distance." + + logger.info(f"Score matrix shape: {scores.shape}, Diagonal values: {torch.diagonal(scores)}") + + +def test_score_all_with_one_vs_batch(cosine_metric, distance_range_checker, distance_matrix_shape_checker): + hyps = [np.random.rand(768) for _ in range(3)] + refs = np.random.rand(768) + + expected_shape = (len(hyps), 1) + + call_with_both_input_orders_and_do_standard_checks( + hyps, refs, cosine_metric.score_all, distance_range_checker, distance_matrix_shape_checker, expected_shape + ) + + +def test_score_all_with_different_sizes(cosine_metric, distance_range_checker, distance_matrix_shape_checker): + """Test score_all with different sizes for hypotheses and references.""" + hyps = [np.random.rand(768) for _ in range(3)] + refs = [np.random.rand(768) for _ in range(5)] + + expected_shape = (len(hyps), len(refs)) + call_with_both_input_orders_and_do_standard_checks( + hyps, refs, cosine_metric.score_all, distance_range_checker, distance_matrix_shape_checker, expected_shape + ) + + +def test_score_with_invalid_input_mismatched_embedding_sizes(cosine_metric: EmbeddingDistanceMetric) -> None: + hyp = random_tensor(768) + ref = random_tensor(769) + + with pytest.raises(RuntimeError): + # gives RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x768 and 769x1 + # TODO: we should probably raise a more descriptive/helpful error/ ValueError + call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score) + + +def test_score_with_invalid_input_single_number(cosine_metric: EmbeddingDistanceMetric) -> None: + hyp = random_tensor(768) + for ref in range(-2, 2): + with pytest.raises(AssertionError, match="score_all received non-2D input"): + # TODO: we should probably raise a more descriptive/helpful error/ ValueError + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score) + + logger.info("Invalid input successfully crashed as expected.") + + +def test_score_with_invalid_input_string(cosine_metric: EmbeddingDistanceMetric) -> None: + hyp = "invalid input" + ref = random_tensor(768) + with pytest.raises(TypeError, match="invalid data type 'str'"): + call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score) + + +def test_score_with_invalid_input_bool(cosine_metric: EmbeddingDistanceMetric) -> None: + hyp = random_tensor(768) + invalid_inputs = [True, False] + for ref in invalid_inputs: + with pytest.raises(AssertionError, match="score_all received non-2D input"): + call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score) + # TODO: why does a bool make it all the way there? + + +def test_score_with_invalid_input_empty_containers(cosine_metric: EmbeddingDistanceMetric) -> None: + """Test the metric with invalid inputs.""" + emb1 = random_tensor(768) + invalid_inputs = ["", [], {}, tuple(), set()] + + for invalid_input in invalid_inputs: + with pytest.raises((RuntimeError, TypeError, IndexError)): + # gives RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x768 and 0x1) + # "" gives TypeError: new(): invalid data type 'str' + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + # TODO: we should probably raise a more descriptive/helpful error/ ValueError + call_and_call_with_inputs_swapped(emb1, invalid_input, cosine_metric.score) + + logger.info("Invalid input successfully crashed as expected.") + + +def test_score_tensor_input(cosine_metric): + """Test score function with torch.Tensor inputs.""" + emb1 = torch.rand(768) + emb2 = torch.rand(768) + + score = cosine_metric.score(emb1, emb2) + assert isinstance(score, float), "Output should be a float." + + +def test_score_ndarray_input(cosine_metric): + """Test score function with np.ndarray inputs.""" + emb1 = np.random.rand(768) + emb2 = np.random.rand(768) + + score = cosine_metric.score(emb1, emb2) + assert isinstance(score, float), "Output should be a float." + + +def test_score_all_list_of_lists_of_floats( + cosine_metric, + distance_range_checker, + distance_matrix_shape_checker, +): + """Does a 2D list of floats work?""" + hyps = [[np.random.rand() for _ in range(768)] for _ in range(5)] + refs = [[np.random.rand() for _ in range(768)] for _ in range(5)] + expected_shape = (len(hyps), len(refs)) + + call_with_both_input_orders_and_do_standard_checks( + hyps, + refs, + cosine_metric.score_all, + distance_range_checker, + distance_matrix_shape_checker, + expected_shape=expected_shape, + ) + + +def test_score_all_list_of_tensor_input(cosine_metric, distance_range_checker, distance_matrix_shape_checker): + """Test score_all function with List of torch.Tensor inputs.""" + hyps = [torch.rand(768) for _ in range(5)] + refs = [torch.rand(768) for _ in range(5)] + + expected_shape = (len(hyps), len(refs)) + + call_with_both_input_orders_and_do_standard_checks( + hyps, + refs, + cosine_metric.score_all, + distance_range_checker, + distance_matrix_shape_checker, + expected_shape=expected_shape, + ) + + +def test_score_all_list_of_ndarray_input( + cosine_metric, + distance_range_checker, + distance_matrix_shape_checker, +): + """Test score_all function with List of np.ndarray inputs.""" + hyps = [np.random.rand(768) for _ in range(5)] + refs = [np.random.rand(768) for _ in range(5)] + expected_shape = (len(hyps), len(refs)) + + call_with_both_input_orders_and_do_standard_checks( + hyps, + refs, + cosine_metric.score_all, + distance_range_checker, + distance_matrix_shape_checker, + expected_shape=expected_shape, + ) + + +def test_device_handling(cosine_metric): + """Test device handling for the metric.""" + assert cosine_metric.device.type in ["cuda", "cpu"], "Device should be either 'cuda' or 'cpu'." + if torch.cuda.is_available(): + assert cosine_metric.device.type == "cuda", "Should use 'cuda' when available." + else: + assert cosine_metric.device.type == "cpu", "Should use 'cpu' when CUDA is unavailable." + + +def test_score_mixed_input_types(cosine_metric): + """Test score function with mixed input types.""" + emb1 = np.random.rand(768) + emb2 = torch.rand(768) + + all_scores = call_and_call_with_inputs_swapped(emb1, emb2, cosine_metric.score) + assert all([isinstance(score, float) for score in all_scores]), "Output should be a float." + + +def test_score_all_mixed_input_types(cosine_metric, distance_range_checker, distance_matrix_shape_checker): + """Test score function with mixed input types.""" + hyps = np.random.rand(5, 768) + refs = torch.rand(3, 768) + + expected_shape = (5, 3) + + call_with_both_input_orders_and_do_standard_checks( + hyps, + refs, + cosine_metric.score_all, + distance_range_checker, + distance_matrix_shape_checker, + expected_shape=expected_shape, + ) + + +@pytest.mark.parametrize("num_points, dim", [(16, 2)]) +def test_unit_circle_points(cosine_metric, num_points, dim, distance_range_checker, distance_matrix_shape_checker): + embeddings = generate_unit_circle_points(num_points, dim) + distances = cosine_metric.score_all(embeddings, embeddings) + save_and_plot_distances(distances=distances, matrix_name="Unit Circle", num_points=num_points, dim=dim) + distance_range_checker(distances, min_val=0, max_val=2) # Check distance range + distance_matrix_shape_checker(embeddings.shape[0], embeddings.shape[0], distances) + + +@pytest.mark.parametrize("num_points, dim", [(20, 2)]) +def test_orthogonal_rows_with_repeats_2d(cosine_metric, num_points, dim): + embeddings = generate_orthogonal_rows_with_repeats(num_points, dim) + distances = cosine_metric.score_all(embeddings, embeddings) + save_and_plot_distances( + distances=distances, matrix_name="Orthogonal Rows (with repeats)", num_points=num_points, dim=dim + ) + + # Create expected pattern directly within the test function + expected_pattern = torch.zeros(num_points, num_points, dtype=distances.dtype) + for i in range(num_points): + for j in range(num_points): + if (i + j) % 2 != 0: + expected_pattern[i, j] = 1 + + # We expect 0 1 0 across and down + assert torch.allclose( + distances, expected_pattern, atol=1e-6 + ), "Output does not match the expected alternating pattern" + + +@pytest.mark.parametrize("num_points, dim", [(20, 2)]) +def test_orthogonal_rows_in_pairs( + cosine_metric, num_points, dim, distance_range_checker, distance_matrix_shape_checker +): + embeddings = generate_orthogonal_rows_in_pairs(num_points, dim) + distances = cosine_metric.score_all(embeddings, embeddings) + save_and_plot_distances(distances, "orthogonal_rows_in_pairs", num_points, dim) + distance_range_checker(distances, min_val=0, max_val=2) # Check distance range + distance_matrix_shape_checker(embeddings.shape[0], embeddings.shape[0], distances) + + +@pytest.mark.parametrize("num_points, dim", [(10, 5)]) +def test_ones_tensor(cosine_metric, num_points, dim, distance_range_checker, distance_matrix_shape_checker): + embeddings = generate_ones_tensor(num_points, dim) + distances = cosine_metric.score_all(embeddings, embeddings) + save_and_plot_distances(distances, "ones_tensor", num_points, dim) + distance_range_checker(distances, min_val=0, max_val=0) # Expect all distances to be 0 + distance_matrix_shape_checker(embeddings.shape[0], embeddings.shape[0], distances) + + +@pytest.mark.parametrize("num_points, dim", [(15, 15)]) # dim should be equal to num_points for identity matrix +def test_identity_matrix_rows(cosine_metric, num_points, dim, distance_range_checker, distance_matrix_shape_checker): + embeddings = generate_identity_matrix_rows(num_points, dim) + distances = cosine_metric.score_all(embeddings, embeddings) + save_and_plot_distances(distances, "identity_matrix_rows", num_points, dim) + distance_range_checker(distances, min_val=0, max_val=2) # Check distance range + distance_matrix_shape_checker(embeddings.shape[0], embeddings.shape[0], distances) diff --git a/pyproject.toml b/pyproject.toml index b38c04e..893fa3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,12 @@ readme = "README.md" dependencies = [ "pose-format", "scipy", + "torch", + "numpy", # possibly could replace all with torch + # for various vector/tensor similarities and distances in torch + "sentence-transformers", + # For reading .csv files, etc + "pandas", # For segment similarity "sign_language_segmentation @ git+https://github.com/sign-language-processing/segmentation" ] @@ -34,6 +40,7 @@ disable = [ "C0115", # Missing class docstring "C0116", # Missing function or method docstring "W0511", # TODO + "W1203", # use lazy % formatting in logging functions ] [tool.setuptools]