Skip to content

Commit

Permalink
Trying to figure out pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
cleong110 committed Jan 10, 2025
1 parent e9e8cc1 commit 4334fb8
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion pose_evaluation/metrics/test_distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_get_test_poses():





@pytest.mark.parametrize("metric_name", DISTANCE_KINDS_TO_CHECK)
def test_scores_are_symmetric(metric_name: ValidDistanceKinds):

Expand Down Expand Up @@ -87,4 +89,36 @@ def test_scores_equal_length(metric_name:ValidDistanceKinds):
score = metric.score(hyp, ref)
point_count = np.prod(hyp.body.confidence.shape)
assert np.isclose(score, expected_distance)
assert isinstance(score, float) # Check if the score is a float
assert isinstance(score, float) # Check if the score is a float

@pytest.mark.parametrize('kind', DISTANCE_KINDS_TO_CHECK)
def test_all_distance_metrics_and_kinds(DistanceMetricToTest, kind):
metric = DistanceMetricToTest(kind)
assert isinstance(metric, DistanceMetric)


def get_all_subclasses(base_class):
"""Recursively discover all subclasses of a given base class."""
subclasses = set(base_class.__subclasses__())
for subclass in base_class.__subclasses__():
subclasses.update(get_all_subclasses(subclass))
return subclasses

def generate_test_cases(base_class, kinds):
"""Generate tuples of (metric_class, kind) for parameterization."""
subclasses = get_all_subclasses(base_class)
return [(subclass, kind) for subclass in subclasses for kind in kinds]

# Parameterize with (metric_class, kind)
@pytest.mark.parametrize("metric_class,kind", generate_test_cases(DistanceMetric, DISTANCE_KINDS_TO_CHECK))
def test_distance_metric_calculations(metric_class, kind):
"""Test all DistanceMetric subclasses with various 'kinds'."""
metric = metric_class(kind)

# if kind == "default":
# result = metric.calculate(3, 7)
# elif kind == "weighted" and hasattr(metric, "calculate"): # Check for additional arguments
# result = metric.calculate(3, 7) # Modify if weighted args are supported
# else:
# pytest.skip(f"{metric_class} does not support kind '{kind}'")
# assert result is not None # Example assertion

0 comments on commit 4334fb8

Please sign in to comment.