Skip to content

Commit

Permalink
Skip object detection tests if pycocotools is not installed. (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzweilin authored May 16, 2024
1 parent 3689976 commit d349977
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, open_dict

from mart.utils.imports import _HAS_TIMM, _HAS_TORCHVISION
from mart.utils.imports import _HAS_PYCOCOTOOLS, _HAS_TIMM, _HAS_TORCHVISION

root = Path(os.getcwd())
pyrootutils.set_root(path=root, dotenv=True, pythonpath=True)

experiments_require_torchvision = [
"CIFAR10_CNN",
"CIFAR10_CNN_Adv",
]

experiments_require_torchvision_pycocotools = [
"COCO_TorchvisionFasterRCNN",
"COCO_TorchvisionFasterRCNN_Adv",
"COCO_TorchvisionRetinaNet",
Expand All @@ -35,6 +38,8 @@
experiments_names = []
if _HAS_TORCHVISION:
experiments_names += experiments_require_torchvision
if _HAS_TORCHVISION and _HAS_PYCOCOTOOLS:
experiments_names += experiments_require_torchvision_pycocotools
if _HAS_TIMM and _HAS_TORCHVISION:
experiments_names += experiments_require_torchvision_and_timm

Expand Down
17 changes: 13 additions & 4 deletions tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from hydra.core.global_hydra import GlobalHydra

from mart.utils.imports import _HAS_TIMM, _HAS_TORCHVISION
from mart.utils.imports import _HAS_PYCOCOTOOLS, _HAS_TIMM, _HAS_TORCHVISION
from tests.helpers.dataset_generator import FakeCOCODataset
from tests.helpers.run_if import RunIf
from tests.helpers.run_sh_command import run_sh_command
Expand Down Expand Up @@ -131,7 +131,10 @@ def test_imagenet_timm_experiment(classification_cfg, tmp_path):

@RunIf(sh=True)
@pytest.mark.slow
@pytest.mark.skipif(not _HAS_TORCHVISION, reason="test requires that torchvision is installed")
@pytest.mark.skipif(
not _HAS_TORCHVISION or not _HAS_PYCOCOTOOLS,
reason="test requires that torchvision and pycocotools are installed",
)
def test_coco_fasterrcnn_experiment(coco_cfg, tmp_path):
"""Test TorchVision FasterRCNN experiment."""
overrides = coco_cfg["trainer"] + coco_cfg["datamodel"]
Expand All @@ -147,7 +150,10 @@ def test_coco_fasterrcnn_experiment(coco_cfg, tmp_path):

@RunIf(sh=True)
@pytest.mark.slow
@pytest.mark.skipif(not _HAS_TORCHVISION, reason="test requires that torchvision is installed")
@pytest.mark.skipif(
not _HAS_TORCHVISION or not _HAS_PYCOCOTOOLS,
reason="test requires that torchvision and pycocotools are installed",
)
def test_coco_fasterrcnn_adv_experiment(coco_cfg, tmp_path):
"""Test TorchVision FasterRCNN Adv experiment."""
overrides = coco_cfg["trainer"] + coco_cfg["datamodel"]
Expand All @@ -163,7 +169,10 @@ def test_coco_fasterrcnn_adv_experiment(coco_cfg, tmp_path):

@RunIf(sh=True)
@pytest.mark.slow
@pytest.mark.skipif(not _HAS_TORCHVISION, reason="test requires that torchvision is installed")
@pytest.mark.skipif(
not _HAS_TORCHVISION or not _HAS_PYCOCOTOOLS,
reason="test requires that torchvision and pycocotools installed",
)
def test_coco_retinanet_experiment(coco_cfg, tmp_path):
"""Test TorchVision RetinaNet experiment."""
overrides = coco_cfg["trainer"] + coco_cfg["datamodel"]
Expand Down

0 comments on commit d349977

Please sign in to comment.