From fd4ffc0f7e3426dd88e3808a0976ab5d6f4447a9 Mon Sep 17 00:00:00 2001 From: Goyo Date: Thu, 7 Nov 2024 12:46:25 +0100 Subject: [PATCH] Unify RayClient usage (#1534) --- .../core/clients/local_client.py | 59 +++++---------- .../core/clients/ray_client.py | 33 +++++---- .../core/local_functions_store.py | 73 +++++++++++++++++++ client/tests/core/test_pattern.py | 3 +- 4 files changed, 113 insertions(+), 55 deletions(-) create mode 100644 client/qiskit_serverless/core/local_functions_store.py diff --git a/client/qiskit_serverless/core/clients/local_client.py b/client/qiskit_serverless/core/clients/local_client.py index c0c28a64e..237808959 100644 --- a/client/qiskit_serverless/core/clients/local_client.py +++ b/client/qiskit_serverless/core/clients/local_client.py @@ -49,6 +49,7 @@ Configuration, ) from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction +from qiskit_serverless.core.local_functions_store import LocalFunctionsStore from qiskit_serverless.exception import QiskitServerlessException from qiskit_serverless.serializers.program_serializers import ( QiskitObjectsEncoder, @@ -69,7 +70,7 @@ def __init__(self): super().__init__("local-client") self.in_test = os.getenv("IN_TEST") self._jobs = {} - self._patterns = [] + self._functions = LocalFunctionsStore(self) @classmethod def from_dict(cls, dictionary: dict): @@ -92,33 +93,30 @@ def run( config: Optional[Configuration] = None, ) -> Job: # pylint: disable=too-many-locals - title = "" - if isinstance(program, QiskitFunction): - title = program.title - else: - title = str(program) - - for pattern in self._patterns: - if pattern["title"] == title: - saved_program = pattern - if saved_program[ # pylint: disable=possibly-used-before-assignment - "dependencies" - ]: - dept = json.loads(saved_program["dependencies"]) - for dependency in dept: + title = program.title if isinstance(program, QiskitFunction) else str(program) + + saved_program = self.function(title) + + if not saved_program: + raise QiskitServerlessException( + "QiskitFunction provided is not uploaded to the client. Use upload() first." + ) + + if saved_program.dependencies: + for dependency in saved_program.dependencies: subprocess.check_call( [sys.executable, "-m", "pip", "install", dependency] ) arguments = arguments or {} env_vars = { - **(saved_program["env_vars"] or {}), - **{OT_PROGRAM_NAME: saved_program["title"]}, + **(saved_program.env_vars or {}), + **{OT_PROGRAM_NAME: saved_program.title}, **{"PATH": os.environ["PATH"]}, **{ENV_JOB_ARGUMENTS: json.dumps(arguments, cls=QiskitObjectsEncoder)}, } with Popen( - ["python", saved_program["working_dir"] + saved_program["entrypoint"]], + ["python", saved_program.working_dir + saved_program.entrypoint], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, @@ -165,31 +163,12 @@ def filtered_logs(self, job_id: str, **kwargs): def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]: # check if entrypoint exists - if not os.path.exists(os.path.join(program.working_dir, program.entrypoint)): - raise QiskitServerlessException( - f"Entrypoint file [{program.entrypoint}] does not exist " - f"in [{program.working_dir}] working directory." - ) - - pattern = { - "title": program.title, - "provider": program.provider, - "entrypoint": program.entrypoint, - "working_dir": program.working_dir, - "env_vars": program.env_vars, - "arguments": json.dumps({}), - "dependencies": json.dumps(program.dependencies or []), - "client": self, - } - self._patterns.append(pattern) - return RunnableQiskitFunction.from_json(pattern) + return self._functions.upload(program) def functions(self, **kwargs) -> List[RunnableQiskitFunction]: - """Returns list of programs.""" - return [RunnableQiskitFunction.from_json(program) for program in self._patterns] + return self._functions.functions() def function( self, title: str, provider: Optional[str] = None ) -> Optional[RunnableQiskitFunction]: - functions = {function.title: function for function in self.functions()} - return functions.get(title) + return self._functions.function(title) diff --git a/client/qiskit_serverless/core/clients/ray_client.py b/client/qiskit_serverless/core/clients/ray_client.py index cd2eaaa90..e039c983c 100644 --- a/client/qiskit_serverless/core/clients/ray_client.py +++ b/client/qiskit_serverless/core/clients/ray_client.py @@ -27,7 +27,6 @@ """ # pylint: disable=duplicate-code import json -import warnings from typing import Optional, List, Dict, Any, Union from uuid import uuid4 @@ -43,6 +42,8 @@ Job, ) from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction +from qiskit_serverless.core.local_functions_store import LocalFunctionsStore +from qiskit_serverless.exception import QiskitServerlessException from qiskit_serverless.serializers.program_serializers import ( QiskitObjectsEncoder, ) @@ -64,6 +65,7 @@ def __init__(self, host: str): """ super().__init__("ray-client", host) self.job_submission_client = JobSubmissionClient(host) + self._functions = LocalFunctionsStore(self) @classmethod def from_dict(cls, dictionary: dict): @@ -104,20 +106,23 @@ def run( arguments: Optional[Dict[str, Any]] = None, config: Optional[Configuration] = None, ) -> Job: - if not isinstance(program, QiskitFunction): - warnings.warn( - "`run` doesn't support program str yet. " - "Send a QiskitFunction instead. " + # pylint: disable=too-many-locals + title = program.title if isinstance(program, QiskitFunction) else str(program) + + saved_program = self.function(title) + + if not saved_program: + raise QiskitServerlessException( + "QiskitFunction provided is not uploaded to the client. Use upload() first." ) - raise NotImplementedError arguments = arguments or {} - entrypoint = f"python {program.entrypoint}" + entrypoint = f"python {saved_program.entrypoint}" # set program name so OT can use it as parent span name env_vars = { - **(program.env_vars or {}), - **{OT_PROGRAM_NAME: program.title}, + **(saved_program.env_vars or {}), + **{OT_PROGRAM_NAME: saved_program.title}, **{ENV_JOB_ARGUMENTS: json.dumps(arguments, cls=QiskitObjectsEncoder)}, } @@ -125,8 +130,8 @@ def run( entrypoint=entrypoint, submission_id=f"qs_{uuid4()}", runtime_env={ - "working_dir": program.working_dir, - "pip": program.dependencies, + "working_dir": saved_program.working_dir, + "pip": saved_program.dependencies, "env_vars": env_vars, }, ) @@ -160,14 +165,14 @@ def filtered_logs(self, job_id: str, **kwargs) -> str: def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]: """Uploads program.""" - raise NotImplementedError("Upload is not available for RayClient.") + return self._functions.upload(program) def functions(self, **kwargs) -> List[RunnableQiskitFunction]: """Returns list of available programs.""" - raise NotImplementedError("get_programs is not available for RayClient.") + return self._functions.functions() def function( self, title: str, provider: Optional[str] = None ) -> Optional[RunnableQiskitFunction]: """Returns program based on parameters.""" - raise NotImplementedError("get_program is not available for RayClient.") + return self._functions.function(title) diff --git a/client/qiskit_serverless/core/local_functions_store.py b/client/qiskit_serverless/core/local_functions_store.py new file mode 100644 index 000000000..7358ec86c --- /dev/null +++ b/client/qiskit_serverless/core/local_functions_store.py @@ -0,0 +1,73 @@ +# This code is a Qiskit project. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +================================================ +Provider (:mod:`qiskit_serverless.core.client`) +================================================ + +.. currentmodule:: qiskit_serverless.core.client + +Qiskit Serverless provider +=========================== + +.. autosummary:: + :toctree: ../stubs/ + + LocalFunctionsStore +""" +# pylint: disable=duplicate-code +import os.path +import os +from typing import Optional, List +from qiskit_serverless.core.client import BaseClient +from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction +from qiskit_serverless.exception import QiskitServerlessException + + +class LocalFunctionsStore: + """LocalClient.""" + + def __init__(self, client: BaseClient): + self.client = client + self._functions: List[RunnableQiskitFunction] = [] + + def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]: + """Save a function in the store""" + if not os.path.exists(os.path.join(program.working_dir, program.entrypoint)): + raise QiskitServerlessException( + f"Entrypoint file [{program.entrypoint}] does not exist " + f"in [{program.working_dir}] working directory." + ) + + pattern = { + "title": program.title, + "provider": program.provider, + "entrypoint": program.entrypoint, + "working_dir": program.working_dir, + "env_vars": program.env_vars, + "arguments": {}, + "dependencies": program.dependencies or [], + "client": self.client, + } + runnable_function = RunnableQiskitFunction.from_json(pattern) + self._functions.append(runnable_function) + return runnable_function + + def functions(self) -> List[RunnableQiskitFunction]: + """Returns list of functions.""" + return list(self._functions) + + def function(self, title: str) -> Optional[RunnableQiskitFunction]: + """Returns a function with the provided title.""" + functions = {function.title: function for function in self.functions()} + return functions.get(title) diff --git a/client/tests/core/test_pattern.py b/client/tests/core/test_pattern.py index 99cbd2dd0..7b2444308 100644 --- a/client/tests/core/test_pattern.py +++ b/client/tests/core/test_pattern.py @@ -33,8 +33,9 @@ def test_program(): description="description", version="0.0.1", ) + uploaded_program = serverless.upload(program) - job = serverless.run(program) + job = serverless.run(uploaded_program) assert isinstance(job, Job)