diff --git a/client/qiskit_serverless/__init__.py b/client/qiskit_serverless/__init__.py index 5daefbeec..50e0c25fe 100644 --- a/client/qiskit_serverless/__init__.py +++ b/client/qiskit_serverless/__init__.py @@ -34,6 +34,7 @@ LocalClient, save_result, Configuration, + is_running_in_serverless, ) from .exception import QiskitServerlessException from .core.function import QiskitPattern, QiskitFunction diff --git a/client/qiskit_serverless/core/__init__.py b/client/qiskit_serverless/core/__init__.py index 2852c01d7..8d100aabe 100644 --- a/client/qiskit_serverless/core/__init__.py +++ b/client/qiskit_serverless/core/__init__.py @@ -44,6 +44,7 @@ get put get_refs_by_status + is_running_in_serverless """ @@ -57,6 +58,7 @@ Job, save_result, Configuration, + is_running_in_serverless, ) from .function import QiskitPattern, QiskitFunction from .decorators import ( diff --git a/client/qiskit_serverless/core/job.py b/client/qiskit_serverless/core/job.py index beec864fd..017b1557a 100644 --- a/client/qiskit_serverless/core/job.py +++ b/client/qiskit_serverless/core/job.py @@ -288,3 +288,8 @@ def _map_status_to_serverless(status: str) -> str: return status_map[status] except KeyError: return status + + +def is_running_in_serverless() -> bool: + """Return ``True`` if running as a Qiskit serverless program, ``False`` otherwise.""" + return "ENV_JOB_ID_GATEWAY" in os.environ diff --git a/client/tests/core/test_job.py b/client/tests/core/test_job.py index f0edcce0d..13321bf8c 100644 --- a/client/tests/core/test_job.py +++ b/client/tests/core/test_job.py @@ -2,10 +2,10 @@ # pylint: disable=too-few-public-methods import os -from unittest import TestCase from unittest.mock import MagicMock, Mock, patch import numpy as np +import pytest import requests_mock from qiskit.circuit.random import random_circuit @@ -16,7 +16,19 @@ ENV_JOB_ID_GATEWAY, ENV_JOB_GATEWAY_TOKEN, ) -from qiskit_serverless.core.job import save_result +from qiskit_serverless.core.job import is_running_in_serverless, save_result + + +# pylint: disable=redefined-outer-name +@pytest.fixture() +def job_env_variables(monkeypatch): + """Fixture to set mock job environment variables.""" + # Inspired by https://stackoverflow.com/a/77256931/1558890 + with patch.dict(os.environ, clear=True): + monkeypatch.setenv(ENV_JOB_GATEWAY_HOST, "https://awesome-tests.com/") + monkeypatch.setenv(ENV_JOB_ID_GATEWAY, "42") + monkeypatch.setenv(ENV_JOB_GATEWAY_TOKEN, "awesome-token") + yield # Restore the environment after the test runs class ResponseMock: @@ -26,15 +38,12 @@ class ResponseMock: text = "{}" -class TestJob(TestCase): +class TestJob: """TestJob.""" - def test_save_result(self): + def test_save_result(self, job_env_variables): """Tests job save result.""" - - os.environ[ENV_JOB_GATEWAY_HOST] = "https://awesome-tests.com/" - os.environ[ENV_JOB_ID_GATEWAY] = "42" - os.environ[ENV_JOB_GATEWAY_TOKEN] = "awesome-token" + _ = job_env_variables url = ( f"{os.environ.get(ENV_JOB_GATEWAY_HOST)}/" @@ -48,7 +57,7 @@ def test_save_result(self): "quantum_circuit": random_circuit(3, 2), } ) - self.assertTrue(result) + assert result is True @patch("requests.get", Mock(return_value=ResponseMock())) def test_filtered_logs(self): @@ -66,3 +75,16 @@ def test_filtered_logs(self): assert "This is the line 1\n" == client.filtered_logs( "id", include="This is the l.+", exclude="the.+a.+l" ) + + +class TestRunningAsServerlessProgram: + """Test ``is_running_in_serverless()``.""" + + def test_not_running_as_serverless_program(self): + """Test ``is_running_in_serverless()`` outside a serverless program.""" + assert is_running_in_serverless() is False + + def test_running_as_serverless_program(self, job_env_variables): + """Test ``is_running_in_serverless()`` in a mocked serverless program.""" + _ = job_env_variables + assert is_running_in_serverless() is True