Skip to content

Commit

Permalink
Add support for MPS Backend [without torch.amp.autocast ] + CI (#3041)
Browse files Browse the repository at this point in the history
* Add support for `MPS` Backend [without torch.amp.autocast ] (#2993)

* Add support for "mps" device in ignite.distributed.base

* Made changes in the supervised_trainer API to have mps devices, Added some tests

* autopep8 fix

* Added lint fixes

* Setup ci for mps tests

---------

Co-authored-by: guptaaryan16 <[email protected]>
Co-authored-by: vfdev <[email protected]>

* Fixed mps-tests.yml

* More fixes to mps-tests.yml

* Set up working directory in mps yml

* another try

* another try

* another try

* added version check and skipped mps tests

* code formatting

* another code formatting fix

---------

Co-authored-by: Aryan Gupta <[email protected]>
Co-authored-by: guptaaryan16 <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2023
1 parent 4dc4e04 commit a5ee7ae
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 9 deletions.
125 changes: 125 additions & 0 deletions .github/workflows/mps-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: Run unit tests on M1
on:
push:
branches:
- master
- "*.*.*"
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
pull_request:
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
workflow_dispatch:

concurrency:
# <workflow_name>-<branch_name>-<true || commit_sha (if branch is protected)>
group: mps-tests-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
cancel-in-progress: true

# Cherry-picked from
# - https://github.com/pytorch/vision/main/.github/workflows/tests.yml
# - https://github.com/pytorch/test-infra/blob/main/.github/workflows/macos_job.yml

jobs:
mps-tests:
strategy:
matrix:
python-version: [3.8]
pytorch-channel: ["pytorch"]
skip-distrib-tests: [1]
fail-fast: false
runs-on: ["macos-m1-12"]
timeout-minutes: 60

steps:
- name: Clean workspace
run: |
echo "::group::Cleanup debug output"
sudo rm -rfv "${GITHUB_WORKSPACE}"
mkdir -p "${GITHUB_WORKSPACE}"
echo "::endgroup::"
- name: Checkout repository (pytorch/test-infra)
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: pytorch/test-infra
path: test-infra

- name: Checkout repository (${{ github.repository }})
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: ${{ github.repository }}
ref: ${{ github.ref }}
path: ${{ github.repository }}
fetch-depth: 1

- name: Setup miniconda
uses: ./test-infra/.github/actions/setup-miniconda
with:
python-version: ${{ matrix.python-version }}

- name: Install PyTorch
if: ${{ matrix.pytorch-channel == 'pytorch' }}
shell: bash -l {0}
run: pip install torch torchvision

- name: Install PyTorch (nightly)
if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }}
shell: bash -l {0}
run: pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu

- name: Install dependencies
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: |
# TODO: We add set -xe to explicitly fail the CI if one of the commands is failing.
# Somehow the step is passing even if a subcommand failed
set -xe
pip install -r requirements-dev.txt
echo "1 returned code: $?"
pip install -e .
echo "2 returned code: $?"
pip list
echo "3 returned code: $?"
# Download MNIST: https://github.com/pytorch/ignite/issues/1737
# to "/tmp" for unit tests
- name: Download MNIST
uses: pytorch-ignite/download-mnist-github-action@master
with:
target_dir: /tmp

# Copy MNIST to "." for the examples
- name: Copy MNIST
run: |
cp -R /tmp/MNIST .
- name: Run Tests
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: |
SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ${{ github.repository }}/coverage.xml
flags: mps
fail_ci_if_error: false

- name: Run MNIST Examples
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: python examples/mnist/mnist.py --epochs=1
1 change: 1 addition & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
run: |
pip install -r requirements-dev.txt
python setup.py install
pip list
- name: Check code formatting
run: |
Expand Down
5 changes: 5 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any, Callable, cast, List, Optional, Union

import torch
from packaging.version import Version

_torch_version_le_112 = Version(torch.__version__) > Version("1.12.0")


class ComputationModel(metaclass=ABCMeta):
Expand Down Expand Up @@ -326,6 +329,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if _torch_version_le_112 and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")

def backend(self) -> Optional[str]:
Expand Down
19 changes: 15 additions & 4 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def supervised_training_step(
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""

if gradient_accumulation_steps <= 0:
Expand Down Expand Up @@ -391,9 +393,12 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to


def _check_arg(
on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
"""Checking tpu, amp and GradScaler instance combinations."""
"""Checking tpu, mps, amp and GradScaler instance combinations."""
if on_mps and amp_mode:
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")

if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

Expand Down Expand Up @@ -546,11 +551,14 @@ def output_transform_fn(x, y, y_pred, loss):
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""

device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if mode == "amp":
_update = supervised_training_step_amp(
Expand Down Expand Up @@ -791,10 +799,13 @@ def create_supervised_evaluator(
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

metrics = metrics or {}
if mode == "amp":
Expand Down
5 changes: 4 additions & 1 deletion tests/ignite/distributed/comp_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
import torch

from ignite.distributed.comp_models.base import _SerialModel, ComputationModel
from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_serial_model():
_SerialModel.create_from_backend()
model = _SerialModel.create_from_context()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from ignite.distributed.comp_models.base import _torch_version_le_112


class DummyDS(Dataset):
Expand Down Expand Up @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support


Expand Down Expand Up @@ -257,6 +258,9 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname,


@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_idist_parallel_no_dist():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with idist.Parallel(backend=None) as parallel:
Expand Down
5 changes: 5 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import torch

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from tests.ignite.distributed.utils import (
_sanity_check,
_test_distrib__get_max_length,
Expand All @@ -13,6 +15,9 @@
)


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
Expand Down
40 changes: 36 additions & 4 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim import SGD

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.engine import (
_check_arg,
create_supervised_evaluator,
Expand Down Expand Up @@ -196,7 +197,8 @@ def _test_create_mocked_supervised_trainer(
data = [(x, y)]

on_tpu = "xla" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")):
trainer.run(data)
Expand Down Expand Up @@ -306,7 +308,9 @@ def _test_create_supervised_evaluator(
else:
if Version(torch.__version__) >= Version("1.7.0"):
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
with pytest.raises(RuntimeError, match=r"Expected all tensors to be on the same device"):
err_msg_1 = "Expected all tensors to be on the same device"
err_msg_2 = "Placeholder storage has not been allocated on MPS device"
with pytest.raises(RuntimeError, match=f"({err_msg_1}|{err_msg_2})"):
evaluator.run(data)


Expand Down Expand Up @@ -358,7 +362,8 @@ def _test_create_evaluation_step_amp(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step_amp(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -393,7 +398,8 @@ def _test_create_evaluation_step(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -475,6 +481,19 @@ def test_create_supervised_trainer_on_cuda():
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS")
def test_create_supervised_trainer_on_mps():
model_device = trainer_device = "mps"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_amp():
Expand Down Expand Up @@ -643,6 +662,19 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda_amp():
Expand Down

0 comments on commit a5ee7ae

Please sign in to comment.