Skip to content

Commit

Permalink
Various pylint changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cleong110 committed Dec 4, 2024
1 parent 12f612c commit 3ca874e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
3 changes: 2 additions & 1 deletion pose_evaluation/metrics/base_embedding_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeVar
from pose_evaluation.metrics.base import BaseMetric
import torch
from pose_evaluation.metrics.base import BaseMetric


# Define a type alias for embeddings (e.g., torch.Tensor)
Embedding = TypeVar("Embedding", bound=torch.Tensor)
Expand Down
25 changes: 12 additions & 13 deletions pose_evaluation/metrics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# conftest.py
import pytest
import shutil
from pathlib import Path
from typing import Callable, Union
import torch
import numpy as np
import pytest

@pytest.fixture(scope="session", autouse=True)

@pytest.fixture(scope="session", autouse=True)
def clean_test_artifacts():
"""Fixture to clean up test artifacts before each test session."""
test_artifacts_dir = Path(__file__).parent / "tests" # Using Path
Expand All @@ -17,19 +17,18 @@ def clean_test_artifacts():
# (Optional) You can add cleanup logic here to run after the session if needed


# conftest.py
from typing import Callable, Union
import torch
import numpy as np

@pytest.fixture
def distance_range_checker() -> Callable[[Union[torch.Tensor, np.ndarray], float, float], None]:
@pytest.fixture(name="distance_range_checker")
def fixture_distance_range_checker() -> Callable[[Union[torch.Tensor, np.ndarray], float, float], None]:
def _check_range(distances: Union[torch.Tensor, np.ndarray], min_val: float = 0, max_val: float = 2) -> None:
max_distance = distances.max().item()
min_distance = distances.min().item()

# Use np.isclose for comparisons with tolerance
assert np.isclose(min_distance, min_val, atol=1e-6) or min_val <= min_distance <= max_val, f"Minimum distance ({min_distance}) is outside the expected range [{min_val}, {max_val}]"
assert np.isclose(max_distance, max_val, atol=1e-6) or min_val <= max_distance <= max_val, f"Maximum distance ({max_distance}) is outside the expected range [{min_val}, {max_val}]"
assert (
np.isclose(min_distance, min_val, atol=1e-6) or min_val <= min_distance <= max_val
), f"Minimum distance ({min_distance}) is outside the expected range [{min_val}, {max_val}]"
assert (
np.isclose(max_distance, max_val, atol=1e-6) or min_val <= max_distance <= max_val
), f"Maximum distance ({max_distance}) is outside the expected range [{min_val}, {max_val}]"

return _check_range
return _check_range
25 changes: 15 additions & 10 deletions pose_evaluation/metrics/test_embedding_distance_metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from pathlib import Path
from typing import List
import logging
import pytest
import numpy as np
import matplotlib.pyplot as plt
import torch
from pose_evaluation.metrics.embedding_distance_metric import EmbeddingDistanceMetric
from pose_evaluation.metrics.conftest import distance_range_checker
import matplotlib.pyplot as plt
import logging
from typing import List
from pathlib import Path

# no need to import. https://github.com/pylint-dev/pylint/issues/3493#issuecomment-616761997
# from pose_evaluation.metrics.conftest import distance_range_checker


# TODO: many fixes. Including the fact that we test cosine but not Euclidean,

Expand All @@ -19,14 +22,16 @@
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture
def cosine_metric():
# named the fixture this way to solve many pylint W0621
# https://stackoverflow.com/questions/46089480/pytest-fixtures-redefining-name-from-outer-scope-pylint
@pytest.fixture(name="cosine_metric")
def fixture_cosine_metric():
"""Fixture to create an EmbeddingDistanceMetric instance."""
return EmbeddingDistanceMetric(kind="cosine")


@pytest.fixture
def embeddings() -> List[torch.Tensor]:
@pytest.fixture(name="embeddings")
def fixture_embeddings() -> List[torch.Tensor]:
"""Fixture to create dummy embeddings for testing."""
return [random_tensor(768) for _ in range(5)]

Expand Down Expand Up @@ -185,7 +190,7 @@ def test_score_all_against_self(
scores = cosine_metric.score_all(embeddings, embeddings)
assert scores.shape == (len(embeddings), len(embeddings)), "Output shape mismatch for score_all."
assert torch.allclose(
torch.diagonal(scores), torch.zeros(len(embeddings),dtype=scores.dtype), atol=1e-6
torch.diagonal(scores), torch.zeros(len(embeddings), dtype=scores.dtype), atol=1e-6
), "Self-comparison scores should be zero for cosine distance."
distance_range_checker(scores, min_val=0, max_val=2)
logger.info(f"Score matrix shape: {scores.shape}, Diagonal values: {torch.diagonal(scores)}")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
# For reading .csv files, etc
"pandas",
# For segment similarity
#"sign_language_segmentation @ git+https://github.com/sign-language-processing/segmentation"
"sign_language_segmentation @ git+https://github.com/sign-language-processing/segmentation"
]

[project.optional-dependencies]
Expand All @@ -40,6 +40,7 @@ disable = [
"C0115", # Missing class docstring
"C0116", # Missing function or method docstring
"W0511", # TODO
"W1203", # use lazy % formatting in logging functions
]

[tool.setuptools]
Expand Down

0 comments on commit 3ca874e

Please sign in to comment.