From 0279a23e7cc7c0b15b01a2820cb31879e57b6909 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 19 Apr 2021 17:06:09 +0200 Subject: [PATCH] remove torch as install requirement (#14) * remove torch as install requirement * make the warning more concise --- pyproject.toml | 2 +- pytest_pytorch/plugin.py | 23 +++++++++++++++++++++-- setup.cfg | 1 - tox.ini | 1 + 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a1cfe0..6ce930f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ order_by_type = true combine_star = true filter_files = true -known_third_party = ["pytest"] +known_third_party = ["pytest", "_pytest"] known_first_party = ["torch", "pytest_pytorch"] known_local_folder = ["tests"] diff --git a/pytest_pytorch/plugin.py b/pytest_pytorch/plugin.py index 5e665b6..39bad08 100644 --- a/pytest_pytorch/plugin.py +++ b/pytest_pytorch/plugin.py @@ -1,11 +1,27 @@ import re import unittest.mock +import warnings from typing import Pattern from _pytest.unittest import TestCaseFunction, UnitTestCase -from torch.testing._internal.common_device_type import get_device_type_test_bases -from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate +try: + from torch.testing._internal.common_device_type import get_device_type_test_bases + from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + warnings.warn( + "Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported." + ) + + def get_device_type_test_bases(): + return [] + + class PyTorchTestCaseTemplate: + pass class PytestPyTorchInternalError(Exception): @@ -93,6 +109,9 @@ def collect(self): def pytest_pycollect_makeitem(collector, name, obj): + if not TORCH_AVAILABLE: + return None + try: if ( not issubclass(obj, PyTorchTestCaseTemplate) diff --git a/setup.cfg b/setup.cfg index d25164f..7a0838d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,6 @@ include_package_data = True python_requires = >=3.6 install_requires = pytest - torch [options.packages.find] exclude = diff --git a/tox.ini b/tox.ini index 4e3ff85..149e434 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ pytorch_channel = nightly deps = pytest >= 6 pytest-mock >= 3.1 + torch # The nightlies do not specify numpy as requirement numpy commands =