Skip to content

Commit

Permalink
Fixing CI for pytorch version tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Nov 10, 2023
1 parent 72487f4 commit 478c79d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from packaging.version import Version
from sklearn.metrics import accuracy_score

import ignite.distributed as idist
Expand Down Expand Up @@ -550,6 +551,7 @@ def update(_, i):
@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_multilabel_input_NHW(device)
Expand All @@ -561,6 +563,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_multilabel_input_NHW(device)
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/metrics/test_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from packaging.version import Version

import ignite.distributed as idist
from ignite.engine import Engine
Expand Down Expand Up @@ -161,6 +162,7 @@ def update(engine, i):
@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_integration_multiclass(device, True)
Expand All @@ -171,6 +173,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_gloo_cpu_or_gpu(local_rank, distributed_context_single_node_gloo):
device = idist.device()
_test_integration_multiclass(device, True)
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numbers
import os
from packaging.version import Version
from typing import Dict, List
from unittest.mock import MagicMock

Expand Down Expand Up @@ -710,6 +711,7 @@ def _test_creating_on_xla_fails(device):
@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_sync_all_reduce_decorator(device)
Expand All @@ -722,6 +724,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_sync_all_reduce_decorator(device)
Expand Down

0 comments on commit 478c79d

Please sign in to comment.