From baf9026a381a5b94e9c0bc21b97fe20bc7ea0819 Mon Sep 17 00:00:00 2001
From: David <9059044+Tansito@users.noreply.github.com>
Date: Tue, 26 Nov 2024 11:23:06 -0500
Subject: [PATCH] Catalog returns only the functions that you can run (#1540)
* move program model access to repository
* remove unused code
* fix linter
* fixed tests
* update swagger information
* added logger in the repository class
* rename provider method
* added TypeFilter enum
* fix typo
---
gateway/api/repositories/__init__.py | 0
gateway/api/repositories/programs.py | 201 +++++++++++++++++++++++++
gateway/api/v1/views/programs.py | 34 ++++-
gateway/api/views/enums/__init__.py | 0
gateway/api/views/enums/type_filter.py | 16 ++
gateway/api/views/jobs.py | 5 +-
gateway/api/views/programs.py | 142 ++++++-----------
gateway/tests/api/test_v1_program.py | 82 ++--------
gateway/tests/fixtures/fixtures.json | 2 +-
9 files changed, 306 insertions(+), 176 deletions(-)
create mode 100644 gateway/api/repositories/__init__.py
create mode 100644 gateway/api/repositories/programs.py
create mode 100644 gateway/api/views/enums/__init__.py
create mode 100644 gateway/api/views/enums/type_filter.py
diff --git a/gateway/api/repositories/__init__.py b/gateway/api/repositories/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/gateway/api/repositories/programs.py b/gateway/api/repositories/programs.py
new file mode 100644
index 000000000..f7e1aae74
--- /dev/null
+++ b/gateway/api/repositories/programs.py
@@ -0,0 +1,201 @@
+"""
+Repository implementatio for Programs model
+"""
+import logging
+
+from typing import Any, List
+
+from django.db.models import Q
+from django.contrib.auth.models import Group, Permission
+
+from api.models import RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION, Program
+
+
+logger = logging.getLogger("gateway")
+
+
+class ProgramRepository:
+ """
+ The main objective of this class is to manage the access to the model
+ """
+
+ def get_functions(self, author) -> List[Program] | Any:
+ """
+ Returns all the functions available to the user. This means:
+ - User functions where the user is the author
+ - Provider functions with view permissions
+
+ Args:
+ author: Django author from who retrieve the functions
+
+ Returns:
+ List[Program] | Any: all the functions available to the user
+ """
+
+ view_program_permission = Permission.objects.get(
+ codename=VIEW_PROGRAM_PERMISSION
+ )
+
+ user_criteria = Q(user=author)
+ view_permission_criteria = Q(permissions=view_program_permission)
+
+ author_groups_with_view_permissions = Group.objects.filter(
+ user_criteria & view_permission_criteria
+ )
+
+ author_criteria = Q(author=author)
+ author_groups_with_view_permissions_criteria = Q(
+ instances__in=author_groups_with_view_permissions
+ )
+
+ result_queryset = Program.objects.filter(
+ author_criteria | author_groups_with_view_permissions_criteria
+ ).distinct()
+
+ count = result_queryset.count()
+ logger.info("[%d] Functions found for author [%s]", count, author.id)
+
+ return result_queryset
+
+ def get_user_functions(self, author) -> List[Program] | Any:
+ """
+ Returns the user functions available to the user. This means:
+ - User functions where the user is the author
+ - Provider is None
+
+ Args:
+ author: Django author from who retrieve the functions
+
+ Returns:
+ List[Program] | Any: user functions available to the user
+ """
+
+ author_criteria = Q(author=author)
+ provider_criteria = Q(provider=None)
+
+ result_queryset = Program.objects.filter(
+ author_criteria & provider_criteria
+ ).distinct()
+
+ count = result_queryset.count()
+ logger.info("[%d] user Functions found for author [%s]", count, author.id)
+
+ return result_queryset
+
+ def get_provider_functions_with_run_permissions(
+ self, author
+ ) -> List[Program] | Any:
+ """
+ Returns the provider functions available to the user. This means:
+ - Provider functions where the user has run permissions
+ - Provider is NOT None
+
+ Args:
+ author: Django author from who retrieve the functions
+
+ Returns:
+ List[Program] | Any: providers functions available to the user
+ """
+
+ run_program_permission = Permission.objects.get(codename=RUN_PROGRAM_PERMISSION)
+
+ user_criteria = Q(user=author)
+ run_permission_criteria = Q(permissions=run_program_permission)
+ author_groups_with_run_permissions = Group.objects.filter(
+ user_criteria & run_permission_criteria
+ )
+
+ author_groups_with_run_permissions_criteria = Q(
+ instances__in=author_groups_with_run_permissions
+ )
+
+ provider_exists_criteria = ~Q(provider=None)
+
+ result_queryset = Program.objects.filter(
+ author_groups_with_run_permissions_criteria & provider_exists_criteria
+ ).distinct()
+
+ count = result_queryset.count()
+ logger.info("[%d] provider Functions found for author [%s]", count, author.id)
+
+ return result_queryset
+
+ def get_user_function_by_title(self, author, title: str) -> Program | Any:
+ """
+ Returns the user function associated to a title:
+
+ Args:
+ author: Django author from who retrieve the function
+ title: Title that the function must have to find it
+
+ Returns:
+ Program | Any: user function with the specific title
+ """
+
+ author_criteria = Q(author=author)
+ title_criteria = Q(title=title)
+
+ result_queryset = Program.objects.filter(
+ author_criteria & title_criteria
+ ).first()
+
+ if result_queryset is None:
+ logger.warning(
+ "Function [%s] was not found or author [%s] doesn't have access to it",
+ title,
+ author.id,
+ )
+
+ return result_queryset
+
+ def get_provider_function_by_title(
+ self, author, title: str, provider_name: str
+ ) -> Program | Any:
+ """
+ Returns the provider function associated to:
+ - A Function title
+ - A Provider
+ - Author must have view permission to see it or be the author
+
+ Args:
+ author: Django author from who retrieve the function
+ title: Title that the function must have to find it
+ provider: Provider associated to the function
+
+ Returns:
+ Program | Any: provider function with the specific
+ title and provider
+ """
+
+ view_program_permission = Permission.objects.get(
+ codename=VIEW_PROGRAM_PERMISSION
+ )
+
+ user_criteria = Q(user=author)
+ view_permission_criteria = Q(permissions=view_program_permission)
+
+ author_groups_with_view_permissions = Group.objects.filter(
+ user_criteria & view_permission_criteria
+ )
+
+ author_criteria = Q(author=author)
+ author_groups_with_view_permissions_criteria = Q(
+ instances__in=author_groups_with_view_permissions
+ )
+
+ title_criteria = Q(title=title, provider__name=provider_name)
+
+ result_queryset = Program.objects.filter(
+ (author_criteria | author_groups_with_view_permissions_criteria)
+ & title_criteria
+ ).first()
+
+ if result_queryset is None:
+ logger.warning(
+ "Function [%s/%s] was not found or author [%s] doesn't have access to it",
+ provider_name,
+ title,
+ author.id,
+ )
+
+ return result_queryset
diff --git a/gateway/api/v1/views/programs.py b/gateway/api/v1/views/programs.py
index 35c44a79a..1280c2222 100644
--- a/gateway/api/v1/views/programs.py
+++ b/gateway/api/v1/views/programs.py
@@ -2,7 +2,7 @@
Programs view api for V1.
"""
-# pylint: disable=duplicate-code
+from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework import permissions, status
from rest_framework.decorators import action
@@ -42,6 +42,15 @@ def get_serializer_job(*args, **kwargs):
@swagger_auto_schema(
operation_description="List author Qiskit Functions",
+ manual_parameters=[
+ openapi.Parameter(
+ "filter",
+ openapi.IN_QUERY,
+ description="Filters that you can apply for list: serverless, catalog or empty",
+ type=openapi.TYPE_STRING,
+ required=False,
+ ),
+ ],
responses={status.HTTP_200_OK: v1_serializers.ProgramSerializer(many=True)},
)
def list(self, request):
@@ -64,3 +73,26 @@ def upload(self, request):
@action(methods=["POST"], detail=False)
def run(self, request):
return super().run(request)
+
+ @swagger_auto_schema(
+ operation_description="Retrieve a Qiskit Function using the title",
+ manual_parameters=[
+ openapi.Parameter(
+ "title",
+ openapi.IN_PATH,
+ description="The title of the function",
+ type=openapi.TYPE_STRING,
+ ),
+ openapi.Parameter(
+ "provider",
+ openapi.IN_QUERY,
+ description="The provider in case the function is owned by a provider",
+ type=openapi.TYPE_STRING,
+ required=False,
+ ),
+ ],
+ responses={status.HTTP_200_OK: v1_serializers.ProgramSerializer},
+ )
+ @action(methods=["GET"], detail=False, url_path="get_by_title/(?P
[^/.]+)")
+ def get_by_title(self, request, title):
+ return super().get_by_title(request, title)
diff --git a/gateway/api/views/enums/__init__.py b/gateway/api/views/enums/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/gateway/api/views/enums/type_filter.py b/gateway/api/views/enums/type_filter.py
new file mode 100644
index 000000000..9bd421d70
--- /dev/null
+++ b/gateway/api/views/enums/type_filter.py
@@ -0,0 +1,16 @@
+"""
+This class defines TypeFilter enum for views:
+"""
+
+from enum import Enum
+
+
+class TypeFilter(str, Enum):
+ """
+ TypeFilter values for the view end-points:
+ - SERVERLESS
+ - CATALOG
+ """
+
+ SERVERLESS = "serverless"
+ CATALOG = "catalog"
diff --git a/gateway/api/views/jobs.py b/gateway/api/views/jobs.py
index f6932ad8e..45376e50d 100644
--- a/gateway/api/views/jobs.py
+++ b/gateway/api/views/jobs.py
@@ -25,6 +25,7 @@
from qiskit_ibm_runtime import RuntimeInvalidStateError, QiskitRuntimeService
from api.models import Job, RuntimeJob
from api.ray import get_job_handler
+from api.views.enums.type_filter import TypeFilter
# pylint: disable=duplicate-code
logger = logging.getLogger("gateway")
@@ -56,13 +57,13 @@ def get_serializer_class(self):
def get_queryset(self):
type_filter = self.request.query_params.get("filter")
if type_filter:
- if type_filter == "catalog":
+ if type_filter == TypeFilter.CATALOG:
user_criteria = Q(author=self.request.user)
provider_exists_criteria = ~Q(program__provider=None)
return Job.objects.filter(
user_criteria & provider_exists_criteria
).order_by("-created")
- if type_filter == "serverless":
+ if type_filter == TypeFilter.SERVERLESS:
user_criteria = Q(author=self.request.user)
provider_not_exists_criteria = Q(program__provider=None)
return Job.objects.filter(
diff --git a/gateway/api/views/programs.py b/gateway/api/views/programs.py
index cfe5242e2..dd6eba04f 100644
--- a/gateway/api/views/programs.py
+++ b/gateway/api/views/programs.py
@@ -5,7 +5,6 @@
"""
import logging
import os
-from typing import Optional
from django.db.models import Q
from django.contrib.auth.models import Group, Permission
@@ -21,6 +20,7 @@
from rest_framework import viewsets, status
from rest_framework.response import Response
+from api.repositories.programs import ProgramRepository
from api.utils import sanitize_name
from api.serializers import (
JobConfigSerializer,
@@ -29,7 +29,8 @@
RunProgramSerializer,
UploadProgramSerializer,
)
-from api.models import VIEW_PROGRAM_PERMISSION, RUN_PROGRAM_PERMISSION, Program, Job
+from api.models import RUN_PROGRAM_PERMISSION, Program, Job
+from api.views.enums.type_filter import TypeFilter
# pylint: disable=duplicate-code
logger = logging.getLogger("gateway")
@@ -55,6 +56,8 @@ class ProgramViewSet(viewsets.GenericViewSet):
BASE_NAME = "programs"
+ program_repository = ProgramRepository()
+
@staticmethod
def get_serializer_job_config(*args, **kwargs):
"""
@@ -101,28 +104,6 @@ def get_serializer_class(self):
def get_object(self):
logger.warning("ProgramViewSet.get_object not implemented")
- def get_queryset(self):
- author = self.request.user
- title = sanitize_name(self.request.query_params.get("title"))
- provider_name = sanitize_name(self.request.query_params.get("provider"))
- type_filter = self.request.query_params.get("filter")
-
- author_programs = self._get_program_queryset_for_title_and_provider(
- author=author,
- title=title,
- provider_name=provider_name,
- type_filter=type_filter,
- ).distinct()
-
- author_programs_count = author_programs.count()
- logger.info(
- "ProgramViewSet get author [%s] programs [%s]",
- author.id,
- author_programs_count,
- )
-
- return author_programs
-
def get_run_queryset(self):
"""get run queryset"""
author = self.request.user
@@ -167,7 +148,28 @@ def list(self, request):
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.program.list", context=ctx):
- serializer = self.get_serializer(self.get_queryset(), many=True)
+
+ author = self.request.user
+ type_filter = self.request.query_params.get("filter")
+
+ if type_filter == TypeFilter.SERVERLESS:
+ # Serverless filter only returns functions created by the author
+ # with the next criterias:
+ # - user is the author of the function and there is no provider
+ functions = self.program_repository.get_user_functions(author)
+ elif type_filter == TypeFilter.CATALOG:
+ # Catalog filter only returns providers functions that user has access:
+ # author has view permissions and the function has a provider assigned
+ functions = (
+ self.program_repository.get_provider_functions_with_run_permissions(
+ author
+ )
+ )
+ else:
+ # If filter is not applied we return author and providers functions together
+ functions = self.program_repository.get_functions(author)
+
+ serializer = self.get_serializer(functions, many=True)
return Response(serializer.data)
@@ -296,87 +298,29 @@ def run(self, request):
def get_by_title(self, request, title):
"""Returns programs by title."""
author = self.request.user
- provider_name = self.request.query_params.get("provider")
-
- result_program = self._get_program_queryset_for_title_and_provider(
- author=author, title=title, provider_name=provider_name, type_filter=None
- ).first()
-
- if result_program:
- return Response(self.get_serializer(result_program).data)
-
- return Response(status=404)
+ function_title = sanitize_name(title)
+ provider_name = sanitize_name(request.query_params.get("provider", None))
- def _get_program_queryset_for_title_and_provider(
- self,
- author,
- title: str,
- provider_name: Optional[str],
- type_filter: Optional[str],
- ):
- """Returns queryset for program for gived request, title and provider."""
- view_program_permission = Permission.objects.get(
- codename=VIEW_PROGRAM_PERMISSION
+ serializer = self.get_serializer_upload_program(data=self.request.data)
+ provider_name, function_title = serializer.get_provider_name_and_title(
+ provider_name, function_title
)
- user_criteria = Q(user=author)
- view_permission_criteria = Q(permissions=view_program_permission)
- author_groups_with_view_permissions = Group.objects.filter(
- user_criteria & view_permission_criteria
- )
- author_groups_with_view_permissions_count = (
- author_groups_with_view_permissions.count()
- )
- logger.info(
- "ProgramViewSet get author [%s] groups [%s]",
- author.id,
- author_groups_with_view_permissions_count,
- )
-
- author_criteria = Q(author=author)
- author_groups_with_view_permissions_criteria = Q(
- instances__in=author_groups_with_view_permissions
- )
-
- # Serverless filter only returns functions created by the author with the next criterias:
- # user is the author of the function and there is no provider
- if type_filter == "serverless":
- provider_criteria = Q(provider=None)
- result_queryset = Program.objects.filter(
- author_criteria & provider_criteria
- )
- return result_queryset
-
- # Catalog filter only returns providers functions that user has access:
- # author has view permissions and the function has a provider assigned
- if type_filter == "catalog":
- provider_exists_criteria = ~Q(provider=None)
- result_queryset = Program.objects.filter(
- author_groups_with_view_permissions_criteria & provider_exists_criteria
- )
- return result_queryset
-
- # If filter is not applied we return author and providers functions together
- title = sanitize_name(title)
- provider_name = sanitize_name(provider_name)
- if title:
- serializer = self.get_serializer_upload_program(data=self.request.data)
- provider_name, title = serializer.get_provider_name_and_title(
- provider_name, title
- )
- title_criteria = Q(title=title)
- if provider_name:
- title_criteria = Q(title=title, provider__name=provider_name)
- result_queryset = Program.objects.filter(
- (author_criteria | author_groups_with_view_permissions_criteria)
- & title_criteria
+ if provider_name:
+ function = self.program_repository.get_provider_function_by_title(
+ author=author, title=function_title, provider_name=provider_name
)
else:
- result_queryset = Program.objects.filter(
- author_criteria | author_groups_with_view_permissions_criteria
+ function = self.program_repository.get_user_function_by_title(
+ author=author, title=function_title
)
- return result_queryset
+ if function:
+ return Response(self.get_serializer(function).data)
+
+ return Response(status=404)
+
+ # This end-point is deprecated and we need to confirm if we can remove it
@action(methods=["GET"], detail=True)
def get_jobs(
self, request, pk=None
diff --git a/gateway/tests/api/test_v1_program.py b/gateway/tests/api/test_v1_program.py
index 4676cd9f2..931047a56 100644
--- a/gateway/tests/api/test_v1_program.py
+++ b/gateway/tests/api/test_v1_program.py
@@ -68,13 +68,21 @@ def test_provider_programs_catalog_list(self):
)
self.assertEqual(programs_response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(programs_response.data), 1)
+ self.assertEqual(len(programs_response.data), 2)
self.assertEqual(
programs_response.data[0].get("provider"),
"ibm",
)
self.assertEqual(
programs_response.data[0].get("title"),
+ "Docker-Image-Program-2",
+ )
+ self.assertEqual(
+ programs_response.data[1].get("provider"),
+ "ibm",
+ )
+ self.assertEqual(
+ programs_response.data[1].get("title"),
"Docker-Image-Program-3",
)
@@ -95,78 +103,6 @@ def test_provider_programs_serverless_list(self):
"Program",
)
- def test_program_list_with_title_query_parameter(self):
- """Tests program list filtered with title."""
- user = models.User.objects.get(username="test_user")
- self.client.force_authenticate(user=user)
-
- programs_response = self.client.get(
- reverse("v1:programs-list"), {"title": "Program"}, format="json"
- )
-
- self.assertEqual(programs_response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(programs_response.data), 1)
- self.assertEqual(
- programs_response.data[0].get("title"),
- "Program",
- )
-
- empty_programs_response = self.client.get(
- reverse("v1:programs-list"), {"title": "Non existing name"}, format="json"
- )
- self.assertEqual(empty_programs_response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(empty_programs_response.data), 0)
-
- def test_program_list_with_title_query_title_and_provider(self):
- """Tests program list filtered with title."""
- user = models.User.objects.get(username="test_user_2")
- self.client.force_authenticate(user=user)
-
- programs_response = self.client.get(
- reverse("v1:programs-list"),
- {"title": "Docker-Image-Program"},
- format="json",
- )
-
- self.assertEqual(programs_response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(programs_response.data), 1)
- self.assertEqual(programs_response.data[0].get("provider"), "default")
-
- programs_response_with_provider = self.client.get(
- reverse("v1:programs-list"),
- {"title": "default/Docker-Image-Program"},
- format="json",
- )
- self.assertEqual(
- programs_response_with_provider.status_code, status.HTTP_200_OK
- )
- self.assertEqual(len(programs_response_with_provider.data), 1)
- self.assertEqual(
- programs_response_with_provider.data[0].get("provider"), "default"
- )
-
- programs_response_with_provider_as_parameter = self.client.get(
- reverse("v1:programs-list"),
- {"title": "Docker-Image-Program", "provider": "default"},
- format="json",
- )
- self.assertEqual(
- programs_response_with_provider_as_parameter.status_code, status.HTTP_200_OK
- )
- self.assertEqual(len(programs_response_with_provider_as_parameter.data), 1)
- self.assertEqual(
- programs_response_with_provider_as_parameter.data[0].get("provider"),
- "default",
- )
-
- programs_response_empty = self.client.get(
- reverse("v1:programs-list"),
- {"title": "Docker-Image-Program", "provider": "non-existing-provider"},
- format="json",
- )
- self.assertEqual(programs_response_empty.status_code, status.HTTP_200_OK)
- self.assertEqual(len(programs_response_empty.data), 0)
-
def test_run(self):
"""Tests run existing authorized."""
diff --git a/gateway/tests/fixtures/fixtures.json b/gateway/tests/fixtures/fixtures.json
index 0aea2f7d1..93d89f7bc 100644
--- a/gateway/tests/fixtures/fixtures.json
+++ b/gateway/tests/fixtures/fixtures.json
@@ -49,7 +49,7 @@
"password": "pbkdf2_sha256$390000$kcex1rxhZg6VVJYkx71cBX$e4ns0xDykbO6Dz6j4nZ4uNusqkB9GVpojyegPv5/9KM=",
"is_active": true,
"groups": [
- 101
+ 100
]
}
},