From a5ee7ae79b66afe43552efef6539e56fb1c39359 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 22 Nov 2023 03:16:16 +0100 Subject: [PATCH] Add support for `MPS` Backend [without torch.amp.autocast ] + CI (#3041) * Add support for `MPS` Backend [without torch.amp.autocast ] (#2993) * Add support for "mps" device in ignite.distributed.base * Made changes in the supervised_trainer API to have mps devices, Added some tests * autopep8 fix * Added lint fixes * Setup ci for mps tests --------- Co-authored-by: guptaaryan16 Co-authored-by: vfdev * Fixed mps-tests.yml * More fixes to mps-tests.yml * Set up working directory in mps yml * another try * another try * another try * added version check and skipped mps tests * code formatting * another code formatting fix --------- Co-authored-by: Aryan Gupta <97878444+guptaaryan16@users.noreply.github.com> Co-authored-by: guptaaryan16 --- .github/workflows/mps-tests.yml | 125 ++++++++++++++++++ .github/workflows/unit-tests.yml | 1 + ignite/distributed/comp_models/base.py | 5 + ignite/engine/__init__.py | 19 ++- .../distributed/comp_models/test_base.py | 5 +- tests/ignite/distributed/test_auto.py | 4 + tests/ignite/distributed/test_launcher.py | 4 + tests/ignite/distributed/utils/test_serial.py | 5 + tests/ignite/engine/test_create_supervised.py | 40 +++++- 9 files changed, 199 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/mps-tests.yml diff --git a/.github/workflows/mps-tests.yml b/.github/workflows/mps-tests.yml new file mode 100644 index 00000000000..8e48c6053c6 --- /dev/null +++ b/.github/workflows/mps-tests.yml @@ -0,0 +1,125 @@ +name: Run unit tests on M1 +on: + push: + branches: + - master + - "*.*.*" + paths: + - "ignite/**" + - "tests/ignite/**" + - "tests/run_code_style.sh" + - "examples/**.py" + - "requirements-dev.txt" + - ".github/workflows/mps-tests.yml" + pull_request: + paths: + - "ignite/**" + - "tests/ignite/**" + - "tests/run_code_style.sh" + - "examples/**.py" + - "requirements-dev.txt" + - ".github/workflows/mps-tests.yml" + workflow_dispatch: + +concurrency: + # -- + group: mps-tests-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }} + cancel-in-progress: true + +# Cherry-picked from +# - https://github.com/pytorch/vision/main/.github/workflows/tests.yml +# - https://github.com/pytorch/test-infra/blob/main/.github/workflows/macos_job.yml + +jobs: + mps-tests: + strategy: + matrix: + python-version: [3.8] + pytorch-channel: ["pytorch"] + skip-distrib-tests: [1] + fail-fast: false + runs-on: ["macos-m1-12"] + timeout-minutes: 60 + + steps: + - name: Clean workspace + run: | + echo "::group::Cleanup debug output" + sudo rm -rfv "${GITHUB_WORKSPACE}" + mkdir -p "${GITHUB_WORKSPACE}" + echo "::endgroup::" + + - name: Checkout repository (pytorch/test-infra) + uses: actions/checkout@v3 + with: + # Support the use case where we need to checkout someone's fork + repository: pytorch/test-infra + path: test-infra + + - name: Checkout repository (${{ github.repository }}) + uses: actions/checkout@v3 + with: + # Support the use case where we need to checkout someone's fork + repository: ${{ github.repository }} + ref: ${{ github.ref }} + path: ${{ github.repository }} + fetch-depth: 1 + + - name: Setup miniconda + uses: ./test-infra/.github/actions/setup-miniconda + with: + python-version: ${{ matrix.python-version }} + + - name: Install PyTorch + if: ${{ matrix.pytorch-channel == 'pytorch' }} + shell: bash -l {0} + run: pip install torch torchvision + + - name: Install PyTorch (nightly) + if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }} + shell: bash -l {0} + run: pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + + - name: Install dependencies + shell: bash -l {0} + working-directory: ${{ github.repository }} + run: | + # TODO: We add set -xe to explicitly fail the CI if one of the commands is failing. + # Somehow the step is passing even if a subcommand failed + set -xe + pip install -r requirements-dev.txt + echo "1 returned code: $?" + pip install -e . + echo "2 returned code: $?" + pip list + echo "3 returned code: $?" + + # Download MNIST: https://github.com/pytorch/ignite/issues/1737 + # to "/tmp" for unit tests + - name: Download MNIST + uses: pytorch-ignite/download-mnist-github-action@master + with: + target_dir: /tmp + + # Copy MNIST to "." for the examples + - name: Copy MNIST + run: | + cp -R /tmp/MNIST . + + - name: Run Tests + shell: bash -l {0} + working-directory: ${{ github.repository }} + run: | + SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ${{ github.repository }}/coverage.xml + flags: mps + fail_ci_if_error: false + + - name: Run MNIST Examples + shell: bash -l {0} + working-directory: ${{ github.repository }} + run: python examples/mnist/mnist.py --epochs=1 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 23ac6b42c9c..7673b79f792 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -93,6 +93,7 @@ jobs: run: | pip install -r requirements-dev.txt python setup.py install + pip list - name: Check code formatting run: | diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 00d4383d1ac..6e86193381c 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -3,6 +3,9 @@ from typing import Any, Callable, cast, List, Optional, Union import torch +from packaging.version import Version + +_torch_version_le_112 = Version(torch.__version__) > Version("1.12.0") class ComputationModel(metaclass=ABCMeta): @@ -326,6 +329,8 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") + if _torch_version_le_112 and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") def backend(self) -> Optional[str]: diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 60d6f7690b2..a67dbe08ee1 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -96,6 +96,8 @@ def supervised_training_step( Added `model_transform` to transform model's output .. versionchanged:: 0.4.13 Added `model_fn` to customize model's application on the sample + .. versionchanged:: 0.4.14 + Added support for ``mps`` device """ if gradient_accumulation_steps <= 0: @@ -391,9 +393,12 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to def _check_arg( - on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]] + on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]] ) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]: - """Checking tpu, amp and GradScaler instance combinations.""" + """Checking tpu, mps, amp and GradScaler instance combinations.""" + if on_mps and amp_mode: + raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.") + if on_tpu and not idist.has_xla_support: raise RuntimeError("In order to run on TPU, please install PyTorch XLA") @@ -546,11 +551,14 @@ def output_transform_fn(x, y, y_pred, loss): Added ``model_transform`` to transform model's output .. versionchanged:: 0.4.13 Added `model_fn` to customize model's application on the sample + .. versionchanged:: 0.4.14 + Added support for ``mps`` device """ device_type = device.type if isinstance(device, torch.device) else device on_tpu = "xla" in device_type if device_type is not None else False - mode, _scaler = _check_arg(on_tpu, amp_mode, scaler) + on_mps = "mps" in device_type if device_type is not None else False + mode, _scaler = _check_arg(on_tpu, on_mps, amp_mode, scaler) if mode == "amp": _update = supervised_training_step_amp( @@ -791,10 +799,13 @@ def create_supervised_evaluator( Added ``model_transform`` to transform model's output .. versionchanged:: 0.4.13 Added `model_fn` to customize model's application on the sample + .. versionchanged:: 0.4.14 + Added support for ``mps`` device """ device_type = device.type if isinstance(device, torch.device) else device on_tpu = "xla" in device_type if device_type is not None else False - mode, _ = _check_arg(on_tpu, amp_mode, None) + on_mps = "mps" in device_type if device_type is not None else False + mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None) metrics = metrics or {} if mode == "amp": diff --git a/tests/ignite/distributed/comp_models/test_base.py b/tests/ignite/distributed/comp_models/test_base.py index cd0244ea903..ef4fd62e293 100644 --- a/tests/ignite/distributed/comp_models/test_base.py +++ b/tests/ignite/distributed/comp_models/test_base.py @@ -1,9 +1,12 @@ import pytest import torch -from ignite.distributed.comp_models.base import _SerialModel, ComputationModel +from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel +@pytest.mark.skipif( + _torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available" +) def test_serial_model(): _SerialModel.create_from_backend() model = _SerialModel.create_from_context() diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index cde9892b8de..c571e0e45b8 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,6 +12,7 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler +from ignite.distributed.comp_models.base import _torch_version_le_112 class DummyDS(Dataset): @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device): assert optimizer.backward_passes_per_step == backward_passes_per_step +@pytest.mark.skipif( + _torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available" +) def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) _test_auto_dataloader(1, 1, batch_size=10, num_workers=2) diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index 04e1e20b7c0..1a1bd801e1c 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -8,6 +8,7 @@ from packaging.version import Version import ignite.distributed as idist +from ignite.distributed.comp_models.base import _torch_version_le_112 from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support @@ -257,6 +258,9 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname, @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +@pytest.mark.skipif( + _torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available" +) def test_idist_parallel_no_dist(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with idist.Parallel(backend=None) as parallel: diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index 1fee2bb8ce1..f1b650f56ef 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -1,6 +1,8 @@ +import pytest import torch import ignite.distributed as idist +from ignite.distributed.comp_models.base import _torch_version_le_112 from tests.ignite.distributed.utils import ( _sanity_check, _test_distrib__get_max_length, @@ -13,6 +15,9 @@ ) +@pytest.mark.skipif( + _torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available" +) def test_no_distrib(capsys): assert idist.backend() is None if torch.cuda.is_available(): diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 8d001a8d2cc..48a703d60a1 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -12,6 +12,7 @@ from torch.optim import SGD import ignite.distributed as idist +from ignite.distributed.comp_models.base import _torch_version_le_112 from ignite.engine import ( _check_arg, create_supervised_evaluator, @@ -196,7 +197,8 @@ def _test_create_mocked_supervised_trainer( data = [(x, y)] on_tpu = "xla" in trainer_device if trainer_device is not None else False - mode, _ = _check_arg(on_tpu, amp_mode, scaler) + on_mps = "mps" in trainer_device if trainer_device is not None else False + mode, _ = _check_arg(on_tpu, on_mps, amp_mode, scaler) if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")): trainer.run(data) @@ -306,7 +308,9 @@ def _test_create_supervised_evaluator( else: if Version(torch.__version__) >= Version("1.7.0"): # This is broken in 1.6.0 but will be probably fixed with 1.7.0 - with pytest.raises(RuntimeError, match=r"Expected all tensors to be on the same device"): + err_msg_1 = "Expected all tensors to be on the same device" + err_msg_2 = "Placeholder storage has not been allocated on MPS device" + with pytest.raises(RuntimeError, match=f"({err_msg_1}|{err_msg_2})"): evaluator.run(data) @@ -358,7 +362,8 @@ def _test_create_evaluation_step_amp( device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device on_tpu = "xla" in device_type if device_type is not None else False - mode, _ = _check_arg(on_tpu, amp_mode, None) + on_mps = "mps" in device_type if device_type is not None else False + mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None) evaluate_step = supervised_evaluation_step_amp(model, evaluator_device, output_transform=output_transform_mock) @@ -393,7 +398,8 @@ def _test_create_evaluation_step( device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device on_tpu = "xla" in device_type if device_type is not None else False - mode, _ = _check_arg(on_tpu, amp_mode, None) + on_mps = "mps" in device_type if device_type is not None else False + mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None) evaluate_step = supervised_evaluation_step(model, evaluator_device, output_transform=output_transform_mock) @@ -475,6 +481,19 @@ def test_create_supervised_trainer_on_cuda(): _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS") +def test_create_supervised_trainer_on_mps(): + model_device = trainer_device = "mps" + _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) + _test_create_supervised_trainer( + gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device + ) + _test_create_supervised_trainer( + gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device + ) + _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) + + @pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") def test_create_supervised_trainer_on_cuda_amp(): @@ -643,6 +662,19 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu(): _test_mocked_supervised_evaluator(evaluator_device="cuda") +@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +def test_create_supervised_evaluator_on_mps(): + model_device = evaluator_device = "mps" + _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) + _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) + + +@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +def test_create_supervised_evaluator_on_mps_with_model_on_cpu(): + _test_create_supervised_evaluator(evaluator_device="mps") + _test_mocked_supervised_evaluator(evaluator_device="mps") + + @pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") def test_create_supervised_evaluator_on_cuda_amp():