From c0d0409c72f4c77f3b832feddd6974341f9c07f5 Mon Sep 17 00:00:00 2001 From: Paul Schweigert Date: Thu, 26 Sep 2024 09:47:57 -0400 Subject: [PATCH] Provide support for large job arguments (#1501) * pass arguments by file instead of by envvar Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * more lint Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * update tests Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * fix test Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * always overwrite arguments file Signed-off-by: Paul S. Schweigert * set envvar if args less than 1MB Signed-off-by: Paul S. Schweigert * fix test Signed-off-by: Paul S. Schweigert * fix test better Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * comma Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * more accurate sizing of args Signed-off-by: Paul S. Schweigert * lint Signed-off-by: Paul S. Schweigert * pass invalid json for envvar args to force failure Signed-off-by: Paul S. Schweigert * fix tests Signed-off-by: Paul S. Schweigert * update arguments when using envvar Signed-off-by: Paul S. Schweigert --------- Signed-off-by: Paul S. Schweigert --- .../serializers/program_serializers.py | 8 +++++--- .../serializers/test_program_serializers.py | 8 +++----- gateway/api/ray.py | 13 +++++++++++++ gateway/api/utils.py | 17 +++++++++++++++-- gateway/requirements.txt | 3 ++- gateway/tests/api/test_utils.py | 10 ++++------ 6 files changed, 42 insertions(+), 17 deletions(-) diff --git a/client/qiskit_serverless/serializers/program_serializers.py b/client/qiskit_serverless/serializers/program_serializers.py index 023b23a82..a98072d9a 100644 --- a/client/qiskit_serverless/serializers/program_serializers.py +++ b/client/qiskit_serverless/serializers/program_serializers.py @@ -33,8 +33,6 @@ from qiskit_ibm_runtime import QiskitRuntimeService from qiskit_ibm_runtime.utils.json import RuntimeDecoder, RuntimeEncoder -from qiskit_serverless.core.constants import ENV_JOB_ARGUMENTS - class QiskitObjectsEncoder(RuntimeEncoder): """Json encoder for Qiskit objects.""" @@ -81,4 +79,8 @@ def get_arguments() -> Dict[str, Any]: Returns: Dictionary of arguments. """ - return json.loads(os.environ.get(ENV_JOB_ARGUMENTS, "{}"), cls=QiskitObjectsDecoder) + arguments = "{}" + if os.path.isfile("arguments.serverless"): + with open("arguments.serverless", "r", encoding="utf-8") as f: + arguments = f.read() + return json.loads(arguments, cls=QiskitObjectsDecoder) diff --git a/client/tests/serializers/test_program_serializers.py b/client/tests/serializers/test_program_serializers.py index dd5d899ec..93ac794a1 100644 --- a/client/tests/serializers/test_program_serializers.py +++ b/client/tests/serializers/test_program_serializers.py @@ -12,14 +12,12 @@ """QiskitPattern serializers tests.""" import json -import os from unittest import TestCase, skip import numpy as np from qiskit.circuit.random import random_circuit from qiskit_ibm_runtime import QiskitRuntimeService -from qiskit_serverless.core.constants import ENV_JOB_ARGUMENTS from qiskit_serverless.serializers.program_serializers import ( QiskitObjectsDecoder, QiskitObjectsEncoder, @@ -61,8 +59,8 @@ def test_argument_parsing(self): circuit = random_circuit(4, 2) array = np.array([[42.0], [0.0]]) - os.environ[ENV_JOB_ARGUMENTS] = json.dumps( - {"circuit": circuit, "array": array}, cls=QiskitObjectsEncoder - ) + with open("arguments.serverless", "w", encoding="utf-8") as f: + json.dump({"circuit": circuit, "array": array}, f, cls=QiskitObjectsEncoder) + parsed_arguments = get_arguments() self.assertEqual(list(parsed_arguments.keys()), ["circuit", "array"]) diff --git a/gateway/api/ray.py b/gateway/api/ray.py index e12779db1..0dd03e3af 100644 --- a/gateway/api/ray.py +++ b/gateway/api/ray.py @@ -118,6 +118,19 @@ def submit(self, job: Job) -> Optional[str]: # get entrypoint entrypoint = f"python {program.entrypoint}" + # upload arguments to working directory + # if no arguments, write an empty dict to the arguments file + with open( + working_directory_for_upload + "/arguments.serverless", + "w", + encoding="utf-8", + ) as f: + if job.arguments: + logger.debug("uploading arguments for job %s", job.id) + f.write(job.arguments) + else: + f.write({}) + # set tracing carrier = {} TraceContextTextMapPropagator().inject(carrier) diff --git a/gateway/api/utils.py b/gateway/api/utils.py index d70e0b8fc..7fdece114 100644 --- a/gateway/api/utils.py +++ b/gateway/api/utils.py @@ -17,6 +17,7 @@ from ray.dashboard.modules.job.common import JobStatus from django.conf import settings from parsley import makeGrammar +import objsize from .models import Job @@ -112,18 +113,30 @@ def decrypt_string(string: str) -> str: return fernet.decrypt(string.encode("utf-8")).decode("utf-8") -def build_env_variables(token, job: Job, arguments: str) -> Dict[str, str]: +def build_env_variables(token, job: Job, args: str = None) -> Dict[str, str]: """Builds env variables for job. Args: token: django request token decoded job: job - arguments: program arguments Returns: env variables dict """ extra = {} + # only set arguments envvar if not too big + # remove this after sufficient time for users to upgrade client + arguments = "{}" + if args: + if objsize.get_deep_size(args) < 100000: + logger.debug("passing arguments as envvar for job %s", job.id) + arguments = args + else: + logger.warning( + "arguments for job %s are too large and will not be written to env var", + job.id, + ) + if settings.SETTINGS_AUTH_MECHANISM != "default": extra = { "QISKIT_IBM_TOKEN": str(token), diff --git a/gateway/requirements.txt b/gateway/requirements.txt index d29458c91..ccf9e62cd 100644 --- a/gateway/requirements.txt +++ b/gateway/requirements.txt @@ -22,4 +22,5 @@ qiskit-ibm-runtime>=0.29.0 tzdata>=2024.1 django-cors-headers>=4.4.0, <5 parsley>=1.3, <2 -whitenoise>=6.7.0, <7 \ No newline at end of file +whitenoise>=6.7.0, <7 +objsize>=0.7.0 diff --git a/gateway/tests/api/test_utils.py b/gateway/tests/api/test_utils.py index dcf80ab62..26171c1db 100644 --- a/gateway/tests/api/test_utils.py +++ b/gateway/tests/api/test_utils.py @@ -24,28 +24,26 @@ def test_build_env_for_job(self): token = "42" job = MagicMock() job.id = "42" - env_vars = build_env_variables(token=token, job=job, arguments={"answer": 42}) + env_vars = build_env_variables(token=token, job=job) self.assertEqual( env_vars, { "ENV_JOB_GATEWAY_TOKEN": "42", "ENV_JOB_GATEWAY_HOST": "http://localhost:8000", "ENV_JOB_ID_GATEWAY": "42", - "ENV_JOB_ARGUMENTS": {"answer": 42}, + "ENV_JOB_ARGUMENTS": "{}", }, ) with self.settings( SETTINGS_AUTH_MECHANISM="custom_token", SECRET_KEY="super-secret" ): - env_vars_with_qiskit_runtime = build_env_variables( - token=token, job=job, arguments={"answer": 42} - ) + env_vars_with_qiskit_runtime = build_env_variables(token=token, job=job) expecting = { "ENV_JOB_GATEWAY_TOKEN": "42", "ENV_JOB_GATEWAY_HOST": "http://localhost:8000", "ENV_JOB_ID_GATEWAY": "42", - "ENV_JOB_ARGUMENTS": {"answer": 42}, + "ENV_JOB_ARGUMENTS": "{}", "QISKIT_IBM_TOKEN": "42", "QISKIT_IBM_CHANNEL": "ibm_quantum", "QISKIT_IBM_URL": "https://auth.quantum-computing.ibm.com/api",