Skip to content

Commit

Permalink
CDL: a few pylint changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cleong110 committed Dec 5, 2024
1 parent 4934c5d commit a495c67
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions pose_evaluation/metrics/test_embedding_distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Device configuration for PyTorch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(DEVICE) # so that we get arrays on the same device
torch.set_default_device(DEVICE) # so that we get arrays on the same device


# named the fixture this way to solve many pylint W0621
Expand All @@ -36,11 +36,14 @@ def fixture_embeddings() -> List[torch.Tensor]:
"""Fixture to create dummy embeddings for testing."""
return [random_tensor(768) for _ in range(5)]

def call_and_call_with_inputs_swapped(hyp:torch.Tensor, ref:torch.Tensor, scoring_function:Callable[[torch.Tensor, torch.Tensor], torch.Tensor])->Tuple[torch.Tensor, torch.Tensor]:

def call_and_call_with_inputs_swapped(
hyp: torch.Tensor, ref: torch.Tensor, scoring_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
score1 = scoring_function(hyp, ref)
score2 = scoring_function(ref, hyp)
return score1, score2


def save_and_plot_distances(distances, matrix_name, num_points, dim):
"""Helper function to save distance matrix and plot distances."""
Expand Down Expand Up @@ -211,12 +214,12 @@ def test_score_all_with_one_vs_batch(cosine_metric, distance_range_checker):

# scores = cosine_metric.score_all(hyps, refs)
scores, scores2 = call_and_call_with_inputs_swapped(hyps, refs, cosine_metric.score_all)


assert scores.shape == (len(hyps), 1)
assert scores2.shape == (1, len(hyps))
distance_range_checker(scores, min_val=0, max_val=2)


def test_score_all_with_different_sizes(cosine_metric, distance_range_checker):
"""Test score_all with different sizes for hypotheses and references."""
hyps = [np.random.rand(768) for _ in range(3)]
Expand All @@ -239,33 +242,33 @@ def test_invalid_input_mismatched_embedding_sizes(cosine_metric: EmbeddingDistan
# TODO: we should probably raise a more descriptive/helpful error/ ValueError
call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score)


def test_invalid_input_single_number(cosine_metric: EmbeddingDistanceMetric) -> None:
hyp = random_tensor(768)
for ref in range (-2, 2):
for ref in range(-2, 2):
with pytest.raises(IndexError):
# TODO: we should probably raise a more descriptive/helpful error/ ValueError
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score)

logger.info("Invalid input successfully crashed as expected.")


def test_invalid_input_noncontainernonnumber_types(cosine_metric: EmbeddingDistanceMetric) -> None:
hyp = random_tensor(768)
invalid_inputs = [
"invalid_input",
True
]
invalid_inputs = ["invalid_input", True]
for ref in invalid_inputs:
with pytest.raises((TypeError, IndexError)):
# TypeError: new(): invalid data type 'str'
# but True gives IndexError
# TODO: better TypeError, more descriptive
call_and_call_with_inputs_swapped(hyp, ref, cosine_metric.score)


def test_invalid_input_empty_containers(cosine_metric: EmbeddingDistanceMetric) -> None:
"""Test the metric with invalid inputs."""
emb1 = random_tensor(768)
invalid_inputs = ["", list(), dict(), tuple(), set()]
invalid_inputs = ["", [], {}, tuple(), set()]

for invalid_input in invalid_inputs:
with pytest.raises((RuntimeError, TypeError, IndexError)):
Expand Down Expand Up @@ -295,14 +298,16 @@ def test_score_ndarray_input(cosine_metric):
score = cosine_metric.score(emb1, emb2)
assert isinstance(score, float), "Output should be a float."


def test_score_all_list_of_lists_of_floats(cosine_metric):
"""Does a 2D list of floats work? """
"""Does a 2D list of floats work?"""
hyps = [[np.random.rand() for _ in range(768)] for _ in range(5)]
refs = [[np.random.rand() for _ in range(768)] for _ in range(5)]
scores = cosine_metric.score_all(hyps, refs)
assert len(scores) == len(hyps), f"Output row count mismatch for torch.Tensor input. Shape:{scores.shape}"
assert len(scores[0]) == len(refs), f"Output column count mismatch for torch.Tensor input. Shape:{scores.shape}"


def test_score_all_list_of_tensor_input(cosine_metric):
"""Test score_all function with List of torch.Tensor inputs."""
hyps = [torch.rand(768) for _ in range(5)]
Expand All @@ -312,14 +317,15 @@ def test_score_all_list_of_tensor_input(cosine_metric):
assert len(scores) == len(hyps), f"Output row count mismatch for torch.Tensor input. Shape:{scores.shape}"
assert len(scores[0]) == len(refs), f"Output column count mismatch for torch.Tensor input. Shape:{scores.shape}"


def test_score_all_list_of_ndarray_input(cosine_metric):
"""Test score_all function with List of np.ndarray inputs."""
hyps = [np.random.rand(768) for _ in range(5)]
refs = [np.random.rand(768) for _ in range(5)]

scores = cosine_metric.score_all(hyps, refs)
assert len(scores) == len(hyps), f"Output row count mismatch for torch.Tensor input. Shape:{scores.shape}"
assert len(scores[0]) == len(refs), f"Output column count mismatch for torch.Tensor input. Shape:{scores.shape}"
assert len(scores[0]) == len(refs), f"Output column count mismatch for torch.Tensor input. Shape:{scores.shape}"


def test_device_handling(cosine_metric):
Expand Down

0 comments on commit a495c67

Please sign in to comment.