From ccd159a48c101f7f0215b527becc5651b7f94d0d Mon Sep 17 00:00:00 2001 From: David <9059044+Tansito@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:33:41 -0500 Subject: [PATCH] Improve data directory functionality (#1509) * Create services storage script * Files list end point refactor (#1529) * First refactorization for files * Completed use cases for list end-point * Renamed storage to file_storage * Migrated files to file_storage * Migrated get files to the new service * Check if function exists first * Separated end-points to the new provider * Improved tests * Restore original fixtures * Solved the problem with the external if * Added Literal type in the FileStorage service * functions methods refactorization * use self.username instead of the variable * run path sanitization only one time * function title is mandatory for file storage * fix black * Updated 403 by 404 * Updated swagger for list * Makes use of enum with integers * use is instead of equals * unify path build * get_function refactor * Add support new files (#1546) * new list files for the gateway refactor * unify files url * remove enum and linter * remove provider parameter * Check provider * Download end point refactor (#1547) * refactor of download end-points * additional test to check non existing file * renamed files test to v1_files * added additional checks for the query * included the checks in the list end-point * include swagger documentation updated * remove unneeded try except * make use of regex instead of a manual parsing * check not all instead of None * rename file_extension_is_valid * Update gateway/api/utils.py Co-authored-by: Goyo --------- Co-authored-by: Goyo * add support to new download files endpoints (#1550) * Gateway/delete end point (#1554) * add support to new delete files endpoints * fix get parameter * fix tests * delete programDelete fixture * update file_storage/remove_file description * Upload end point (#1555) * Update upload end-point * removed unused imports * updated comment in upload end-point * remove file extension limitation from files (#1559) * Client - files upload refactor (#1557) * add support to new upload file endpoints * replace data with params * Client - files delete refactor (#1556) * add support to new delete files endpoints * fix get parameter * fix tests * adapt files delete to the refactor * replace data with params * remove not used fixture * Integration tests fix (#1561) * add support to new delete files endpoints * fix get parameter * fix tests * adapt files delete to the refactor * replace data with params * remove not used fixture * tests fixed * fix client integration * fix context manager in files when download * added a new test for provider end-points * fix black * migrated old tests * Update gateway/api/services/file_storage.py Co-authored-by: Goyo * remove additional line --------- Co-authored-by: David <9059044+Tansito@users.noreply.github.com> * Trace decorator (#1553) * create a trace decorator * fix decorator * change documentation --------- Co-authored-by: David <9059044+Tansito@users.noreply.github.com> * Files refactor to include repositories logic (#1560) * remove unneeded provider methods * create access policies file * refactor get_functions repository method * refactor groups and repositories * repository refactor from files * fix some linter problems * fixed a bug when the user retrieves a function * fix lint * refactor of get_function method * remove artifact test file * remove programs access policies * refactor programs references to functions * group repository refactor * rename groups repository into user repository * simplified get_function methods * fix query * adapt get_functions methods * updated comments * create path if doesn't exist * remove some unused code * fix files client * fix typos * fixed the creation of the directory * added a test for the provider end-points * fix some typos from the provider end-points * fix black on tests * files documentation updated (#1565) * Refactor storage in scheduler (#1563) * refactor storage in scheduler * fix typo in ray template * remove unused node_image * use sub-paths for the cluster template * storage absolute path refactor * docstrings updated in the client --------- Co-authored-by: Goyo --- .../gateway/templates/rayclustertemplate.yaml | 6 +- .../core/clients/serverless_client.py | 253 ++++---- client/qiskit_serverless/core/decorators.py | 35 + client/qiskit_serverless/core/files.py | 236 +++++-- .../experimental/file_download.ipynb | 70 +- .../experimental/manage_data_directory.ipynb | 206 ++++-- gateway/api/access_policies/__init__.py | 0 gateway/api/access_policies/providers.py | 38 ++ gateway/api/ray.py | 33 +- gateway/api/repositories/functions.py | 224 +++++++ gateway/api/repositories/programs.py | 201 ------ gateway/api/repositories/providers.py | 32 + gateway/api/repositories/users.py | 31 + gateway/api/services/__init__.py | 0 gateway/api/services/file_storage.py | 239 +++++++ gateway/api/utils.py | 21 +- gateway/api/v1/views/files.py | 195 +++++- gateway/api/views/files.py | 607 +++++++++++++----- gateway/api/views/programs.py | 23 +- gateway/tests/api/test_files.py | 257 -------- gateway/tests/api/test_v1_files.py | 505 +++++++++++++++ gateway/tests/fixtures/files_fixtures.json | 93 +++ .../Program/provider_program_artifact.tar | Bin 0 -> 10240 bytes .../default/Program/user_program_artifact.tar | Bin 0 -> 10240 bytes tests/docker/test_docker_experimental.py | 121 +++- tests/experimental/file_download.py | 4 +- tests/experimental/manage_data_directory.py | 18 +- 27 files changed, 2517 insertions(+), 931 deletions(-) create mode 100644 gateway/api/access_policies/__init__.py create mode 100644 gateway/api/access_policies/providers.py create mode 100644 gateway/api/repositories/functions.py delete mode 100644 gateway/api/repositories/programs.py create mode 100644 gateway/api/repositories/providers.py create mode 100644 gateway/api/repositories/users.py create mode 100644 gateway/api/services/__init__.py create mode 100644 gateway/api/services/file_storage.py delete mode 100644 gateway/tests/api/test_files.py create mode 100644 gateway/tests/api/test_v1_files.py create mode 100644 gateway/tests/fixtures/files_fixtures.json create mode 100644 gateway/tests/resources/fake_media/default/Program/provider_program_artifact.tar create mode 100644 gateway/tests/resources/fake_media/test_user_2/default/Program/user_program_artifact.tar diff --git a/charts/qiskit-serverless/charts/gateway/templates/rayclustertemplate.yaml b/charts/qiskit-serverless/charts/gateway/templates/rayclustertemplate.yaml index cc1f36771..8e27c5951 100644 --- a/charts/qiskit-serverless/charts/gateway/templates/rayclustertemplate.yaml +++ b/charts/qiskit-serverless/charts/gateway/templates/rayclustertemplate.yaml @@ -120,10 +120,10 @@ data: {{- end }} - mountPath: /data name: user-storage - subPath: {{`{{ user_id }}`}} + subPath: {{`{{ user_data_folder }}`}} - mountPath: /function_data name: user-storage - subPath: {{`{{ function_data }}`}} + subPath: {{`{{ provider_data_folder }}`}} env: # Environment variables for Ray TLS authentication. # See https://docs.ray.io/en/latest/ray-core/configure.html#tls-authentication for more details. @@ -184,7 +184,7 @@ data: {{- end }} - mountPath: /data name: user-storage - subPath: {{`{{ user_id }}`}} + subPath: {{`{{ user_data_folder }}`}} env: # Environment variables for Ray TLS authentication. # See https://docs.ray.io/en/latest/ray-core/configure.html#tls-authentication for more details. diff --git a/client/qiskit_serverless/core/clients/serverless_client.py b/client/qiskit_serverless/core/clients/serverless_client.py index 3478b9843..a7d215fdb 100644 --- a/client/qiskit_serverless/core/clients/serverless_client.py +++ b/client/qiskit_serverless/core/clients/serverless_client.py @@ -49,6 +49,7 @@ MAX_ARTIFACT_FILE_SIZE_MB, ) from qiskit_serverless.core.client import BaseClient +from qiskit_serverless.core.decorators import trace_decorator_factory from qiskit_serverless.core.files import GatewayFilesClient from qiskit_serverless.core.job import ( Job, @@ -72,6 +73,9 @@ QiskitObjectsDecoder, ) +_trace_job = trace_decorator_factory("job") +_trace_functions = trace_decorator_factory("function") + class ServerlessClient(BaseClient): """ @@ -146,47 +150,45 @@ def _verify_token(self, token: str): ####### JOBS ####### #################### + @_trace_job("list") def jobs(self, **kwargs) -> List[Job]: - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.list"): - limit = kwargs.get("limit", 10) - kwargs["limit"] = limit - offset = kwargs.get("offset", 0) - kwargs["offset"] = offset - - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/jobs", - params=kwargs, - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - ) + limit = kwargs.get("limit", 10) + kwargs["limit"] = limit + offset = kwargs.get("offset", 0) + kwargs["offset"] = offset + + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/jobs", + params=kwargs, + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) return [ Job(job.get("id"), job_service=self, raw_data=job) for job in response_data.get("results", []) ] + @_trace_job("get") def job(self, job_id: str) -> Optional[Job]: - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.get"): - url = f"{self.host}/api/{self.version}/jobs/{job_id}/" - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - url, - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - ) + url = f"{self.host}/api/{self.version}/jobs/{job_id}/" + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + url, + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) - job = None - job_id = response_data.get("id") - if job_id is not None: - job = Job( - job_id=job_id, - job_service=self, - ) + job = None + job_id = response_data.get("id") + if job_id is not None: + job = Job( + job_id=job_id, + job_service=self, + ) return job @@ -205,7 +207,7 @@ def run( tracer = trace.get_tracer("client.tracer") with tracer.start_as_current_span("job.run") as span: - span.set_attribute("program", title) + span.set_attribute("function", title) span.set_attribute("provider", provider) span.set_attribute("arguments", str(arguments)) @@ -234,66 +236,62 @@ def run( return Job(job_id, job_service=self) + @_trace_job def status(self, job_id: str): - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.status"): - default_status = "Unknown" - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/jobs/{job_id}/", - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - ) + default_status = "Unknown" + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/jobs/{job_id}/", + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) return response_data.get("status", default_status) + @_trace_job def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None): - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.stop"): - if service: - data = { - "service": json.dumps(service, cls=QiskitObjectsEncoder), - } - else: - data = { - "service": None, - } - response_data = safe_json_request_as_dict( - request=lambda: requests.post( - f"{self.host}/api/{self.version}/jobs/{job_id}/stop/", - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - json=data, - ) + if service: + data = { + "service": json.dumps(service, cls=QiskitObjectsEncoder), + } + else: + data = { + "service": None, + } + response_data = safe_json_request_as_dict( + request=lambda: requests.post( + f"{self.host}/api/{self.version}/jobs/{job_id}/stop/", + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, + json=data, ) + ) return response_data.get("message") + @_trace_job def result(self, job_id: str): - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.result"): - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/jobs/{job_id}/", - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - ) + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/jobs/{job_id}/", + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) return json.loads( response_data.get("result", "{}") or "{}", cls=QiskitObjectsDecoder ) + @_trace_job def logs(self, job_id: str): - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.logs"): - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/jobs/{job_id}/logs/", - headers={"Authorization": f"Bearer {self.token}"}, - timeout=REQUESTS_TIMEOUT, - ) + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/jobs/{job_id}/logs/", + headers={"Authorization": f"Bearer {self.token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) return response_data.get("logs") def filtered_logs(self, job_id: str, **kwargs): @@ -323,8 +321,8 @@ def filtered_logs(self, job_id: str, **kwargs): def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]: tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("job.run") as span: - span.set_attribute("program", program.title) + with tracer.start_as_current_span("function.upload") as span: + span.set_attribute("function", program.title) url = f"{self.host}/api/{self.version}/programs/upload/" if program.image is not None: @@ -344,18 +342,17 @@ def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]: return function_uploaded + @_trace_functions("list") def functions(self, **kwargs) -> List[RunnableQiskitFunction]: - """Returns list of available programs.""" - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("program.list"): - response_data = safe_json_request_as_list( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/programs", - headers={"Authorization": f"Bearer {self.token}"}, - params=kwargs, - timeout=REQUESTS_TIMEOUT, - ) + """Returns list of available functions.""" + response_data = safe_json_request_as_list( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/programs", + headers={"Authorization": f"Bearer {self.token}"}, + params=kwargs, + timeout=REQUESTS_TIMEOUT, ) + ) return [ RunnableQiskitFunction( @@ -368,6 +365,7 @@ def functions(self, **kwargs) -> List[RunnableQiskitFunction]: for program in response_data ] + @_trace_functions("get_by_title") def function( self, title: str, provider: Optional[str] = None ) -> Optional[RunnableQiskitFunction]: @@ -376,50 +374,73 @@ def function( request_provider=provider, title=title ) - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("program.get_by_title"): - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/programs/get_by_title/{title}", - headers={"Authorization": f"Bearer {self.token}"}, - params={"provider": provider}, - timeout=REQUESTS_TIMEOUT, - ) - ) - return RunnableQiskitFunction( - client=self, - title=response_data.get("title"), - provider=response_data.get("provider", None), - raw_data=response_data, + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + f"{self.host}/api/{self.version}/programs/get_by_title/{title}", + headers={"Authorization": f"Bearer {self.token}"}, + params={"provider": provider}, + timeout=REQUESTS_TIMEOUT, ) + ) + + return RunnableQiskitFunction( + client=self, + title=response_data.get("title"), + provider=response_data.get("provider", None), + raw_data=response_data, + ) ##################### ####### FILES ####### ##################### - def files(self, provider: Optional[str] = None) -> List[str]: - """Returns list of available files produced by programs to download.""" - return self._files_client.list(provider) + def files(self, function: QiskitFunction) -> List[str]: + """Returns the list of files available for the user in the Qiskit Function folder.""" + return self._files_client.list(function) + + def provider_files(self, function: QiskitFunction) -> List[str]: + """Returns the list of files available for the provider in the Qiskit Function folder.""" + return self._files_client.provider_list(function) def file_download( self, file: str, + function: QiskitFunction, target_name: Optional[str] = None, download_location: str = "./", - provider: Optional[str] = None, ): - """Download file.""" + """Download a file available to the user for the specific Qiskit Function.""" return self._files_client.download( - file, download_location, target_name, provider + file, download_location, function, target_name + ) + + def provider_file_download( + self, + file: str, + function: QiskitFunction, + target_name: Optional[str] = None, + download_location: str = "./", + ): + """Download a file available to the provider for the specific Qiskit Function.""" + return self._files_client.provider_download( + file, download_location, function, target_name ) - def file_delete(self, file: str, provider: Optional[str] = None): - """Deletes file uploaded or produced by the programs,""" - return self._files_client.delete(file, provider) + def file_delete(self, file: str, function: QiskitFunction): + """Deletes a file available to the user for the specific Qiskit Function.""" + return self._files_client.delete(file, function) + + def provider_file_delete(self, file: str, function: QiskitFunction): + """Deletes a file available to the provider for the specific Qiskit Function.""" + return self._files_client.provider_delete(file, function) + + def file_upload(self, file: str, function: QiskitFunction): + """Uploads a file in the specific user's Qiskit Function folder.""" + return self._files_client.upload(file, function) - def file_upload(self, file: str, provider: Optional[str] = None): - """Upload file.""" - return self._files_client.upload(file, provider) + def provider_file_upload(self, file: str, function: QiskitFunction): + """Uploads a file in the specific provider's Qiskit Function folder.""" + return self._files_client.provider_upload(file, function) class IBMServerlessClient(ServerlessClient): @@ -520,8 +541,8 @@ def _upload_with_docker_image( ) program_title = response_data.get("title", "na") program_provider = response_data.get("provider", "na") - span.set_attribute("program.title", program_title) - span.set_attribute("program.provider", program_provider) + span.set_attribute("function.title", program_title) + span.set_attribute("function.provider", program_provider) response_data["client"] = client return RunnableQiskitFunction.from_json(response_data) @@ -588,8 +609,8 @@ def _upload_with_artifact( timeout=REQUESTS_TIMEOUT, ) ) - span.set_attribute("program.title", response_data.get("title", "na")) - span.set_attribute("program.provider", response_data.get("provider", "na")) + span.set_attribute("function.title", response_data.get("title", "na")) + span.set_attribute("function.provider", response_data.get("provider", "na")) response_data["client"] = client response_function = RunnableQiskitFunction.from_json(response_data) except Exception as error: # pylint: disable=broad-exception-caught diff --git a/client/qiskit_serverless/core/decorators.py b/client/qiskit_serverless/core/decorators.py index ffccc9305..07ddf8947 100644 --- a/client/qiskit_serverless/core/decorators.py +++ b/client/qiskit_serverless/core/decorators.py @@ -34,6 +34,7 @@ import inspect import os import shutil +from types import FunctionType import warnings from dataclasses import dataclass from typing import Optional, Dict, Any, Union, List, Callable, Sequence @@ -451,3 +452,37 @@ def distribute_program( "Please, use `distribute_qiskit_function` instead." ) return distribute_qiskit_function(provider, dependencies, working_dir) + + +def trace_decorator_factory(traced_feature: str): + """Factory for generate decorators for classes or features.""" + + def generated_decorator(traced_function: Union[FunctionType, str]): + """ + The decorator wrapper to generate optional arguments + if traced_function is string it will be used in the span, + the function.__name__ attribute will be used otherwise + """ + + def decorator_trace(func: FunctionType): + """The decorator that python call""" + + def wrapper(*args, **kwargs): + """The wrapper""" + tracer = trace.get_tracer("client.tracer") + function_name = ( + traced_function + if isinstance(traced_function, str) + else func.__name__ + ) + with tracer.start_as_current_span(f"{traced_feature}.${function_name}"): + result = func(*args, **kwargs) + return result + + return wrapper + + if callable(traced_function): + return decorator_trace(traced_function) + return decorator_trace + + return generated_decorator diff --git a/client/qiskit_serverless/core/files.py b/client/qiskit_serverless/core/files.py index fbb799bcc..2547871a8 100644 --- a/client/qiskit_serverless/core/files.py +++ b/client/qiskit_serverless/core/files.py @@ -30,16 +30,21 @@ from typing import List, Optional import requests -from opentelemetry import trace from tqdm import tqdm from qiskit_serverless.core.constants import ( REQUESTS_STREAMING_TIMEOUT, REQUESTS_TIMEOUT, ) +from qiskit_serverless.core.decorators import trace_decorator_factory +from qiskit_serverless.core.function import QiskitFunction +from qiskit_serverless.exception import QiskitServerlessException from qiskit_serverless.utils.json import safe_json_request_as_dict +_trace = trace_decorator_factory("files") + + class GatewayFilesClient: """GatewayFilesClient.""" @@ -54,84 +59,183 @@ def __init__(self, host: str, token: str, version: str): self.host = host self.version = version self._token = token + self._files_url = os.path.join(self.host, "api", self.version, "files") + def _download_with_url( # pylint: disable=too-many-positional-arguments + self, + file: str, + download_location: str, + function: QiskitFunction, + url: str, + target_name: Optional[str] = None, + ) -> Optional[str]: + """Auxiliar function to download a file using an url.""" + with requests.get( + url, + params={ + "file": file, + "provider": function.provider, + "function": function.title, + }, + stream=True, + headers={"Authorization": f"Bearer {self._token}"}, + timeout=REQUESTS_STREAMING_TIMEOUT, + ) as req: + req.raise_for_status() + + total_size_in_bytes = int(req.headers.get("content-length", 0)) + chunk_size = 8192 + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + file_name = target_name or f"downloaded_{str(uuid.uuid4())[:8]}_{file}" + with open(os.path.join(download_location, file_name), "wb") as f: + for chunk in req.iter_content(chunk_size=chunk_size): + progress_bar.update(len(chunk)) + f.write(chunk) + progress_bar.close() + return file_name + + @_trace def download( self, file: str, download_location: str, + function: QiskitFunction, target_name: Optional[str] = None, - provider: Optional[str] = None, ) -> Optional[str]: - """Downloads file.""" - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("files.download"): - with requests.get( - f"{self.host}/api/{self.version}/files/download/", - params={"file": file, "provider": provider}, + """Download a file available to the user for the specific Qiskit Function.""" + return self._download_with_url( + file, + download_location, + function, + os.path.join(self._files_url, "download"), + target_name, + ) + + @_trace + def provider_download( + self, + file: str, + download_location: str, + function: QiskitFunction, + target_name: Optional[str] = None, + ) -> Optional[str]: + """Download a file available to the provider for the specific Qiskit Function.""" + if not function.provider: + raise QiskitServerlessException("`function` doesn't have a provider.") + + return self._download_with_url( + file, + download_location, + function, + os.path.join(self._files_url, "provider", "download"), + target_name, + ) + + @_trace + def upload(self, file: str, function: QiskitFunction) -> Optional[str]: + """Uploads a file in the specific user's Qiskit Function folder.""" + with open(file, "rb") as f: + with requests.post( + os.path.join(self._files_url, "upload/"), + files={"file": f}, + params={"provider": function.provider, "function": function.title}, stream=True, headers={"Authorization": f"Bearer {self._token}"}, timeout=REQUESTS_STREAMING_TIMEOUT, ) as req: - req.raise_for_status() - - total_size_in_bytes = int(req.headers.get("content-length", 0)) - chunk_size = 8192 - progress_bar = tqdm( - total=total_size_in_bytes, unit="iB", unit_scale=True - ) - file_name = target_name or f"downloaded_{str(uuid.uuid4())[:8]}_{file}" - with open(os.path.join(download_location, file_name), "wb") as f: - for chunk in req.iter_content(chunk_size=chunk_size): - progress_bar.update(len(chunk)) - f.write(chunk) - progress_bar.close() - return file_name - - def upload(self, file: str, provider: Optional[str] = None) -> Optional[str]: - """Uploads file.""" - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("files.upload"): - with open(file, "rb") as f: - with requests.post( - f"{self.host}/api/{self.version}/files/upload/", - files={"file": f}, - data={"provider": provider}, - stream=True, - headers={"Authorization": f"Bearer {self._token}"}, - timeout=REQUESTS_STREAMING_TIMEOUT, - ) as req: - if req.ok: - return req.text - return "Upload failed" - return "Can not open file" - - def list(self, provider: Optional[str] = None) -> List[str]: - """Returns list of available files to download produced by programs,""" - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("files.list"): - response_data = safe_json_request_as_dict( - request=lambda: requests.get( - f"{self.host}/api/{self.version}/files/", - params={"provider": provider}, - headers={"Authorization": f"Bearer {self._token}"}, - timeout=REQUESTS_TIMEOUT, - ) + if req.ok: + return req.text + return "Upload failed" + return "Can not open file" + + @_trace + def provider_upload(self, file: str, function: QiskitFunction) -> Optional[str]: + """Uploads a file in the specific provider's Qiskit Function folder.""" + if not function.provider: + raise QiskitServerlessException("`function` doesn't have a provider.") + + with open(file, "rb") as f: + with requests.post( + os.path.join(self._files_url, "provider", "upload/"), + files={"file": f}, + params={"provider": function.provider, "function": function.title}, + stream=True, + headers={"Authorization": f"Bearer {self._token}"}, + timeout=REQUESTS_STREAMING_TIMEOUT, + ) as req: + if req.ok: + return req.text + return "Upload failed" + return "Can not open file" + + @_trace + def list(self, function: QiskitFunction) -> List[str]: + """Returns the list of files available for the user in the Qiskit Function folder.""" + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + self._files_url, + params={"function": function.title, "provider": function.provider}, + headers={"Authorization": f"Bearer {self._token}"}, + timeout=REQUESTS_TIMEOUT, + ) + ) + return response_data.get("results", []) + + @_trace + def provider_list(self, function: QiskitFunction) -> List[str]: + """Returns the list of files available for the provider in the Qiskit Function folder.""" + if not function.provider: + raise QiskitServerlessException("`function` doesn't have a provider.") + + response_data = safe_json_request_as_dict( + request=lambda: requests.get( + os.path.join(self._files_url, "provider"), + params={"function": function.title, "provider": function.provider}, + headers={"Authorization": f"Bearer {self._token}"}, + timeout=REQUESTS_TIMEOUT, ) + ) return response_data.get("results", []) - def delete(self, file: str, provider: Optional[str] = None) -> Optional[str]: - """Deletes file uploaded or produced by the programs,""" - tracer = trace.get_tracer("client.tracer") - with tracer.start_as_current_span("files.delete"): - response_data = safe_json_request_as_dict( - request=lambda: requests.delete( - f"{self.host}/api/{self.version}/files/delete/", - data={"file": file, "provider": provider}, - headers={ - "Authorization": f"Bearer {self._token}", - "format": "json", - }, - timeout=REQUESTS_TIMEOUT, - ) + @_trace + def delete(self, file: str, function: QiskitFunction) -> Optional[str]: + """Deletes a file available to the user for the specific Qiskit Function.""" + response_data = safe_json_request_as_dict( + request=lambda: requests.delete( + os.path.join(self._files_url, "delete"), + params={ + "file": file, + "function": function.title, + "provider": function.provider, + }, + headers={ + "Authorization": f"Bearer {self._token}", + "format": "json", + }, + timeout=REQUESTS_TIMEOUT, + ) + ) + return response_data.get("message", "") + + @_trace + def provider_delete(self, file: str, function: QiskitFunction) -> Optional[str]: + """Deletes a file available to the provider for the specific Qiskit Function.""" + if not function.provider: + raise QiskitServerlessException("`function` doesn't have a provider.") + + response_data = safe_json_request_as_dict( + request=lambda: requests.delete( + os.path.join(self._files_url, "provider", "delete"), + params={ + "file": file, + "function": function.title, + "provider": function.provider, + }, + headers={ + "Authorization": f"Bearer {self._token}", + "format": "json", + }, + timeout=REQUESTS_TIMEOUT, ) + ) return response_data.get("message", "") diff --git a/docs/getting_started/experimental/file_download.ipynb b/docs/getting_started/experimental/file_download.ipynb index e1adec3a0..2fa8edf3e 100644 --- a/docs/getting_started/experimental/file_download.ipynb +++ b/docs/getting_started/experimental/file_download.ipynb @@ -9,23 +9,21 @@ "\n", "In this tutorial we will describe a way to retrieve files produced by Qiskit Functions.\n", "\n", - "This function provides a way to download files produced by functions during execution. All you need is to call `QiskitServerless.download` function and pass `tar` file name to start downloading the file. Or you can list all available files to you by calling `QiskitServerless.files`.\n", + "This function provides a way to download files produced by functions during execution. All you need is to call `QiskitServerless.file_download` function and pass a file name and the Qiskit Function to start downloading the file. Or you can list all available files to you by calling `QiskitServerless.files`.\n", "\n", "Limitations:\n", "\n", - "- only `tar` and `h5` files are supported\n", - "- `tar` or `h5` file should be saved in `/data` directory during your function execution to be visible by `.files()` method call\n", - "- only `/data` directory is supported, `/data/other_folder` will not be visible\n", + "- files should be saved in `/data` directory during your function execution to be visible by `.files()` method call.\n", + "- only `/data` directory is supported, `/data/other_folder` will not be visible.\n", "- as a provider you have access to `/function-data`, it works in a similar way as the `/data` folder with the distinction that users don't have access to it. Only the providers of the specific functions can see files under that path.\n", - "\n", - "> ⚠ This interface is experimental, therefore it is subjected to breaking changes.\n", + "- Qiskit Functions created by you and Qiskit Functions created by others don't share directories.\n", "\n", "> ⚠ This provider is set up with default credentials to a test cluster intended to run on your machine. For information on setting up infrastructure on your local machine, check out the guide on [local infrastructure setup](https://qiskit.github.io/qiskit-serverless/deployment/local.html)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "93717e14-d06e-4e11-bd5b-6cdc3f1b1abd", "metadata": {}, "outputs": [ @@ -35,7 +33,7 @@ "" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -46,7 +44,7 @@ "\n", "serverless = ServerlessClient(\n", " token=os.environ.get(\"GATEWAY_TOKEN\", \"awesome_token\"),\n", - " host=os.environ.get(\"GATEWAY_HOST\", \"http://localhost:8000\"),\n", + " host=os.environ.get(\"GATEWAY_HOST\", \"http://localhost\"),\n", " # If you are using the kubernetes approach the URL must be http://localhost\n", ")\n", "serverless" @@ -62,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "8d93f33b-f7f1-475d-b46e-1106cbe45cae", "metadata": {}, "outputs": [ @@ -72,7 +70,7 @@ "QiskitFunction(file-producer)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -87,29 +85,51 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, + "id": "3577cc07", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QiskitFunction(file-producer)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "my_function = serverless.get(\"file-producer\")\n", + "my_function" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "id": "3fef0868-7574-4fbf-b8de-4a7889bdf5ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "job = serverless.run(\"file-producer\")\n", + "job = my_function.run()\n", "job" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "ecd0bb68-4d3c-450e-b363-a58fd91880b3", "metadata": {}, "outputs": [ @@ -119,7 +139,7 @@ "{'Message': 'my_file.txt archived into my_file.tar'}" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -138,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "08205fd4-b3d6-44d1-a33c-fb3918c26b12", "metadata": {}, "outputs": [ @@ -148,13 +168,13 @@ "['my_file.tar']" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "available_files = serverless.files()\n", + "available_files = serverless.files(my_function)\n", "available_files" ] }, @@ -168,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "id": "39ca652d-77d7-49d2-97e9-42b60963a671", "metadata": {}, "outputs": [ @@ -176,22 +196,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 201/201 [00:00<00:00, 331kiB/s]\n" + "100%|██████████| 200/200 [00:00<00:00, 309kiB/s]\n" ] }, { "data": { "text/plain": [ - "'downloaded_91ea37d9_my_file.tar'" + "'downloaded_8d3f92ba_my_file.tar'" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.file_download(available_files[0])" + "serverless.file_download(available_files[0], my_function)" ] } ], diff --git a/docs/getting_started/experimental/manage_data_directory.ipynb b/docs/getting_started/experimental/manage_data_directory.ipynb index 233854217..1a1925249 100644 --- a/docs/getting_started/experimental/manage_data_directory.ipynb +++ b/docs/getting_started/experimental/manage_data_directory.ipynb @@ -7,26 +7,23 @@ "source": [ "# Shared data directory (Experimental)\n", "\n", - "In this tutorial we will describe a shared directory `data`. Qiskit Functions can produce and consume files in the `data` directory. The files in the `data` directory are shared among Qiskit Functions.\n", + "In this tutorial we will describe a shared directory `data`. Qiskit Functions can produce and consume files in the `data` directory.\n", "\n", - "`QiskitServerless` has `file_download`, `file_upload`, `file_delete` and `files` functions that provide file upload, file download, file delete and list files function.\n", + "`QiskitServerless` has `file_download`, `file_upload`, `file_delete` and `files` functions that provide file upload, file download, file delete and list files.\n", "\n", "Limitations:\n", "\n", - "- only `tar` and `h5` files are supported\n", - "- `tar` or `h5` file should be saved in `/data` director during your function execution to be visible by `.files()` method call\n", - "- only `/data` directory is supported, `/data/other_folder` will not be visible\n", - "- the working directory of these functions are `/data`\n", + "- files should be saved in `/data` directory during your function execution to be visible by `.files()` method call.\n", + "- only `/data` directory is supported, `/data/other_folder` will not be visible.\n", "- as a provider you have access to `/function-data`, it works in a similar way as the `/data` folder with the distinction that users don't have access to it. Only the providers of the specific functions can see files under that path.\n", - "\n", - "> ⚠ This interface is experimental, therefore it is subjected to breaking changes.\n", + "- Qiskit Functions created by you and Qiskit Functions created by others don't share directories.\n", "\n", "> ⚠ This provider is set up with default credentials to a test cluster intended to run on your machine. For information on setting up infrastructure on your local machine, check out the guide on [local infrastructure setup](https://qiskit.github.io/qiskit-serverless/deployment/local.html)." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "429d8e73-aa75-47ca-9dcd-ed3b33a2cdf8", "metadata": {}, "outputs": [ @@ -36,7 +33,7 @@ "" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -53,6 +50,53 @@ "serverless" ] }, + { + "cell_type": "code", + "execution_count": 3, + "id": "87622133", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QiskitFunction(file-producer)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "producer_function = QiskitFunction(\n", + " title=\"file-producer\", entrypoint=\"produce_files.py\", working_dir=\"./source_files/\"\n", + ")\n", + "\n", + "serverless.upload(producer_function)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c26a3422", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QiskitFunction(file-producer)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "producer_function = serverless.get(\"file-producer\")\n", + "producer_function" + ] + }, { "cell_type": "markdown", "id": "bc286ee3-2b3b-47ca-a60e-b4db8a57ee06", @@ -63,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "4f35187e-952e-4f2b-8871-1c971e28a739", "metadata": {}, "outputs": [], @@ -78,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "id": "71d0cd3f-95fd-42ed-ab77-c363433e4172", "metadata": {}, "outputs": [ @@ -88,13 +132,13 @@ "'{\"message\":\"/usr/src/app/media/mockuser/uploaded_file.tar\"}'" ] }, - "execution_count": 3, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.file_upload(filename)" + "serverless.file_upload(filename, producer_function)" ] }, { @@ -107,28 +151,44 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "id": "d0009872-7da2-45c9-8bd6-d42be9da67a9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'Message': 'my_file.txt archived into my_file.tar'}" + "" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "function = QiskitFunction(\n", - " title=\"file-producer\", entrypoint=\"produce_files.py\", working_dir=\"./source_files/\"\n", - ")\n", - "\n", - "serverless.upload(function)\n", "job = serverless.run(\"file-producer\")\n", + "job" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e2e46a7f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Message': 'my_file.txt archived into my_file.tar'}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "job.result()" ] }, @@ -142,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "id": "81ca879c-352f-41ae-a97a-5eb08b4b2c19", "metadata": {}, "outputs": [ @@ -152,13 +212,13 @@ "['uploaded_file.tar', 'my_file.tar']" ] }, - "execution_count": 5, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.files()" + "serverless.files(producer_function)" ] }, { @@ -171,28 +231,90 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "06c345d7-b7dd-49b4-8c67-e8158acdcf33", + "execution_count": 10, + "id": "0741ca7e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'Message': \"b'Hello!'\"}" + "QiskitFunction(file-consumer)" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "function = QiskitFunction(\n", + "consumer_function = QiskitFunction(\n", " title=\"file-consumer\", entrypoint=\"consume_files.py\", working_dir=\"./source_files/\"\n", ")\n", - "\n", - "serverless.upload(function)\n", - "job = serverless.run(\"file-consumer\")\n", + "serverless.upload(consumer_function)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "05286e6d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QiskitFunction(file-consumer)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "consumer_function = serverless.get(\"file-consumer\")\n", + "consumer_function" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "06c345d7-b7dd-49b4-8c67-e8158acdcf33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "job = consumer_function.run()\n", + "job" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0fff1967", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Message': \"b'Hello!'\"}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "job.result()" ] }, @@ -206,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "id": "da35c2cd-8b26-4d36-83f3-02fda396d56d", "metadata": {}, "outputs": [ @@ -216,13 +338,13 @@ "['uploaded_file.tar', 'my_file.tar']" ] }, - "execution_count": 7, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.files()" + "serverless.files(consumer_function)" ] }, { @@ -235,7 +357,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 15, "id": "c826ed42-9ca6-44f4-83c3-20b2f28c09fc", "metadata": {}, "outputs": [ @@ -245,18 +367,18 @@ "'Requested file was deleted.'" ] }, - "execution_count": 8, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.file_delete(\"uploaded_file.tar\")" + "serverless.file_delete(filename, consumer_function)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "id": "cd6a3009-80bf-4c70-a731-63c5d58bfa6a", "metadata": {}, "outputs": [ @@ -266,13 +388,13 @@ "['my_file.tar']" ] }, - "execution_count": 9, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "serverless.files()" + "serverless.files(consumer_function)" ] } ], diff --git a/gateway/api/access_policies/__init__.py b/gateway/api/access_policies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gateway/api/access_policies/providers.py b/gateway/api/access_policies/providers.py new file mode 100644 index 000000000..18a019cf0 --- /dev/null +++ b/gateway/api/access_policies/providers.py @@ -0,0 +1,38 @@ +""" +Access policies implementation for Provider access +""" +import logging + +from api.models import Provider + + +logger = logging.getLogger("gateway") + + +class ProviderAccessPolicy: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access for the user + to the Provider entities. + """ + + @staticmethod + def can_access(user, provider: Provider) -> bool: + """ + Checks if the user has access to a Provider: + + Args: + user: Django user from the request + provider: Provider instance against to check the access + + Returns: + bool: True or False in case the user has access + """ + + user_groups = user.groups.all() + admin_groups = provider.admin_groups.all() + has_access = any(group in admin_groups for group in user_groups) + if not has_access: + logger.warning( + "User [%s] has no access to provider [%s].", user.id, provider.name + ) + return has_access diff --git a/gateway/api/ray.py b/gateway/api/ray.py index a02e32eb4..fffa5e64a 100644 --- a/gateway/api/ray.py +++ b/gateway/api/ray.py @@ -22,6 +22,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from api.models import ComputeResource, Job, JobConfig, DEFAULT_PROGRAM_ENTRYPOINT +from api.services.file_storage import FileStorage, WorkingDir from api.utils import ( try_json_loads, retry_function, @@ -214,7 +215,7 @@ def submit_job(job: Job) -> Job: return job -def create_ray_cluster( # pylint: disable=too-many-branches +def create_ray_cluster( # pylint: disable=too-many-branches,too-many-locals,too-many-statements job: Job, cluster_name: Optional[str] = None, cluster_data: Optional[str] = None, @@ -247,7 +248,6 @@ def create_ray_cluster( # pylint: disable=too-many-branches job_config.max_workers = settings.RAY_CLUSTER_WORKER_MAX_REPLICAS if not job_config.auto_scaling: job_config.auto_scaling = settings.RAY_CLUSTER_WORKER_AUTO_SCALING - node_image = settings.RAY_NODE_IMAGE # cpu job settings node_selector_label = settings.RAY_CLUSTER_CPU_NODE_SELECTOR_LABEL @@ -257,19 +257,34 @@ def create_ray_cluster( # pylint: disable=too-many-branches node_selector_label = settings.RAY_CLUSTER_GPU_NODE_SELECTOR_LABEL gpu_request = settings.LIMITS_GPU_PER_TASK - # if user specified image use specified image - function_data = user.username - if job.program.image is not None: + # configure provider configuration if needed + node_image = settings.RAY_NODE_IMAGE + provider_name = None + if job.program.provider is not None: node_image = job.program.image - if job.program.provider.name: - function_data = job.program.provider.name + provider_name = job.program.provider.name + + user_file_storage = FileStorage( + username=user.username, + working_dir=WorkingDir.USER_STORAGE, + function_title=job.program.title, + provider_name=provider_name, + ) + provider_file_storage = user_file_storage + if job.program.provider is not None: + provider_file_storage = FileStorage( + username=user.username, + working_dir=WorkingDir.PROVIDER_STORAGE, + function_title=job.program.title, + provider_name=provider_name, + ) cluster = get_template("rayclustertemplate.yaml") manifest = cluster.render( { "cluster_name": cluster_name, - "user_id": user.username, - "function_data": function_data, + "user_data_folder": user_file_storage.sub_path, + "provider_data_folder": provider_file_storage.sub_path, "node_image": node_image, "workers": job_config.workers, "min_workers": job_config.min_workers, diff --git a/gateway/api/repositories/functions.py b/gateway/api/repositories/functions.py new file mode 100644 index 000000000..33020f8e7 --- /dev/null +++ b/gateway/api/repositories/functions.py @@ -0,0 +1,224 @@ +""" +Repository implementation for Programs model +""" +import logging + +from typing import List + +from django.db.models import Q + +from api.models import Program as Function + +from api.repositories.users import UserRepository + + +logger = logging.getLogger("gateway") + + +class FunctionRepository: + """ + The main objective of this class is to manage the access to the model + """ + + # This repository should be in the use case implementation + # but this class is not ready yet so it will live here + # in the meantime + user_repository = UserRepository() + + def get_functions_by_permission( + self, author, permission_name: str + ) -> List[Function]: + """ + Returns all the functions available to the user. This means: + - User functions where the user is the author + - Provider functions with the permission specified + + Args: + author: Django author from who retrieve the functions + permission_name (str): name of the permission. Values accepted + RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION + + Returns: + List[Function]: all the functions available to the user + """ + + view_groups = self.user_repository.get_groups_by_permissions( + user=author, permission_name=permission_name + ) + author_groups_with_view_permissions_criteria = Q(instances__in=view_groups) + author_criteria = Q(author=author) + + result_queryset = Function.objects.filter( + author_criteria | author_groups_with_view_permissions_criteria + ).distinct() + + count = result_queryset.count() + logger.info("[%d] Functions found for author [%s]", count, author.id) + + return result_queryset + + def get_user_functions(self, author) -> List[Function]: + """ + Returns the user functions available to the user. This means: + - User functions where the user is the author + - Provider is None + + Args: + author: Django author from who retrieve the functions + + Returns: + List[Program]: user functions available to the user + """ + + author_criteria = Q(author=author) + provider_criteria = Q(provider=None) + + result_queryset = Function.objects.filter( + author_criteria & provider_criteria + ).distinct() + + count = result_queryset.count() + logger.info("[%d] user Functions found for author [%s]", count, author.id) + + return result_queryset + + def get_provider_functions_by_permission( + self, author, permission_name: str + ) -> List[Function]: + """ + Returns the provider functions available to the user. This means: + - Provider functions where the user has run permissions + - Provider functions where the user is the author + - Provider is NOT None + + Args: + author: Django author from who retrieve the functions + permission_name (str): name of the permission. Values accepted + RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION + + Returns: + List[Program]: providers functions available to the user + """ + + run_groups = self.user_repository.get_groups_by_permissions( + user=author, permission_name=permission_name + ) + author_groups_with_run_permissions_criteria = Q(instances__in=run_groups) + provider_exists_criteria = ~Q(provider=None) + author_criteria = Q(author=author) + + result_queryset = Function.objects.filter( + (author_groups_with_run_permissions_criteria & provider_exists_criteria) + | (author_criteria & provider_exists_criteria) + ).distinct() + + count = result_queryset.count() + logger.info("[%d] provider Functions found for author [%s]", count, author.id) + + return result_queryset + + def get_user_function(self, author, title: str) -> Function | None: + """ + Returns the user function associated to a title: + + Args: + author: Django author from who retrieve the function + title (str): Title that the function must have to find it + + Returns: + Program | None: user function with the specific title + """ + + author_criteria = Q(author=author) + title_criteria = Q(title=title) + + result_queryset = Function.objects.filter( + author_criteria & title_criteria + ).first() + + if result_queryset is None: + logger.warning( + "Function [%s] was not found or author [%s] doesn't have access to it", + title, + author.id, + ) + + return result_queryset + + def get_provider_function_by_permission( + self, author, permission_name: str, title: str, provider_name: str + ) -> Function | None: + """ + Returns the provider function associated to: + - A Function title + - A Provider + - Author must have a permission to see it or be the author + + Args: + author: Django author from who retrieve the function + permission_name (str): name of the permission. Values accepted + RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION + title (str): Title that the function must have to find it + provider (str): the name of the provider + + Returns: + Program | None: provider function with the specific + title and provider + """ + + # This access should be checked in the use-case but how we don't + # have it implemented yet we will do the check by now in the + # repository call + view_groups = self.user_repository.get_groups_by_permissions( + user=author, permission_name=permission_name + ) + author_groups_with_view_permissions_criteria = Q(instances__in=view_groups) + author_criteria = Q(author=author) + title_criteria = Q(title=title, provider__name=provider_name) + + result_queryset = Function.objects.filter( + (author_criteria | author_groups_with_view_permissions_criteria) + & title_criteria + ).first() + + if result_queryset is None: + logger.warning( + "Function [%s/%s] was not found or author [%s] doesn't have access to it", + provider_name, + title, + author.id, + ) + + return result_queryset + + def get_function_by_permission( + self, + user, + permission_name: str, + function_title: str, + provider_name: str | None, + ) -> None: + """ + This method returns the specified function if the user is + the author of the function or it has a permission. + + Args: + user: Django user of the function that wants to get it + permission_name (str): name of the permission. Values accepted + RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION + function_title (str): title of the function + provider_name (str | None): name of the provider owner of the function + + Returns: + Program | None: returns the function if it exists + """ + + if provider_name: + return self.get_provider_function_by_permission( + author=user, + permission_name=permission_name, + title=function_title, + provider_name=provider_name, + ) + + return self.get_user_function(author=user, title=function_title) diff --git a/gateway/api/repositories/programs.py b/gateway/api/repositories/programs.py deleted file mode 100644 index f7e1aae74..000000000 --- a/gateway/api/repositories/programs.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Repository implementatio for Programs model -""" -import logging - -from typing import Any, List - -from django.db.models import Q -from django.contrib.auth.models import Group, Permission - -from api.models import RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION, Program - - -logger = logging.getLogger("gateway") - - -class ProgramRepository: - """ - The main objective of this class is to manage the access to the model - """ - - def get_functions(self, author) -> List[Program] | Any: - """ - Returns all the functions available to the user. This means: - - User functions where the user is the author - - Provider functions with view permissions - - Args: - author: Django author from who retrieve the functions - - Returns: - List[Program] | Any: all the functions available to the user - """ - - view_program_permission = Permission.objects.get( - codename=VIEW_PROGRAM_PERMISSION - ) - - user_criteria = Q(user=author) - view_permission_criteria = Q(permissions=view_program_permission) - - author_groups_with_view_permissions = Group.objects.filter( - user_criteria & view_permission_criteria - ) - - author_criteria = Q(author=author) - author_groups_with_view_permissions_criteria = Q( - instances__in=author_groups_with_view_permissions - ) - - result_queryset = Program.objects.filter( - author_criteria | author_groups_with_view_permissions_criteria - ).distinct() - - count = result_queryset.count() - logger.info("[%d] Functions found for author [%s]", count, author.id) - - return result_queryset - - def get_user_functions(self, author) -> List[Program] | Any: - """ - Returns the user functions available to the user. This means: - - User functions where the user is the author - - Provider is None - - Args: - author: Django author from who retrieve the functions - - Returns: - List[Program] | Any: user functions available to the user - """ - - author_criteria = Q(author=author) - provider_criteria = Q(provider=None) - - result_queryset = Program.objects.filter( - author_criteria & provider_criteria - ).distinct() - - count = result_queryset.count() - logger.info("[%d] user Functions found for author [%s]", count, author.id) - - return result_queryset - - def get_provider_functions_with_run_permissions( - self, author - ) -> List[Program] | Any: - """ - Returns the provider functions available to the user. This means: - - Provider functions where the user has run permissions - - Provider is NOT None - - Args: - author: Django author from who retrieve the functions - - Returns: - List[Program] | Any: providers functions available to the user - """ - - run_program_permission = Permission.objects.get(codename=RUN_PROGRAM_PERMISSION) - - user_criteria = Q(user=author) - run_permission_criteria = Q(permissions=run_program_permission) - author_groups_with_run_permissions = Group.objects.filter( - user_criteria & run_permission_criteria - ) - - author_groups_with_run_permissions_criteria = Q( - instances__in=author_groups_with_run_permissions - ) - - provider_exists_criteria = ~Q(provider=None) - - result_queryset = Program.objects.filter( - author_groups_with_run_permissions_criteria & provider_exists_criteria - ).distinct() - - count = result_queryset.count() - logger.info("[%d] provider Functions found for author [%s]", count, author.id) - - return result_queryset - - def get_user_function_by_title(self, author, title: str) -> Program | Any: - """ - Returns the user function associated to a title: - - Args: - author: Django author from who retrieve the function - title: Title that the function must have to find it - - Returns: - Program | Any: user function with the specific title - """ - - author_criteria = Q(author=author) - title_criteria = Q(title=title) - - result_queryset = Program.objects.filter( - author_criteria & title_criteria - ).first() - - if result_queryset is None: - logger.warning( - "Function [%s] was not found or author [%s] doesn't have access to it", - title, - author.id, - ) - - return result_queryset - - def get_provider_function_by_title( - self, author, title: str, provider_name: str - ) -> Program | Any: - """ - Returns the provider function associated to: - - A Function title - - A Provider - - Author must have view permission to see it or be the author - - Args: - author: Django author from who retrieve the function - title: Title that the function must have to find it - provider: Provider associated to the function - - Returns: - Program | Any: provider function with the specific - title and provider - """ - - view_program_permission = Permission.objects.get( - codename=VIEW_PROGRAM_PERMISSION - ) - - user_criteria = Q(user=author) - view_permission_criteria = Q(permissions=view_program_permission) - - author_groups_with_view_permissions = Group.objects.filter( - user_criteria & view_permission_criteria - ) - - author_criteria = Q(author=author) - author_groups_with_view_permissions_criteria = Q( - instances__in=author_groups_with_view_permissions - ) - - title_criteria = Q(title=title, provider__name=provider_name) - - result_queryset = Program.objects.filter( - (author_criteria | author_groups_with_view_permissions_criteria) - & title_criteria - ).first() - - if result_queryset is None: - logger.warning( - "Function [%s/%s] was not found or author [%s] doesn't have access to it", - provider_name, - title, - author.id, - ) - - return result_queryset diff --git a/gateway/api/repositories/providers.py b/gateway/api/repositories/providers.py new file mode 100644 index 000000000..cc3e31125 --- /dev/null +++ b/gateway/api/repositories/providers.py @@ -0,0 +1,32 @@ +""" +Repository implementation for Provider model +""" +import logging + +from api.models import Provider + + +logger = logging.getLogger("gateway") + + +class ProviderRepository: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access to the model + """ + + def get_provider_by_name(self, name: str) -> Provider | None: + """ + Returns the provider associated with a name. + + Args: + - name: provider name + + Returns: + - Provider | None: returns the specific provider if it exists + """ + + provider = Provider.objects.filter(name=name).first() + if provider is None: + logger.warning("Provider [%s] does not exist.", name) + + return provider diff --git a/gateway/api/repositories/users.py b/gateway/api/repositories/users.py new file mode 100644 index 000000000..5827e4478 --- /dev/null +++ b/gateway/api/repositories/users.py @@ -0,0 +1,31 @@ +""" +Repository implementation for Groups model +""" + +from typing import List +from django.contrib.auth.models import Group, Permission +from django.db.models import Q + + +class UserRepository: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access to the model + """ + + def get_groups_by_permissions(self, user, permission_name: str) -> List[Group]: + """ + Returns all the groups associated to a permission available in the user. + + Args: + user: Django user from the request + permission_name (str): name of the permission by look for + + Returns: + List[Group]: all the groups available to the user + """ + + function_permission = Permission.objects.get(codename=permission_name) + user_criteria = Q(user=user) + permission_criteria = Q(permissions=function_permission) + + return Group.objects.filter(user_criteria & permission_criteria) diff --git a/gateway/api/services/__init__.py b/gateway/api/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gateway/api/services/file_storage.py b/gateway/api/services/file_storage.py new file mode 100644 index 000000000..875a8d235 --- /dev/null +++ b/gateway/api/services/file_storage.py @@ -0,0 +1,239 @@ +""" +This file stores the logic to manage the access to data stores +""" +import glob +import logging +import mimetypes +import os +from enum import Enum +from typing import Optional, Tuple +from wsgiref.util import FileWrapper + +from django.conf import settings +from django.core.files import File + +from utils import sanitize_file_path + + +class WorkingDir(Enum): + """ + This Enum has the values: + USER_STORAGE + PROVIDER_STORAGE + + Both values are being used to identify in + FileStorage service the path to be used + """ + + USER_STORAGE = 1 + PROVIDER_STORAGE = 2 + + +logger = logging.getLogger("gateway") + + +class FileStorage: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access of the users to their storage. + + Attributes: + username (str): storgae user's username + working_dir (WorkingDir(Enum)): working directory + function_title (str): title of the function in case is needed to build the path + provider_name (str | None): name of the provider in caseis needed to build the path + """ + + def __init__( + self, + username: str, + working_dir: WorkingDir, + function_title: str, + provider_name: str | None, + ) -> None: + self.sub_path = None + self.absolute_path = None + self.username = username + + if working_dir is WorkingDir.USER_STORAGE: + self.sub_path = self.__get_user_sub_path(function_title, provider_name) + elif working_dir is WorkingDir.PROVIDER_STORAGE: + self.sub_path = self.__get_provider_sub_path(function_title, provider_name) + + self.absolute_path = self.__get_absolute_path(self.sub_path) + + def __get_user_sub_path( + self, function_title: str, provider_name: str | None + ) -> str: + """ + This method returns the sub-path where the user or the function + will store files + + Args: + function_title (str): in case the function is from a + provider it will identify the function folder + provider_name (str | None): in case a provider is provided it will + identify the folder for the specific function + + Returns: + str: storage sub-path. + - In case the function is from a provider that sub-path would + be: username/provider_name/function_title + - In case the function is from a user that path would + be: username/ + """ + if provider_name is None: + path = os.path.join(self.username) + else: + path = os.path.join(self.username, provider_name, function_title) + + return sanitize_file_path(path) + + def __get_provider_sub_path(self, function_title: str, provider_name: str) -> str: + """ + This method returns the provider sub-path where the user + or the function will store files + + Args: + function_title (str): in case the function is from a provider + it will identify the function folder + provider_name (str): in case a provider is provided + it will identify the folder for the specific function + + Returns: + str: storage sub-path following the format provider_name/function_title/ + """ + path = os.path.join(provider_name, function_title) + + return sanitize_file_path(path) + + def __get_absolute_path(self, sub_path: str) -> str: + """ + This method returns the absolute path where the user + or the function will store files + + Args: + sub_path (str): the sub-path that we will use to build + the absolute path + + Returns: + str: storage path. + """ + path = os.path.join(settings.MEDIA_ROOT, sub_path) + sanitized_path = sanitize_file_path(path) + + # Create directory if it doesn't exist + if not os.path.exists(sanitized_path): + os.makedirs(sanitized_path, exist_ok=True) + logger.debug("Path %s was created.", sanitized_path) + + return sanitized_path + + def get_files(self) -> list[str]: + """ + This method returns a list of file names following the next rules: + - It returns only files from a user or a provider file storage + - Directories are excluded + + Returns: + list[str]: list of file names + """ + + return [ + os.path.basename(path) + for path in glob.glob(f"{self.absolute_path}/*") + if os.path.isfile(path) + ] + + def get_file(self, file_name: str) -> Optional[Tuple[FileWrapper, str, int]]: + """ + This method returns a file from file_name: + - Only files with supported extensions are available to download + - It returns only a file from a user or a provider file storage + + Args: + file_name (str): the name of the file to download + + Returns: + FileWrapper: the file itself + str: with the type of the file + int: with the size of the file + """ + + file_name_path = os.path.basename(file_name) + path_to_file = sanitize_file_path( + os.path.join(self.absolute_path, file_name_path) + ) + + if not os.path.exists(path_to_file): + logger.warning( + "File %s not found in %s.", + file_name_path, + path_to_file, + ) + return None + + # We can not use context manager here. Django close the file automatically: + # https://docs.djangoproject.com/en/5.1/ref/request-response/#fileresponse-objects + file_wrapper = FileWrapper( + open(path_to_file, "rb") # pylint: disable=consider-using-with + ) + + file_type = mimetypes.guess_type(path_to_file)[0] + file_size = os.path.getsize(path_to_file) + + return file_wrapper, file_type, file_size + + def upload_file(self, file: File) -> str: + """ + This method upload a file to the specific path: + - Only files with supported extensions are available to download + - It returns only a file from a user or a provider file storage + + Args: + file (django.File): the file to store in the specific path + + Returns: + str: the path where the file was stored + """ + + file_name = sanitize_file_path(file.name) + basename = os.path.basename(file_name) + path_to_file = sanitize_file_path(os.path.join(self.absolute_path, basename)) + + with open(path_to_file, "wb+") as destination: + for chunk in file.chunks(): + destination.write(chunk) + + return path_to_file + + def remove_file(self, file_name: str) -> bool: + """ + This method remove a file in the path of file_name + + Args: + file_name (str): the name of the file to remove + + Returns: + - True if it was deleted + - False otherwise + """ + + file_name_path = os.path.basename(file_name) + path_to_file = sanitize_file_path( + os.path.join(self.absolute_path, file_name_path) + ) + + try: + os.remove(path_to_file) + except FileNotFoundError: + logger.warning( + "File %s not found in %s.", + file_name_path, + path_to_file, + ) + return False + except OSError as ex: + logger.warning("OSError: %s.", ex.strerror) + return False + + return True diff --git a/gateway/api/utils.py b/gateway/api/utils.py index 994236c10..779568b15 100644 --- a/gateway/api/utils.py +++ b/gateway/api/utils.py @@ -419,15 +419,12 @@ def create_dependency_allowlist(): return allowlist -def sanitize_name(name: str): +def sanitize_name(name: str | None): """Sanitize name""" - if name: - sanitized_name = "" - for c in name: - if c.isalnum() or c in ["_", "-", "/"]: - sanitized_name += c - return sanitized_name - return name + if not name: + return name + # Remove all characters except alphanumeric, _, -, / + return re.sub("[^a-zA-Z0-9_\\-/]", "", name) def create_gpujob_allowlist(): @@ -448,3 +445,11 @@ def create_gpujob_allowlist(): raise ValueError("Unable to decode gpujob allowlist") from e return gpujobs + + +def sanitize_file_name(name: str | None): + """Sanitize the name of a file""" + if not name: + return name + # Remove all characters except alphanumeric, _, ., - + return re.sub("[^a-zA-Z0-9_\\.\\-]", "", name) diff --git a/gateway/api/v1/views/files.py b/gateway/api/v1/views/files.py index faa5a3150..456ad7d15 100644 --- a/gateway/api/v1/views/files.py +++ b/gateway/api/v1/views/files.py @@ -19,7 +19,7 @@ class FilesViewSet(views.FilesViewSet): permission_classes = [permissions.IsAuthenticated, IsOwner] @swagger_auto_schema( - operation_description="List of available for user files", + operation_description="List of available files in the user directory", manual_parameters=[ openapi.Parameter( "provider", @@ -28,25 +28,62 @@ class FilesViewSet(views.FilesViewSet): type=openapi.TYPE_STRING, required=False, ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="function title", + type=openapi.TYPE_STRING, + required=True, + ), ], ) def list(self, request): return super().list(request) @swagger_auto_schema( - operation_description="Download a specific file", + operation_description="List of available files in the provider directory", + manual_parameters=[ + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="provider name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="function title", + type=openapi.TYPE_STRING, + required=True, + ), + ], + ) + @action(methods=["GET"], detail=False, url_path="provider") + def provider_list(self, request): + return super().provider_list(request) + + @swagger_auto_schema( + operation_description="Download a specific file in the user directory", manual_parameters=[ openapi.Parameter( "file", openapi.IN_QUERY, - description="file name", + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", type=openapi.TYPE_STRING, required=True, ), openapi.Parameter( "provider", openapi.IN_QUERY, - description="provider name", + description="Provider name", type=openapi.TYPE_STRING, required=False, ), @@ -56,38 +93,156 @@ def list(self, request): def download(self, request): return super().download(request) + @swagger_auto_schema( + operation_description="Download a specific file in the provider directory", + manual_parameters=[ + openapi.Parameter( + "file", + openapi.IN_QUERY, + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=True, + ), + ], + ) + @action(methods=["GET"], detail=False, url_path="provider/download") + def provider_download(self, request): + return super().provider_download(request) + @swagger_auto_schema( operation_description="Deletes file uploaded or produced by the programs", - request_body=openapi.Schema( - type=openapi.TYPE_OBJECT, - properties={ - "file": openapi.Schema( - type=openapi.TYPE_STRING, description="file name" - ), - "provider": openapi.Schema( - type=openapi.TYPE_STRING, description="provider name" - ), - }, - required=["file"], - ), + manual_parameters=[ + openapi.Parameter( + "file", + openapi.IN_QUERY, + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=False, + ), + ], ) @action(methods=["DELETE"], detail=False) def delete(self, request): return super().delete(request) + @swagger_auto_schema( + operation_description="Deletes file uploaded or produced by the programs", + manual_parameters=[ + openapi.Parameter( + "file", + openapi.IN_QUERY, + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=True, + ), + ], + ) + @action(methods=["DELETE"], detail=False, url_path="provider/delete") + def provider_delete(self, request): + return super().provider_delete(request) + @swagger_auto_schema( operation_description="Upload selected file", request_body=openapi.Schema( type=openapi.TYPE_OBJECT, properties={ - "file": openapi.Schema(type=openapi.TYPE_FILE, description="file name"), - "provider": openapi.Schema( - type=openapi.TYPE_STRING, description="provider name" - ), + "file": openapi.Schema( + type=openapi.TYPE_FILE, description="File to be uploaded" + ) }, required=["file"], ), + manual_parameters=[ + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=False, + ), + ], ) @action(methods=["POST"], detail=False) def upload(self, request): return super().upload(request) + + @swagger_auto_schema( + operation_description="Upload a file into the provider directory", + request_body=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + "file": openapi.Schema( + type=openapi.TYPE_FILE, description="File to be uploaded" + ) + }, + required=["file"], + ), + manual_parameters=[ + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=True, + ), + ], + ) + @action(methods=["POST"], detail=False, url_path="provider/upload") + def provider_upload(self, request): + return super().provider_upload(request) diff --git a/gateway/api/views/files.py b/gateway/api/views/files.py index 10a54a4dc..8ad1431b3 100644 --- a/gateway/api/views/files.py +++ b/gateway/api/views/files.py @@ -3,13 +3,9 @@ Version views inherit from the different views. """ -import glob import logging -import mimetypes import os -from wsgiref.util import FileWrapper -from django.conf import settings from django.http import StreamingHttpResponse # pylint: disable=duplicate-code @@ -23,13 +19,17 @@ from rest_framework.decorators import action from rest_framework.response import Response -from utils import sanitize_file_path -from api.models import Provider +from api.access_policies.providers import ProviderAccessPolicy +from api.models import RUN_PROGRAM_PERMISSION +from api.repositories.functions import FunctionRepository +from api.repositories.providers import ProviderRepository +from api.services.file_storage import FileStorage, WorkingDir +from api.utils import sanitize_file_name, sanitize_name # pylint: disable=duplicate-code logger = logging.getLogger("gateway") resource = Resource(attributes={SERVICE_NAME: "QiskitServerless-Gateway"}) -provider = TracerProvider(resource=resource) +tracer_provider = TracerProvider(resource=resource) otel_exporter = BatchSpanProcessor( OTLPSpanExporter( endpoint=os.environ.get( @@ -38,179 +38,488 @@ insecure=bool(int(os.environ.get("OTEL_EXPORTER_OTLP_TRACES_INSECURE", "0"))), ) ) -provider.add_span_processor(otel_exporter) +tracer_provider.add_span_processor(otel_exporter) if bool(int(os.environ.get("OTEL_ENABLED", "0"))): - trace._set_tracer_provider(provider, log=False) # pylint: disable=protected-access + trace._set_tracer_provider( # pylint: disable=protected-access + tracer_provider, log=False + ) class FilesViewSet(viewsets.ViewSet): - """ViewSet for file operations handling. - - Note: only tar files are available for list and download + """ + ViewSet for file operations handling. """ BASE_NAME = "files" - def list_user_providers(self, user): - """list provider names that the user in""" - provider_list = [] - providers = Provider.objects.all() - for instance in providers: - user_groups = user.groups.all() - admin_groups = instance.admin_groups.all() - provider_found = any(group in admin_groups for group in user_groups) - if provider_found: - provider_list.append(instance.name) - return provider_list - - def check_user_has_provider(self, user, provider_name): - """check if user has the provider""" - return provider_name in self.list_user_providers(user) + function_repository = FunctionRepository() + provider_repository = ProviderRepository() def list(self, request): - """List of available for user files.""" - response = Response( - {"message": "Requested file was not found."}, - status=status.HTTP_404_NOT_FOUND, - ) - files = [] + """ + It returns a list with the names of available files for the user directory: + it will look under its username or username/provider_name/function_title + """ + tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.files.list", context=ctx): - user_dir = request.user.username - provider_name = request.query_params.get("provider") - if provider_name is not None: - if self.check_user_has_provider(request.user, provider_name): - user_dir = provider_name + username = request.user.username + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.USER_STORAGE + + if function_title is None: + return Response( + {"message": "Qiskit Function title is mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + if provider_name: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long else: - return response - user_dir = os.path.join( - sanitize_file_path(settings.MEDIA_ROOT), - sanitize_file_path(user_dir), - ) - if os.path.exists(user_dir): - files = [ - os.path.basename(path) - for path in glob.glob(f"{user_dir}/*.tar") - + glob.glob(f"{user_dir}/*.h5") - ] - else: - logger.warning( - "Directory %s does not exist for %s.", user_dir, request.user + error_message = f"Qiskit Function {function_title} doesn't exist." + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + files = file_storage.get_files() + + return Response({"results": files}) + + @action(methods=["GET"], detail=False, url_path="provider") + def provider_list(self, request): + """ + It returns a list with the names of available files for the provider working directory: + provider_name/function_title + """ + tracer = trace.get_tracer("gateway.tracer") + ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) + with tracer.start_as_current_span("gateway.files.provider_list", context=ctx): + username = request.user.username + provider_name = sanitize_name(request.query_params.get("provider")) + function_title = sanitize_name(request.query_params.get("function")) + working_dir = WorkingDir.PROVIDER_STORAGE + + if function_title is None or provider_name is None: + return Response( + { + "message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + provider = self.provider_repository.get_provider_by_name(name=provider_name) + if provider is None: + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + if not ProviderAccessPolicy.can_access( + user=request.user, provider=provider + ): + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + return Response( + { + "message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + }, + status=status.HTTP_404_NOT_FOUND, ) + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + files = file_storage.get_files() + return Response({"results": files}) @action(methods=["GET"], detail=False) - def download(self, request): # pylint: disable=invalid-name - """Download selected file.""" - # default response for file not found, overwritten if file is found - response = Response( - {"message": "Requested file was not found."}, - status=status.HTTP_404_NOT_FOUND, - ) + def download(self, request): + """ + It returns a file from user paths: + - username/ + - username/provider_name/function_title + """ tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.files.download", context=ctx): - requested_file_name = request.query_params.get("file") - provider_name = request.query_params.get("provider") - if requested_file_name is not None: - user_dir = request.user.username - if provider_name is not None: - if self.check_user_has_provider(request.user, provider_name): - user_dir = provider_name - else: - return response - # look for file in user's folder - filename = os.path.basename(requested_file_name) - user_dir = os.path.join( - sanitize_file_path(settings.MEDIA_ROOT), - sanitize_file_path(user_dir), - ) - file_path = os.path.join( - sanitize_file_path(user_dir), sanitize_file_path(filename) - ) - if os.path.exists(user_dir) and os.path.exists(file_path) and filename: - chunk_size = 8192 - # note: we do not use with statements as Streaming response closing file itself. - response = StreamingHttpResponse( - FileWrapper( - open( # pylint: disable=consider-using-with - file_path, "rb" - ), - chunk_size, - ), - content_type=mimetypes.guess_type(file_path)[0], - ) - response["Content-Length"] = os.path.getsize(file_path) - response["Content-Disposition"] = f"attachment; filename={filename}" + username = request.user.username + requested_file_name = sanitize_file_name( + request.query_params.get("file", None) + ) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.USER_STORAGE + + if not all([requested_file_name, function_title]): + return Response( + {"message": "File name and Qiskit Function title are mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + if provider_name: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + else: + error_message = f"Qiskit Function {function_title} doesn't exist." + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.get_file(file_name=requested_file_name) + if result is None: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_wrapper, file_type, file_size = result + response = StreamingHttpResponse(file_wrapper, content_type=file_type) + response["Content-Length"] = file_size + response[ + "Content-Disposition" + ] = f"attachment; filename={requested_file_name}" + return response + + @action(methods=["GET"], detail=False, url_path="provider/download") + def provider_download(self, request): + """ + It returns a file from provider path: + - provider_name/function_title + """ + tracer = trace.get_tracer("gateway.tracer") + ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) + with tracer.start_as_current_span( + "gateway.files.provider_download", context=ctx + ): + username = request.user.username + requested_file_name = sanitize_file_name( + request.query_params.get("file", None) + ) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.PROVIDER_STORAGE + + if not all([requested_file_name, function_title, provider_name]): + return Response( + { + "message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + provider = self.provider_repository.get_provider_by_name(name=provider_name) + if provider is None: + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + if not ProviderAccessPolicy.can_access( + user=request.user, provider=provider + ): + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + return Response( + { + "message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + }, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.get_file(file_name=requested_file_name) + if result is None: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_wrapper, file_type, file_size = result + response = StreamingHttpResponse(file_wrapper, content_type=file_type) + response["Content-Length"] = file_size + response[ + "Content-Disposition" + ] = f"attachment; filename={requested_file_name}" return response @action(methods=["DELETE"], detail=False) - def delete(self, request): # pylint: disable=invalid-name - """Deletes file uploaded or produced by the programs,""" - # default response for file not found, overwritten if file is found - response = Response( - {"message": "Requested file was not found."}, - status=status.HTTP_404_NOT_FOUND, - ) + def delete(self, request): + """Deletes file uploaded or produced by the functions""" tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.files.delete", context=ctx): - if request.data and "file" in request.data: - # look for file in user's folder - filename = os.path.basename(request.data["file"]) - provider_name = request.data.get("provider") - user_dir = request.user.username - if provider_name is not None: - if self.check_user_has_provider(request.user, provider_name): - user_dir = provider_name - else: - return response - user_dir = os.path.join( - sanitize_file_path(settings.MEDIA_ROOT), - sanitize_file_path(user_dir), - ) - file_path = os.path.join( - sanitize_file_path(user_dir), sanitize_file_path(filename) - ) - if os.path.exists(user_dir) and os.path.exists(file_path) and filename: - os.remove(file_path) - response = Response( - {"message": "Requested file was deleted."}, - status=status.HTTP_200_OK, - ) - return response + username = request.user.username + file_name = sanitize_file_name(request.query_params.get("file", None)) + provider_name = sanitize_name(request.query_params.get("provider")) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.USER_STORAGE + + if not all([file_name, function_title]): + return Response( + {"message": "File name and Qiskit Function title are mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + + if not function: + if provider_name: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + else: + error_message = f"Qiskit Function {function_title} doesn't exist." + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.remove_file(file_name=file_name) + if not result: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + return Response( + {"message": "Requested file was deleted."}, status=status.HTTP_200_OK + ) + + @action(methods=["DELETE"], detail=False, url_path="provider/delete") + def provider_delete(self, request): + """Deletes file uploaded or produced by the functions""" + tracer = trace.get_tracer("gateway.tracer") + ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) + with tracer.start_as_current_span("gateway.files.delete", context=ctx): + username = request.user.username + file_name = sanitize_file_name(request.query_params.get("file")) + provider_name = sanitize_name(request.query_params.get("provider")) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.PROVIDER_STORAGE + + if not all([file_name, function_title, provider_name]): + return Response( + {"message": "File name and Qiskit Function title are mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + provider = self.provider_repository.get_provider_by_name(name=provider_name) + if provider is None: + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + if not ProviderAccessPolicy.can_access( + user=request.user, provider=provider + ): + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + + if not function: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.remove_file(file_name=file_name) + if not result: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + return Response( + {"message": "Requested file was deleted."}, status=status.HTTP_200_OK + ) @action(methods=["POST"], detail=False) - def upload(self, request): # pylint: disable=invalid-name - """Upload selected file.""" - response = Response( - {"message": "Requested file was not found."}, - status=status.HTTP_404_NOT_FOUND, - ) + def upload(self, request): + """ + It upload a file to a specific user paths: + - username/ + - username/provider_name/function_title + """ tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.files.download", context=ctx): + username = request.user.username upload_file = request.FILES["file"] - filename = os.path.basename(upload_file.name) - user_dir = request.user.username - if request.data and "provider" in request.data: - provider_name = request.data["provider"] - if provider_name is not None: - if self.check_user_has_provider(request.user, provider_name): - user_dir = provider_name - else: - return response - user_dir = os.path.join( - sanitize_file_path(settings.MEDIA_ROOT), - sanitize_file_path(user_dir), - ) - file_path = os.path.join( - sanitize_file_path(user_dir), sanitize_file_path(filename) - ) - with open(file_path, "wb+") as destination: - for chunk in upload_file.chunks(): - destination.write(chunk) - return Response({"message": file_path}) - return Response("server error", status=status.HTTP_500_INTERNAL_SERVER_ERROR) + file_name = sanitize_file_name(upload_file.name) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.USER_STORAGE + + if not all([file_name, function_title]): + return Response( + {"message": "A file and Qiskit Function title are mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + if provider_name: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + else: + error_message = f"Qiskit Function {function_title} doesn't exist." + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.upload_file(file=upload_file) + + return Response({"message": result}) + + @action(methods=["POST"], detail=False, url_path="provider/upload") + def provider_upload(self, request): + """ + It upload a file to a specific user paths: + - provider_name/function_title + """ + tracer = trace.get_tracer("gateway.tracer") + ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) + with tracer.start_as_current_span("gateway.files.download", context=ctx): + username = request.user.username + upload_file = request.FILES["file"] + file_name = sanitize_file_name(upload_file.name) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.PROVIDER_STORAGE + + if not all([file_name, function_title, provider_name]): + return Response( + { + "message": "The file, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + provider = self.provider_repository.get_provider_by_name(name=provider_name) + if provider is None: + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + if not ProviderAccessPolicy.can_access( + user=request.user, provider=provider + ): + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + + function = self.function_repository.get_function_by_permission( + user=request.user, + permission_name=RUN_PROGRAM_PERMISSION, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + return Response( + { + "message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + }, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.upload_file(file=upload_file) + + return Response({"message": result}) diff --git a/gateway/api/views/programs.py b/gateway/api/views/programs.py index dd6eba04f..e3682ff29 100644 --- a/gateway/api/views/programs.py +++ b/gateway/api/views/programs.py @@ -20,7 +20,7 @@ from rest_framework import viewsets, status from rest_framework.response import Response -from api.repositories.programs import ProgramRepository +from api.repositories.functions import FunctionRepository from api.utils import sanitize_name from api.serializers import ( JobConfigSerializer, @@ -29,7 +29,7 @@ RunProgramSerializer, UploadProgramSerializer, ) -from api.models import RUN_PROGRAM_PERMISSION, Program, Job +from api.models import RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION, Program, Job from api.views.enums.type_filter import TypeFilter # pylint: disable=duplicate-code @@ -56,7 +56,7 @@ class ProgramViewSet(viewsets.GenericViewSet): BASE_NAME = "programs" - program_repository = ProgramRepository() + program_repository = FunctionRepository() @staticmethod def get_serializer_job_config(*args, **kwargs): @@ -161,13 +161,15 @@ def list(self, request): # Catalog filter only returns providers functions that user has access: # author has view permissions and the function has a provider assigned functions = ( - self.program_repository.get_provider_functions_with_run_permissions( - author + self.program_repository.get_provider_functions_by_permission( + author, permission_name=RUN_PROGRAM_PERMISSION ) ) else: # If filter is not applied we return author and providers functions together - functions = self.program_repository.get_functions(author) + functions = self.program_repository.get_functions_by_permission( + author, permission_name=VIEW_PROGRAM_PERMISSION + ) serializer = self.get_serializer(functions, many=True) @@ -307,11 +309,14 @@ def get_by_title(self, request, title): ) if provider_name: - function = self.program_repository.get_provider_function_by_title( - author=author, title=function_title, provider_name=provider_name + function = self.program_repository.get_provider_function_by_permission( + author=author, + permission_name=VIEW_PROGRAM_PERMISSION, + title=function_title, + provider_name=provider_name, ) else: - function = self.program_repository.get_user_function_by_title( + function = self.program_repository.get_user_function( author=author, title=function_title ) diff --git a/gateway/tests/api/test_files.py b/gateway/tests/api/test_files.py deleted file mode 100644 index 9510da604..000000000 --- a/gateway/tests/api/test_files.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Tests files api.""" - -import os -from urllib.parse import quote_plus - -from django.urls import reverse -from rest_framework import status -from rest_framework.test import APITestCase -from django.contrib.auth import models - - -class TestFilesApi(APITestCase): - """TestProgramApi.""" - - fixtures = ["tests/fixtures/fixtures.json"] - - def test_files_list_non_authorized(self): - """Tests files list non-authorized.""" - url = reverse("v1:files-list") - response = self.client.get(url, format="json") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_files_list(self): - """Tests files list.""" - - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-list") - response = self.client.get(url, format="json") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {"results": ["artifact.tar"]}) - - def test_provider_files_list(self): - """Tests files list.""" - - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user_2") - self.client.force_authenticate(user=user) - url = reverse("v1:files-list") - response = self.client.get(url, data={"provider": "default"}, format="json") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {"results": ["provider_artifact.tar"]}) - - def test_non_existing_file_download(self): - """Tests downloading non-existing file.""" - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-download") - response = self.client.get( - url, data={"file": "non_existing.tar"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self.client.get(url, format="json") - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_file_download(self): - """Tests downloading non-existing file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-download") - response = self.client.get( - url, data={"file": "artifact.tar"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertTrue(response.streaming) - - def test_provider_file_download(self): - """Tests downloading non-existing file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user_2") - self.client.force_authenticate(user=user) - url = reverse("v1:files-download") - response = self.client.get( - url, - data={"file": "provider_artifact.tar", "provider": "default"}, - format="json", - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertTrue(response.streaming) - - def test_file_delete(self): - """Tests delete file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with open( - os.path.join(media_root, "test_user", "artifact_delete.tar"), "w" - ) as fp: - fp.write("This is first line") - print(fp) - fp.close() - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-delete") - response = self.client.delete( - url, data={"file": "artifact_delete.tar"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_provider_file_delete(self): - """Tests delete file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with open( - os.path.join(media_root, "default", "artifact_delete.tar"), "w" - ) as fp: - fp.write("This is first line") - print(fp) - fp.close() - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user_2") - self.client.force_authenticate(user=user) - url = reverse("v1:files-delete") - response = self.client.delete( - url, - data={"file": "artifact_delete.tar", "provider": "default"}, - format="json", - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_non_existing_file_delete(self): - """Tests delete file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-delete") - response = self.client.delete( - url, data={"file": "artifact_delete.tar"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_file_upload(self): - """Tests uploading existing file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-upload") - with open("README.md") as f: - response = self.client.post( - url, - data={"file": f}, - format="multipart", - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertTrue(os.path.join(media_root, "test_user", "README.md")) - os.remove(os.path.join(media_root, "test_user", "README.md")) - - def test_provider_file_upload(self): - """Tests uploading existing file.""" - media_root = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - - with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user_2") - self.client.force_authenticate(user=user) - url = reverse("v1:files-upload") - with open("README.md") as f: - response = self.client.post( - url, - data={"file": f, "provider": "default"}, - format="multipart", - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertTrue(os.path.join(media_root, "test_user", "README.md")) - os.remove(os.path.join(media_root, "default", "README.md")) - - def test_escape_directory(self): - """Tests directory escape / injection.""" - with self.settings( - MEDIA_ROOT=os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "resources", - "fake_media", - ) - ): - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-download") - response = self.client.get( - url, data={"file": "../test_user_2/artifact_2.tar"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - response = self.client.get( - url, data={"file": "../test_user_2/artifact_2.tar/"}, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/gateway/tests/api/test_v1_files.py b/gateway/tests/api/test_v1_files.py new file mode 100644 index 000000000..de1ee58b1 --- /dev/null +++ b/gateway/tests/api/test_v1_files.py @@ -0,0 +1,505 @@ +"""Tests files api.""" + +import os +from urllib.parse import urlencode + +from django.urls import reverse +from pytest import mark +from rest_framework import status +from rest_framework.test import APITestCase +from django.contrib.auth import models + + +class TestFilesApi(APITestCase): + """TestProgramApi.""" + + fixtures = ["tests/fixtures/files_fixtures.json"] + + def test_files_list_non_authorized(self): + """Tests files list non-authorized.""" + url = reverse("v1:files-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_files_list_with_empty_params(self): + """Tests files list using empty params""" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + user = models.User.objects.get(username="test_user") + self.client.force_authenticate(user=user) + url = reverse("v1:files-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_files_list_from_user_working_dir(self): + """Tests files list with working dir as user""" + + function = "personal-program" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-list") + response = self.client.get( + url, + { + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {"results": ["artifact_2.tar"]}) + + def test_files_list_from_user_without_access_to_function(self): + """Tests files list with working dir as user where the user has no access to the function""" + + provider = "default" + function = "Program" + + user = models.User.objects.get(username="test_user") + self.client.force_authenticate(user=user) + url = reverse("v1:files-list") + response = self.client.get( + url, + { + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_files_list_from_user_with_access_to_function(self): + """Tests files list with working dir as user where the user has access to the function""" + + provider = "default" + function = "Program" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-list") + response = self.client.get( + url, + { + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {"results": ["user_program_artifact.tar"]}) + + def test_files_list_from_a_provider_that_not_exist(self): + """Tests files list with a provider that it doesn't exist""" + + provider = "noexist" + function = "Program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-list") + response = self.client.get( + url, + { + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_files_provider_list_using_provider_working_dir(self): + """Tests files provider list with working dir as provider""" + + provider = "default" + function = "Program" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-list") + response = self.client.get( + url, + { + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data, {"results": ["provider_program_artifact.tar"]} + ) + + def test_files_provider_list_with_a_user_that_has_no_access_to_provider(self): + """Tests files provider list with working dir as provider""" + + provider = "default" + function = "Program" + + user = models.User.objects.get(username="test_user") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-list") + response = self.client.get( + url, + { + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_non_existing_file_download(self): + """Tests downloading non-existing file.""" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "non_existing_file.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-download") + response = self.client.get( + url, + { + "file": file, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_file_download(self): + """Tests downloading an existing file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "artifact_2.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-download") + response = self.client.get( + url, + { + "file": file, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.streaming) + + def test_non_existing_provider_file_download(self): + """Tests downloading a non-existing file from a provider storage.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "non-existing_artifact.tar" + provider = "default" + function = "Program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-download") + response = self.client.get( + url, + { + "file": file, + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_provider_file_download(self): + """Tests downloading a file from a provider storage.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "provider_program_artifact.tar" + provider = "default" + function = "Program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-download") + response = self.client.get( + url, + { + "file": file, + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.streaming) + + def test_file_delete(self): + """Tests delete file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + function = "personal-program" + file = "artifact_delete.tar" + username = "test_user_2" + functionPath = os.path.join(media_root, username) + + if not os.path.exists(functionPath): + os.makedirs(functionPath) + + with open( + os.path.join(functionPath, file), + "w+", + ) as fp: + fp.write("This is first line") + print(fp) + fp.close() + + with self.settings(MEDIA_ROOT=media_root): + query_params = {"function": function, "file": file} + user = models.User.objects.get(username=username) + self.client.force_authenticate(user=user) + url = reverse("v1:files-delete") + response = self.client.delete(f"{url}?{urlencode(query_params)}") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_provider_file_delete(self): + """Tests delete file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + provider = "default" + function = "Program" + file = "artifact_delete.tar" + username = "test_user_2" + functionPath = os.path.join(media_root, provider, function) + + if not os.path.exists(functionPath): + os.makedirs(functionPath) + + with open( + os.path.join(functionPath, file), + "w+", + ) as fp: + fp.write("This is first line") + print(fp) + fp.close() + + with self.settings(MEDIA_ROOT=media_root): + query_params = {"function": function, "provider": provider, "file": file} + user = models.User.objects.get(username=username) + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-delete") + response = self.client.delete(f"{url}?{urlencode(query_params)}") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_non_existing_file_delete(self): + """Tests delete file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + function = "personal-program" + file = "non-existing-artifact_delete.tar" + username = "test_user_2" + + with self.settings(MEDIA_ROOT=media_root): + query_params = {"function": function, "file": file} + user = models.User.objects.get(username=username) + self.client.force_authenticate(user=user) + url = reverse("v1:files-delete") + response = self.client.delete(f"{url}?{urlencode(query_params)}") + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_non_existing_provider_file_delete(self): + """Tests delete file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + provider = "default" + function = "Program" + file = "non-existing-artifact_delete.tar" + username = "test_user_2" + + with self.settings(MEDIA_ROOT=media_root): + query_params = {"function": function, "provider": provider, "file": file} + user = models.User.objects.get(username=username) + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-delete") + response = self.client.delete(f"{url}?{urlencode(query_params)}") + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_file_upload(self): + """Tests uploading existing file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + function = "personal-program" + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-upload") + + with open("README.md") as f: + query_params = {"function": function} + response = self.client.post( + f"{url}?{urlencode(query_params)}", + {"file": f}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(os.path.join(media_root, "test_user_2", "README.md")) + os.remove(os.path.join(media_root, "test_user_2", "README.md")) + + def test_provider_file_upload(self): + """Tests uploading existing file.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + provider = "default" + function = "Program" + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-upload") + + with open("README.md") as f: + query_params = {"function": function, "provider": provider} + response = self.client.post( + f"{url}?{urlencode(query_params)}", + {"file": f}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue( + os.path.join(media_root, "default", "Program", "README.md") + ) + os.remove(os.path.join(media_root, "default", "Program", "README.md")) + + def test_escape_directory(self): + """Tests directory escape / injection.""" + with self.settings( + MEDIA_ROOT=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + ): + file = "../test_user/artifact.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-download") + response = self.client.get( + url, + { + "file": file, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + file = "../test_user/artifact.tar/" + response = self.client.get( + url, + { + "file": file, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/gateway/tests/fixtures/files_fixtures.json b/gateway/tests/fixtures/files_fixtures.json new file mode 100644 index 000000000..1e65eab64 --- /dev/null +++ b/gateway/tests/fixtures/files_fixtures.json @@ -0,0 +1,93 @@ +[ + { + "model": "auth.user", + "pk": 1, + "fields": { + "email": "test_user@email.com", + "username": "test_user", + "password": "pbkdf2_sha256$390000$kcex1rxhZg6VVJYkx71cBX$e4ns0xDykbO6Dz6j4nZ4uNusqkB9GVpojyegPv5/9KM=", + "is_active": true, + "groups": [ + 100 + ] + } + }, + { + "model": "auth.user", + "pk": 2, + "fields": { + "email": "test_user_2@email.com", + "username": "test_user_2", + "password": "pbkdf2_sha256$390000$kcex1rxhZg6VVJYkx71cBX$e4ns0xDykbO6Dz6j4nZ4uNusqkB9GVpojyegPv5/9KM=", + "is_active": true, + "groups": [ + 101, + 105 + ] + } + }, + { + "model": "auth.group", + "pk": 100, + "fields": { + "name": "runner", + "permissions": [ + 65 + ] + } + }, + { + "model": "auth.group", + "pk": 101, + "fields": { + "name": "viewer", + "permissions": [ + 64 + ] + } + }, + { + "model": "auth.group", + "pk": 105, + "fields": { + "name": "default-group" + } + }, + { + "model": "api.provider", + "pk": "bfe8aa6a-2127-4123-bf57-5b547293cbea", + "fields": { + "name": "default", + "created": "2023-02-01T15:30:43.281796Z", + "admin_groups": [ + 105 + ], + "registry": "docker.io/awesome" + } + }, + { + "model": "api.program", + "pk": "6160a2ff-e482-443d-af23-15110b646ae2", + "fields": { + "created": "2023-02-01T15:30:43.281796Z", + "title": "Program", + "image": "icr.io/awesome-namespace/awesome-title", + "author": 2, + "provider": "bfe8aa6a-2127-4123-bf57-5b547293cbea", + "instances": [ + 101 + ] + } + }, + { + "model": "api.program", + "pk": "5b7deef6-45e2-4142-9e31-efb9e32e3592", + "fields": { + "created": "2023-02-01T15:30:43.281796Z", + "title": "personal-program", + "entrypoint": "program.py", + "artifact": "path", + "author": 2 + } + } +] \ No newline at end of file diff --git a/gateway/tests/resources/fake_media/default/Program/provider_program_artifact.tar b/gateway/tests/resources/fake_media/default/Program/provider_program_artifact.tar new file mode 100644 index 0000000000000000000000000000000000000000..a3a73711326ad926b9a685f1ef5fd35a41a7ac62 GIT binary patch literal 10240 zcmeH{OK%)E41o8nUm++cVh8T*c)i{f26An$E$Tyyz!0=EoEh1ARS$ayMgM!KdDtns zJ?0QV+JjakC6XV>X{_lgH*UV=s&O6m`%T_I9&V?--pywe?e^FE&Gr7~rrsTH4!gsq zzP_z*KW=FKErVL$P9X|U=|+YpVOrsXjIDgK&#Qjj(X8v%!nq( zkmf{N+awJpiYTq2$0St?o@4{^j{X7$yinA-tPOWWYG-|rcT5`$s=7X ze#IjpsI9W163&5XFYJ0ehAh5#Usc;iNp%1I@{$ne?ISdiexmHRBuq@22Jp||6@ykK6z-NlLp!!b23pHB2pZz&#z)q+*^ya&c__% z1?Qbq!1*Nx>M@XI%h}%#K)a+A#RvlCe?_HfRBEN^sZeeZ+apA=^$m@p55$845g_}i zEBvwLiGa={{iFJGeTOXXEtVJ*oFbU)BB1Oup}1I>>QkYrGQwE}*08Jps4zE+CHx;v zNLmX&p3!Q|{ZX9GXr3j!G`58^DvSkLKZRBGnA{k_^EuR2I-mFw8IPj_?6mlL3n|-b zHo^zjTLxPqH{^O=Udc5-rWmK>A4ml4l!y>`?w<^c@mCppASUTf_9(yi36>M@tjBaw z_&1BWMc5r>_xTU~A+Udl5~fe0p9%+t5jXaEAOa6@L>nw;W0m#T35bGOwpm;`(@!J3lwq6qr5#?qd?n-kH0m+uJFz>b;}$#j z#mmWR&O4hmN`!!k8*JifHjg$Alk>Hxt4Df)wV)C2c!}69$!fr()L>J-ef?AYCWaF< zX>!L5pmX-7dXa4pSf!;AaJJ!;rxUUSyKR1PXydpb#ho3V}kP5GVu+fkL1VCR N1PXydpb+>MfiKM?(`Enw literal 0 HcmV?d00001 diff --git a/gateway/tests/resources/fake_media/test_user_2/default/Program/user_program_artifact.tar b/gateway/tests/resources/fake_media/test_user_2/default/Program/user_program_artifact.tar new file mode 100644 index 0000000000000000000000000000000000000000..a3a73711326ad926b9a685f1ef5fd35a41a7ac62 GIT binary patch literal 10240 zcmeH{OK%)E41o8nUm++cVh8T*c)i{f26An$E$Tyyz!0=EoEh1ARS$ayMgM!KdDtns zJ?0QV+JjakC6XV>X{_lgH*UV=s&O6m`%T_I9&V?--pywe?e^FE&Gr7~rrsTH4!gsq zzP_z*KW=FKErVL$P9X|U=|+YpVOrsXjIDgK&#Qjj(X8v%!nq( zkmf{N+awJpiYTq2$0St?o@4{^j{X7$yinA-tPOWWYG-|rcT5`$s=7X ze#IjpsI9W163&5XFYJ0ehAh5#Usc;iNp%1I@{$ne?ISdiexmHRBuq@22Jp||6@ykK6z-NlLp!!b23pHB2pZz&#z)q+*^ya&c__% z1?Qbq!1*Nx>M@XI%h}%#K)a+A#RvlCe?_HfRBEN^sZeeZ+apA=^$m@p55$845g_}i zEBvwLiGa={{iFJGeTOXXEtVJ*oFbU)BB1Oup}1I>>QkYrGQwE}*08Jps4zE+CHx;v zNLmX&p3!Q|{ZX9GXr3j!G`58^DvSkLKZRBGnA{k_^EuR2I-mFw8IPj_?6mlL3n|-b zHo^zjTLxPqH{^O=Udc5-rWmK>A4ml4l!y>`?w<^c@mCppASUTf_9(yi36>M@tjBaw z_&1BWMc5r>_xTU~A+Udl5~fe0p9%+t5jXaEAOa6@L>nw;W0m#T35bGOwpm;`(@!J3lwq6qr5#?qd?n-kH0m+uJFz>b;}$#j z#mmWR&O4hmN`!!k8*JifHjg$Alk>Hxt4Df)wV)C2c!}69$!fr()L>J-ef?AYCWaF< zX>!L5pmX-7dXa4pSf!;AaJJ!;rxUUSyKR1PXydpb#ho3V}kP5GVu+fkL1VCR N1PXydpb+>MfiKM?(`Enw literal 0 HcmV?d00001 diff --git a/tests/docker/test_docker_experimental.py b/tests/docker/test_docker_experimental.py index 03f62f44c..d2bfa192c 100644 --- a/tests/docker/test_docker_experimental.py +++ b/tests/docker/test_docker_experimental.py @@ -23,14 +23,15 @@ class TestDockerExperimental: @mark.order(1) def test_file_producer(self, serverless_client: ServerlessClient): """Integration test for files.""" + functionTitle = "file-producer-for-consume" function = QiskitFunction( - title="file-producer-for-consume", + title=functionTitle, entrypoint="produce_files.py", working_dir=resources_path, ) serverless_client.upload(function) - file_producer_function = serverless_client.function("file-producer-for-consume") + file_producer_function = serverless_client.function(functionTitle) job = file_producer_function.run() @@ -39,7 +40,7 @@ def test_file_producer(self, serverless_client: ServerlessClient): assert job.status() == "DONE" assert isinstance(job.logs(), str) - assert len(serverless_client.files()) > 0 + assert len(serverless_client.files(functionTitle)) > 0 @mark.skip( reason="File producing and consuming is not working. Maybe write permissions for functions?" @@ -47,14 +48,15 @@ def test_file_producer(self, serverless_client: ServerlessClient): @mark.order(2) def test_file_consumer(self, serverless_client: ServerlessClient): """Integration test for files.""" + functionTitle = "file-consumer" function = QiskitFunction( - title="file-consumer", + title=functionTitle, entrypoint="consume_files.py", working_dir=resources_path, ) serverless_client.upload(function) - file_consumer_function = serverless_client.function("file-consumer") + file_consumer_function = serverless_client.function(functionTitle) job = file_consumer_function.run() assert job is not None @@ -62,7 +64,7 @@ def test_file_consumer(self, serverless_client: ServerlessClient): assert job.status() == "DONE" assert isinstance(job.logs(), str) - files = serverless_client.files() + files = serverless_client.files(functionTitle) assert files is not None @@ -70,18 +72,19 @@ def test_file_consumer(self, serverless_client: ServerlessClient): assert file_count > 0 - serverless_client.file_delete("uploaded_file.tar") + serverless_client.file_delete("uploaded_file.tar", functionTitle) - assert (file_count - len(serverless_client.files())) == 1 + assert (file_count - len(serverless_client.files(functionTitle))) == 1 @mark.order(1) - def test_upload_download_delete(self, serverless_client: ServerlessClient): + def test_list_upload_download_delete(self, serverless_client: ServerlessClient): """Integration test for upload files.""" + function = serverless_client.function("hello-world") print("::: file_upload :::") - print(serverless_client.file_upload(filename_path)) + print(serverless_client.file_upload(filename_path, function)) - files = serverless_client.files() + files = serverless_client.files(function) print("::: files :::") print(files) @@ -92,19 +95,107 @@ def test_upload_download_delete(self, serverless_client: ServerlessClient): assert file_count == 1 print("::: file_download :::") - assert serverless_client.file_download(filename) is not None + assert serverless_client.file_download(filename, function) is not None - files = serverless_client.files() + files = serverless_client.files(function) print("::: files after download :::") print(files) assert file_count == len(files) print("::: file_delete :::") - print(serverless_client.file_delete(filename)) + print(serverless_client.file_delete(filename, function)) print("::: files after delete:::") - files = serverless_client.files() + files = serverless_client.files(function) + print(files) + + assert (file_count - len(files)) == 1 + + def test_list_upload_download_delete_with_provider_function( + self, serverless_client: ServerlessClient + ): + """Integration test for upload files.""" + function = QiskitFunction( + title="provider-function", + provider="mockprovider", + image="test-local-provider-function:latest", + ) + serverless_client.upload(function) + + function = serverless_client.function("mockprovider/provider-function") + + print("::: file_upload :::") + print(serverless_client.file_upload(filename_path, function)) + + files = serverless_client.files(function) + print("::: files :::") + print(files) + + file_count = len(files) + print("::: file_count :::") + print(file_count) + + assert file_count == 1 + + print("::: file_download :::") + assert serverless_client.file_download(filename, function) is not None + + files = serverless_client.files(function) + print("::: files after download :::") + print(files) + + assert file_count == len(files) + + print("::: file_delete :::") + print(serverless_client.file_delete(filename, function)) + + print("::: files after delete:::") + files = serverless_client.files(function) + print(files) + + assert (file_count - len(files)) == 1 + + def test_provider_list_upload_download_delete( + self, serverless_client: ServerlessClient + ): + """Integration test for upload files.""" + function = QiskitFunction( + title="provider-function", + provider="mockprovider", + image="test-local-provider-function:latest", + ) + serverless_client.upload(function) + + function = serverless_client.function("mockprovider/provider-function") + + print("::: Provider file_upload :::") + print(serverless_client.provider_file_upload(filename_path, function)) + + files = serverless_client.provider_files(function) + print("::: Provider files :::") + print(files) + + file_count = len(files) + print("::: Provider file_count :::") + print(file_count) + + assert file_count == 1 + + print("::: Provider file_download :::") + assert serverless_client.provider_file_download(filename, function) is not None + + files = serverless_client.provider_files(function) + print("::: Provider files after download :::") + print(files) + + assert file_count == len(files) + + print("::: Provider file_delete :::") + print(serverless_client.provider_file_delete(filename, function)) + + print("::: Provider files after delete:::") + files = serverless_client.provider_files(function) print(files) assert (file_count - len(files)) == 1 diff --git a/tests/experimental/file_download.py b/tests/experimental/file_download.py index ee4da23f7..658740577 100644 --- a/tests/experimental/file_download.py +++ b/tests/experimental/file_download.py @@ -21,9 +21,9 @@ print(job.status()) print(job.logs()) -available_files = serverless.files() +available_files = serverless.files(function) print(available_files) if len(available_files) > 0: - serverless.file_download(available_files[0]) + serverless.file_download(available_files[0], function) print("Download complete") diff --git a/tests/experimental/manage_data_directory.py b/tests/experimental/manage_data_directory.py index 70455a923..9abbf563d 100644 --- a/tests/experimental/manage_data_directory.py +++ b/tests/experimental/manage_data_directory.py @@ -9,6 +9,11 @@ ) print(serverless) +function = QiskitFunction( + title="file-producer", entrypoint="produce_files.py", working_dir="./source_files/" +) +serverless.upload(function) + import tarfile filename = "uploaded_file.tar" @@ -16,12 +21,7 @@ file.add("manage_data_directory.py") file.close() -serverless.file_upload(filename) - -function = QiskitFunction( - title="file-producer", entrypoint="produce_files.py", working_dir="./source_files/" -) -serverless.upload(function) +serverless.file_upload(filename, function) functions = {f.title: f for f in serverless.list()} file_producer_function = functions.get("file-producer") @@ -33,7 +33,7 @@ print(job.logs()) -print(serverless.files()) +print(serverless.files(file_producer_function)) function = QiskitFunction( title="file-consumer", entrypoint="consume_files.py", working_dir="./source_files/" @@ -49,8 +49,8 @@ print(job.status()) print(job.logs()) -print(serverless.files()) +print(serverless.files(file_consumer_function)) -serverless.file_delete("uploaded_file.tar") +serverless.file_delete("uploaded_file.tar", file_consumer_function) print("Done deleting files")