Skip to content

Commit

Permalink
Program upload refactorization (#1112)
Browse files Browse the repository at this point in the history
* created exceptions class

* created base program service class

* use new service in the upload method

* services now are versioned

* applied services versioned logic

* Applied black format

* Remove request from the business logic

* created a save program service test

* fixed black format

* removed self from static methods

* fixed pylint messages

* removed commented line

* rename services test

* fix returned program without serialize

* test for program service

* added some logs for test

* configure a logger for the gateway services
  • Loading branch information
Tansito authored Nov 22, 2023
1 parent e8a6934 commit f81a2f8
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 35 deletions.
33 changes: 33 additions & 0 deletions gateway/api/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Custom exceptions for the gateway application
"""

from rest_framework import status


class GatewayException(Exception):
"""
Generic custom exception for our application
"""

def __init__(self, message):
super().__init__(message)


class GatewayHttpException(GatewayException):
"""
Generic http custom exception for our application
"""

def __init__(self, message, http_code):
super().__init__(message)
self.http_code = http_code


class InternalServerErrorException(GatewayHttpException):
"""
A wrapper for when we want to raise an internal server error
"""

def __init__(self, message, http_code=status.HTTP_500_INTERNAL_SERVER_ERROR):
super().__init__(message, http_code)
32 changes: 1 addition & 31 deletions gateway/api/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from opentelemetry import trace

from api.models import Job, Program, ComputeResource
from api.models import Job, ComputeResource
from api.ray import submit_job, create_ray_cluster, kill_ray_cluster
from main import settings as config

Expand All @@ -22,36 +22,6 @@
logger = logging.getLogger("commands")


def save_program(serializer, request) -> Program:
"""Save program.
Args:
request: request data.
Returns:
saved program
"""

existing_program = (
Program.objects.filter(title=serializer.data.get("title"), author=request.user)
.order_by("-created")
.first()
)

if existing_program is not None:
program = existing_program
program.arguments = serializer.data.get("arguments")
program.entrypoint = serializer.data.get("entrypoint")
program.dependencies = serializer.data.get("dependencies", "[]")
program.env_vars = serializer.data.get("env_vars", "{}")
else:
program = Program(**serializer.data)
program.artifact = request.FILES.get("artifact")
program.author = request.user
program.save()
return program


def execute_job(job: Job) -> Job:
"""Executes program.
Expand Down
67 changes: 67 additions & 0 deletions gateway/api/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Services for api application:
- Program Service
- Job Service
Version services inherit from the different services.
"""

# pylint: disable=too-few-public-methods

import logging

from .models import Program
from .exceptions import InternalServerErrorException

logger = logging.getLogger("gateway.services")


class ProgramService:
"""
Program service allocate the logic related with programs
"""

@staticmethod
def save(serializer, author, artifact) -> Program:
"""
Save method gets a program serializer and creates or updates a program
"""

title = serializer.data.get("title")
existing_program = (
Program.objects.filter(title=title, author=author)
.order_by("-created")
.first()
)

if existing_program is not None:
program = existing_program
program.arguments = serializer.data.get("arguments")
program.entrypoint = serializer.data.get("entrypoint")
program.dependencies = serializer.data.get("dependencies", "[]")
program.env_vars = serializer.data.get("env_vars", "{}")
logger.debug("Program [%s] will be updated by [%s]", title, author)
else:
program = Program(**serializer.data)
logger.debug("Program [%s] will be created by [%s]", title, author)
program.artifact = artifact
program.author = author

# It would be nice if we could unify all the saves logic in one unique entry-point
try:
program.save()
except (Exception) as save_program_exception:
logger.error(
"Exception was caught saving the program [%s] by [%s] \n"
"Error trace: %s",
title,
author,
save_program_exception,
)
raise InternalServerErrorException(
"Unexpected error saving the program"
) from save_program_exception

logger.debug("Program [%s] saved", title)

return program
13 changes: 13 additions & 0 deletions gateway/api/v1/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Services api for V1.
"""

# pylint: disable=too-few-public-methods

from api import services


class ProgramService(services.ProgramService):
"""
Program service first version.
"""
5 changes: 5 additions & 0 deletions gateway/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from api.models import Program, Job
from api.permissions import IsOwner
from . import serializers as v1_serializers
from . import services as v1_services


class ProgramViewSet(views.ProgramViewSet): # pylint: disable=too-many-ancestors
Expand All @@ -24,6 +25,10 @@ class ProgramViewSet(views.ProgramViewSet): # pylint: disable=too-many-ancestor
def get_serializer_job_class():
return v1_serializers.JobSerializer

@staticmethod
def get_service_program_class():
return v1_services.ProgramService

def get_serializer_class(self):
return v1_serializers.ProgramSerializer

Expand Down
36 changes: 32 additions & 4 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from .models import Program, Job
from .services import ProgramService
from .exceptions import InternalServerErrorException
from .ray import get_job_handler
from .schedule import save_program

from .serializers import JobSerializer, ExistingProgramSerializer, JobConfigSerializer
from .utils import build_env_variables, encrypt_env_vars

Expand Down Expand Up @@ -71,6 +73,14 @@ def get_serializer_job_class():

return JobSerializer

@staticmethod
def get_service_program_class():
"""
This method returns Program service to be used in Program ViewSet.
"""

return ProgramService

def get_serializer_class(self):
return self.serializer_class

Expand All @@ -89,8 +99,18 @@ def upload(self, request):
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

save_program(serializer=serializer, request=request)
return Response(serializer.data)
program_service = self.get_service_program_class()
try:
program = program_service.save(
serializer=serializer,
author=request.user,
artifact=request.FILES.get("artifact"),
)
except InternalServerErrorException as exception:
return Response(exception, exception.http_code)

program_serializer = self.get_serializer(program)
return Response(program_serializer.data)

@action(methods=["POST"], detail=False)
def run_existing(self, request):
Expand Down Expand Up @@ -173,7 +193,15 @@ def run(self, request):

jobconfig = config_serializer.save()

program = save_program(serializer=serializer, request=request)
program_service = self.get_service_program_class()
try:
program = program_service.save(
serializer=serializer,
author=request.user,
artifact=request.FILES.get("artifact"),
)
except InternalServerErrorException as exception:
return Response(exception, exception.http_code)

job = Job(
program=program,
Expand Down
5 changes: 5 additions & 0 deletions gateway/main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@
"level": LOG_LEVEL,
"propagate": False,
},
"gateway.services": {
"handlers": ["console"],
"level": LOG_LEVEL,
"propagate": False,
},
},
}

Expand Down
26 changes: 26 additions & 0 deletions gateway/tests/api/test_v1_services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json
from rest_framework.test import APITestCase
from api.v1.services import ProgramService
from api.v1.serializers import ProgramSerializer
from api.models import Program
from django.contrib.auth.models import User


class ServicesTest(APITestCase):
"""Tests for V1 services."""

fixtures = ["tests/fixtures/fixtures.json"]

def test_save_program(self):
"""Test to verify that the service creates correctly an entry with its serializer."""

user = User.objects.get(id=1)
data = '{"title": "My Qiskit Pattern", "entrypoint": "pattern.py"}'
program_serializer = ProgramSerializer(data=json.loads(data))
program_serializer.is_valid()

program = ProgramService.save(program_serializer, user, "path")
entry = Program.objects.get(id=program.id)

self.assertIsNotNone(entry)
self.assertEqual(program.title, entry.title)

0 comments on commit f81a2f8

Please sign in to comment.