Skip to content

Commit

Permalink
Fix failing mps tests (#3145)
Browse files Browse the repository at this point in the history
* Fix failing mps tests

* Fix failing tests

* tests

* remove unnecessary changes

* Fix flake8 errors

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
pranavvp16 and vfdev-5 authored Nov 27, 2023
1 parent 8fb3ae2 commit 514e2f8
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 18 deletions.
5 changes: 2 additions & 3 deletions tests/ignite/distributed/comp_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_serial_model():
_SerialModel.create_from_backend()
model = _SerialModel.create_from_context()
Expand All @@ -19,6 +16,8 @@ def test_serial_model():
assert model.get_node_rank() == 0
if torch.cuda.is_available():
assert model.device().type == "cuda"
elif _torch_version_le_112 and torch.backends.mps.is_available():
assert model.device().type == "mps"
else:
assert model.device().type == "cpu"
assert model.backend() is None
Expand Down
8 changes: 2 additions & 6 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from ignite.distributed.comp_models.base import _torch_version_le_112


class DummyDS(Dataset):
Expand Down Expand Up @@ -180,16 +179,13 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="DistributedSampler")

_test_auto_model_optimizer(1, "cuda" if torch.cuda.is_available() else "cpu")
device = idist.device()
_test_auto_model_optimizer(1, device)


@pytest.mark.distributed
Expand Down
6 changes: 1 addition & 5 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support


Expand Down Expand Up @@ -258,11 +257,8 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname,


@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_idist_parallel_no_dist():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = idist.device()
with idist.Parallel(backend=None) as parallel:
parallel.run(_test_func, ws=1, device=device, backend=None, true_init_method=None)

Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch

import ignite.distributed as idist
Expand All @@ -15,13 +14,12 @@
)


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
assert idist.device().type == "cuda"
elif _torch_version_le_112 and torch.backends.mps.is_available():
assert idist.device().type == "mps"
else:
assert idist.device().type == "cpu"
assert idist.get_rank() == 0
Expand All @@ -43,6 +41,8 @@ def test_no_distrib(capsys):
assert "ignite.distributed.utils INFO: backend: None" in out[-1]
if torch.cuda.is_available():
assert "ignite.distributed.utils INFO: device: cuda" in out[-1]
elif _torch_version_le_112 and torch.backends.mps.is_available():
assert "ignite.distributed.utils INFO: device: mps" in out[-1]
else:
assert "ignite.distributed.utils INFO: device: cpu" in out[-1]
assert "ignite.distributed.utils INFO: rank: 0" in out[-1]
Expand Down

0 comments on commit 514e2f8

Please sign in to comment.