Skip to content

Commit

Permalink
Reuse _IN_CI in tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzweilin committed May 9, 2024
1 parent 626e94a commit b367071
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, open_dict

from .test_utils import _IN_CI

root = Path(os.getcwd())
pyrootutils.set_root(path=root, dotenv=True, pythonpath=True)

Expand All @@ -29,7 +31,7 @@
"ImageNet_Timm",
]

if os.getenv("CI") == "true":
if _IN_CI:
# Test all experiments on CI
experiments_names = experiments_require_torchvision + experiments_require_torchvision_and_timm
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from tests.helpers.run_if import RunIf
from tests.helpers.run_sh_command import run_sh_command

from .test_utils import _IN_CI

module = "mart"

coco_ds = {
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_cifar10_cnn_experiment(classification_cfg, tmp_path):

@RunIf(sh=True)
@pytest.mark.slow
@pytest.mark.skipif(not _HAS_TIMM, reason="test requires that timm is installed")
@pytest.mark.skipif(not _IN_CI and not _HAS_TIMM, reason="test requires that timm is installed")
def test_imagenet_timm_experiment(classification_cfg, tmp_path):
"""Test ImageNet Timm experiment."""
overrides = classification_cfg["trainer"] + classification_cfg["datamodel"]
Expand Down
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#

import os

import pytest

from mart.utils import flatten_dict
Expand All @@ -18,3 +20,10 @@ def test_flatten_dict_key_collision():
d = {"a": 1, "b": {"c": 2, "d": 3}, "b.c": 4}
with pytest.raises(KeyError):
flatten_dict(d)


def in_ci():
return os.getenv("CI") == "true"


_IN_CI = in_ci()

0 comments on commit b367071

Please sign in to comment.