diff --git a/CHANGELOG.md b/CHANGELOG.md index 044ebdf25..a31b495da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - New `SECRET_KEY` optional environment variable ([#671](https://github.com/Substra/substra-backend/pull/671)) - `/api-token-auth/` and the associated tokens can now be disabled through the `EXPIRY_TOKEN_ENABLED` environment variable and `server.allowImplicitLogin` chart value ([#698](https://github.com/Substra/substra-backend/pull/698)) - Tokens issued by `/api-token-auth/` can now be deleted like other API tokens, through a `DELETE` request on the `/active-api-tokens` endpoint ([#698](https://github.com/Substra/substra-backend/pull/698)) +- Field `asset_type` on `AssetFailureReport` (based on protobuf enum `orchestrator.FailedAssetKind`) ([#727](https://github.com/Substra/substra-backend/pull/727)) +- Celery task `FailableTask` that contains the logic to store the failure report, that can be re-used in different +assets. ([#727](https://github.com/Substra/substra-backend/pull/727)) ### Changed - Increase the number of tasks displayable in frontend workflow [#697](https://github.com/Substra/substra-backend/pull/697) - BREAKING: Change the format of many API responses from `{"message":...}` to `{"detail":...}` ([#705](https://github.com/Substra/substra-backend/pull/705)) +- `ComputeTaskFailureReport` renamed in `AssetFailureReport` ([#727](https://github.com/Substra/substra-backend/pull/727)) +- Field `AssetFailureReport.compute_task_key` renamed to `asset_key` ([#727](https://github.com/Substra/substra-backend/pull/727)) ### Removed diff --git a/backend/api/events/sync.py b/backend/api/events/sync.py index 8d5127356..703a2982c 100644 --- a/backend/api/events/sync.py +++ b/backend/api/events/sync.py @@ -33,6 +33,7 @@ from api.serializers import PerformanceSerializer from orchestrator import client as orc_client from orchestrator import computetask +from orchestrator import failure_report_pb2 logger = structlog.get_logger(__name__) @@ -89,7 +90,7 @@ def _on_update_function_event(event: dict) -> None: _update_function(key=event["asset_key"], name=function["name"], status=function["status"]) -def _update_function(key: str, *, name: Optional[str], status: Optional[str]) -> None: +def _update_function(key: str, *, name: Optional[str] = None, status: Optional[str] = None) -> None: """Process update function event to update local database.""" function = Function.objects.get(key=key) @@ -382,7 +383,17 @@ def _disable_model(key: str) -> None: def _on_create_failure_report(event: dict) -> None: """Process create failure report event to update local database.""" logger.debug("Syncing failure report create", asset_key=event["asset_key"], event_id=event["id"]) - _update_computetask(key=event["asset_key"], failure_report=event["failure_report"]) + + asset_key = event["asset_key"] + failure_report = event["failure_report"] + asset_type = failure_report_pb2.FailedAssetKind.Value(failure_report["asset_type"]) + + if asset_type == failure_report_pb2.FAILED_ASSET_FUNCTION: + # Needed as this field is only in ComputeTask + compute_task_key = ComputeTask.objects.values_list("key", flat=True).get(function_id=asset_key) + _update_computetask(key=str(compute_task_key), failure_report={"error_type": failure_report.get("error_type")}) + else: + _update_computetask(key=asset_key, failure_report=failure_report) EVENT_CALLBACKS = { diff --git a/backend/api/migrations/0053_function_status.py b/backend/api/migrations/0053_function_status.py index 21559afe4..268cc522e 100644 --- a/backend/api/migrations/0053_function_status.py +++ b/backend/api/migrations/0053_function_status.py @@ -15,14 +15,14 @@ class Migration(migrations.Migration): name="status", field=models.CharField( choices=[ - ("FUNCTION_STATUS_UNKONWN", "Function Status Unkonwn"), + ("FUNCTION_STATUS_UNKNOWN", "Function Status Unknown"), ("FUNCTION_STATUS_CREATED", "Function Status Created"), ("FUNCTION_STATUS_BUILDING", "Function Status Building"), ("FUNCTION_STATUS_READY", "Function Status Ready"), ("FUNCTION_STATUS_CANCELED", "Function Status Canceled"), ("FUNCTION_STATUS_FAILED", "Function Status Failed"), ], - default="FUNCTION_STATUS_UNKONWN", + default="FUNCTION_STATUS_UNKNOWN", max_length=64, ), preserve_default=False, diff --git a/backend/api/tests/asset_factory.py b/backend/api/tests/asset_factory.py index cdf8bd57b..fc15769a6 100644 --- a/backend/api/tests/asset_factory.py +++ b/backend/api/tests/asset_factory.py @@ -62,6 +62,7 @@ import datetime import uuid +from typing import Optional from django.core import files from django.utils import timezone @@ -80,9 +81,10 @@ from api.models import Model from api.models import Performance from api.models import TaskProfiling -from substrapp.models import ComputeTaskFailureReport as ComputeTaskLogs +from substrapp.models import AssetFailureReport from substrapp.models import DataManager as DataManagerFiles from substrapp.models import DataSample as DataSampleFiles +from substrapp.models import FailedAssetKind from substrapp.models import Function as FunctionFiles from substrapp.models import Model as ModelFiles from substrapp.utils import get_hash @@ -535,20 +537,36 @@ def create_model_files( return model_files -def create_computetask_logs( - compute_task_key: uuid.UUID, - logs: files.File = None, -) -> ComputeTaskLogs: +def create_asset_logs( + asset_key: uuid.UUID, + asset_type: FailedAssetKind, + logs: Optional[files.File] = None, +) -> AssetFailureReport: if logs is None: logs = files.base.ContentFile("dummy content") - compute_task_logs = ComputeTaskLogs.objects.create( - compute_task_key=compute_task_key, + asset_logs = AssetFailureReport.objects.create( + asset_key=asset_key, + asset_type=asset_type, logs_checksum=get_hash(logs), creation_date=timezone.now(), ) - compute_task_logs.logs.save("logs", logs) - return compute_task_logs + asset_logs.logs.save("logs", logs) + return asset_logs + + +def create_computetask_logs( + compute_task_key: uuid.UUID, + logs: Optional[files.File] = None, +) -> AssetFailureReport: + return create_asset_logs(compute_task_key, FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, logs) + + +def create_function_logs( + function_key: uuid.UUID, + logs: Optional[files.File] = None, +) -> AssetFailureReport: + return create_asset_logs(function_key, FailedAssetKind.FAILED_ASSET_FUNCTION, logs) def create_computetask_profiling(compute_task: ComputeTask) -> TaskProfiling: diff --git a/backend/api/tests/views/test_views_computetask_logs.py b/backend/api/tests/views/test_views_failed_asset_logs.py similarity index 90% rename from backend/api/tests/views/test_views_computetask_logs.py rename to backend/api/tests/views/test_views_failed_asset_logs.py index 5083c8a88..5fe6db3d4 100644 --- a/backend/api/tests/views/test_views_computetask_logs.py +++ b/backend/api/tests/views/test_views_failed_asset_logs.py @@ -13,11 +13,11 @@ from api.views import utils as view_utils from organization import authentication as organization_auth from organization import models as organization_models -from substrapp.models import ComputeTaskFailureReport +from substrapp.models import AssetFailureReport @pytest.fixture -def compute_task_failure_report() -> tuple[ComputeTask, ComputeTaskFailureReport]: +def asset_failure_report() -> tuple[ComputeTask, AssetFailureReport]: compute_task = factory.create_computetask( factory.create_computeplan(), factory.create_function(), @@ -41,12 +41,12 @@ def test_download_logs_failure_unauthenticated(api_client: test.APIClient): @pytest.mark.django_db def test_download_local_logs_success( - compute_task_failure_report, + asset_failure_report, authenticated_client: test.APIClient, ): """An authorized user download logs located on the organization.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local assert conf.settings.LEDGER_MSP_ID in compute_task.logs_permission_authorized_ids # allowed @@ -60,12 +60,12 @@ def test_download_local_logs_success( @pytest.mark.django_db def test_download_logs_failure_forbidden( - compute_task_failure_report, + asset_failure_report, authenticated_client: test.APIClient, ): """An authenticated user cannot download logs if he is not authorized.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local compute_task.logs_permission_authorized_ids = [] # not allowed compute_task.save() @@ -77,12 +77,12 @@ def test_download_logs_failure_forbidden( @pytest.mark.django_db def test_download_local_logs_failure_not_found( - compute_task_failure_report, + asset_failure_report, authenticated_client: test.APIClient, ): """An authorized user attempt to download logs that are not referenced in the database.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local assert conf.settings.LEDGER_MSP_ID in compute_task.logs_permission_authorized_ids # allowed failure_report.delete() # not found @@ -94,12 +94,12 @@ def test_download_local_logs_failure_not_found( @pytest.mark.django_db def test_download_remote_logs_success( - compute_task_failure_report, + asset_failure_report, authenticated_client: test.APIClient, ): """An authorized user download logs on a remote organization by using his organization as proxy.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report outgoing_organization = "outgoing-organization" compute_task.logs_owner = outgoing_organization # remote compute_task.logs_permission_authorized_ids = [conf.settings.LEDGER_MSP_ID, outgoing_organization] # allowed @@ -139,13 +139,13 @@ def get_proxy_headers(channel_name: str) -> dict[str, str]: @pytest.mark.django_db def test_organization_download_logs_success( - compute_task_failure_report, + asset_failure_report, api_client: test.APIClient, incoming_organization_user: organization_auth.OrganizationUser, ): """An authorized organization can download logs from another organization.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report compute_task.logs_owner = conf.settings.LEDGER_MSP_ID # local (incoming request from remote) compute_task.logs_permission_authorized_ids = [ conf.settings.LEDGER_MSP_ID, @@ -166,13 +166,13 @@ def test_organization_download_logs_success( @pytest.mark.django_db def test_organization_download_logs_forbidden( - compute_task_failure_report, + asset_failure_report, api_client: test.APIClient, incoming_organization_user: organization_auth.OrganizationUser, ): """An unauthorized organization cannot download logs from another organization.""" - compute_task, failure_report = compute_task_failure_report + compute_task, failure_report = asset_failure_report compute_task.logs_owner = conf.settings.LEDGER_MSP_ID # local (incoming request from remote) compute_task.logs_permission_authorized_ids = [conf.settings.LEDGER_MSP_ID] # incoming user not allowed compute_task.channel = incoming_organization_user.username diff --git a/backend/api/urls.py b/backend/api/urls.py index 6826dff8c..6cbc27751 100644 --- a/backend/api/urls.py +++ b/backend/api/urls.py @@ -25,7 +25,7 @@ router.register(r"compute_plan_metadata", views.ComputePlanMetadataViewSet, basename="compute_plan_metadata") router.register(r"news_feed", views.NewsFeedViewSet, basename="news_feed") router.register(r"performance", views.PerformanceViewSet, basename="performance") -router.register(r"logs", views.ComputeTaskLogsViewSet, basename="logs") +router.register(r"logs", views.FailedAssetLogsViewSet, basename="logs") router.register(r"task_profiling", views.TaskProfilingViewSet, basename="task_profiling") task_profiling_router = routers.NestedDefaultRouter(router, r"task_profiling", lookup="task_profiling") diff --git a/backend/api/views/__init__.py b/backend/api/views/__init__.py index ab83a34e5..25484555e 100644 --- a/backend/api/views/__init__.py +++ b/backend/api/views/__init__.py @@ -2,10 +2,10 @@ from .computeplan import ComputePlanViewSet from .computetask import ComputeTaskViewSet from .computetask import CPTaskViewSet -from .computetask_logs import ComputeTaskLogsViewSet from .datamanager import DataManagerPermissionViewSet from .datamanager import DataManagerViewSet from .datasample import DataSampleViewSet +from .failed_asset_logs import FailedAssetLogsViewSet from .function import CPFunctionViewSet from .function import FunctionPermissionViewSet from .function import FunctionViewSet @@ -24,6 +24,7 @@ "DataManagerPermissionViewSet", "ModelViewSet", "ModelPermissionViewSet", + "FailedAssetLogsViewSet", "FunctionViewSet", "FunctionPermissionViewSet", "ComputeTaskViewSet", @@ -31,7 +32,6 @@ "CPTaskViewSet", "CPFunctionViewSet", "NewsFeedViewSet", - "ComputeTaskLogsViewSet", "CPPerformanceViewSet", "ComputePlanMetadataViewSet", "PerformanceViewSet", diff --git a/backend/api/views/computetask_logs.py b/backend/api/views/computetask_logs.py deleted file mode 100644 index 5eca090b6..000000000 --- a/backend/api/views/computetask_logs.py +++ /dev/null @@ -1,18 +0,0 @@ -from rest_framework import response as drf_response -from rest_framework import viewsets -from rest_framework.decorators import action - -from api.models import ComputeTask -from api.views import utils as view_utils -from substrapp.models import compute_task_failure_report - - -class ComputeTaskLogsViewSet(view_utils.PermissionMixin, viewsets.GenericViewSet): - queryset = compute_task_failure_report.ComputeTaskFailureReport.objects.all() - - @action(detail=True, url_path=compute_task_failure_report.LOGS_FILE_PATH) - def file(self, request, pk=None) -> drf_response.Response: - response = self.download_file(request, ComputeTask, "logs", "logs_address") - response.headers["Content-Type"] = "text/plain; charset=utf-8" - response.headers["Content-Disposition"] = f'attachment; filename="tuple_logs_{pk}.txt"' - return response diff --git a/backend/api/views/datamanager.py b/backend/api/views/datamanager.py index 0acc744e2..57df76802 100644 --- a/backend/api/views/datamanager.py +++ b/backend/api/views/datamanager.py @@ -212,8 +212,20 @@ class DataManagerPermissionViewSet(PermissionMixin, GenericViewSet): @action(detail=True, url_path="description", url_name="description") def description_(self, request, *args, **kwargs): - return self.download_file(request, DataManager, "description", "description_address") + return self.download_file( + request, + asset_class=DataManager, + local_file_class=DataManagerFiles, + content_field="description", + address_field="description_address", + ) @action(detail=True) def opener(self, request, *args, **kwargs): - return self.download_file(request, DataManager, "data_opener", "opener_address") + return self.download_file( + request, + asset_class=DataManager, + local_file_class=DataManagerFiles, + content_field="data_opener", + address_field="opener_address", + ) diff --git a/backend/api/views/failed_asset_logs.py b/backend/api/views/failed_asset_logs.py new file mode 100644 index 000000000..7b55f0df1 --- /dev/null +++ b/backend/api/views/failed_asset_logs.py @@ -0,0 +1,41 @@ +from rest_framework import response as drf_response +from rest_framework import status +from rest_framework import viewsets +from rest_framework.decorators import action + +from api.errors import AssetPermissionError +from api.models import ComputeTask +from api.models import Function +from api.views import utils as view_utils +from substrapp.models import asset_failure_report + + +class FailedAssetLogsViewSet(view_utils.PermissionMixin, viewsets.GenericViewSet): + queryset = asset_failure_report.AssetFailureReport.objects.all() + + @action(detail=True, url_path=asset_failure_report.LOGS_FILE_PATH) + def file(self, request, pk=None) -> drf_response.Response: + report = self.get_object() + channel_name = view_utils.get_channel_name(request) + if report.asset_type == asset_failure_report.FailedAssetKind.FAILED_ASSET_FUNCTION: + asset_class = Function + else: + asset_class = ComputeTask + + try: + asset = self.get_asset(request, report.key, channel_name, asset_class) + except AssetPermissionError as e: + return view_utils.ApiResponse({"detail": str(e)}, status=status.HTTP_403_FORBIDDEN) + + response = view_utils.get_file_response( + local_file_class=asset_failure_report.AssetFailureReport, + key=report.key, + content_field="logs", + channel_name=channel_name, + url=report.logs_address, + asset_owner=asset.get_owner(), + ) + + response.headers["Content-Type"] = "text/plain; charset=utf-8" + response.headers["Content-Disposition"] = f'attachment; filename="tuple_logs_{pk}.txt"' + return response diff --git a/backend/api/views/function.py b/backend/api/views/function.py index d0391d3ef..28b8a9f2a 100644 --- a/backend/api/views/function.py +++ b/backend/api/views/function.py @@ -202,7 +202,13 @@ class FunctionPermissionViewSet(PermissionMixin, GenericViewSet): @action(detail=True) def file(self, request, *args, **kwargs): - return self.download_file(request, Function, "file", "function_address") + return self.download_file( + request, + asset_class=Function, + local_file_class=FunctionFiles, + content_field="file", + address_field="function_address", + ) # actions cannot be named "description" # https://github.com/encode/django-rest-framework/issues/6490 @@ -210,7 +216,13 @@ def file(self, request, *args, **kwargs): # https://www.django-rest-framework.org/api-guide/viewsets/#introspecting-viewset-actions @action(detail=True, url_path="description", url_name="description") def description_(self, request, *args, **kwargs): - return self.download_file(request, Function, "description", "description_address") + return self.download_file( + request, + asset_class=Function, + local_file_class=FunctionFiles, + content_field="description", + address_field="description_address", + ) @action(detail=True) def image(self, request, *args, **kwargs): diff --git a/backend/api/views/model.py b/backend/api/views/model.py index 6a051da5a..c29e78aef 100644 --- a/backend/api/views/model.py +++ b/backend/api/views/model.py @@ -140,4 +140,10 @@ def _check_export_enabled(channel_name): @if_true(gzip.gzip_page, settings.GZIP_MODELS) @action(detail=True) def file(self, request, *args, **kwargs): - return self.download_file(request, Model, "file", "model_address") + return self.download_file( + request, + asset_class=Model, + local_file_class=ModelFiles, + content_field="file", + address_field="model_address", + ) diff --git a/backend/api/views/utils.py b/backend/api/views/utils.py index 912c3c1d8..5a3626daf 100644 --- a/backend/api/views/utils.py +++ b/backend/api/views/utils.py @@ -1,10 +1,13 @@ import os import uuid from typing import Callable +from typing import Type +from typing import TypeVar from wsgiref.util import is_hop_by_hop import django.http from django.conf import settings +from django.db import models from rest_framework import status from rest_framework.authentication import BasicAuthentication from rest_framework.permissions import SAFE_METHODS @@ -26,6 +29,9 @@ HTTP_HEADER_PROXY_ASSET = "Substra-Proxy-Asset" +AssetType = TypeVar("AssetType", bound=models.Model) +LocalFileType = TypeVar("LocalFileType", bound=models.Model) + class ApiResponse(Response): """The Content-Disposition header is used for downloads and web service responses @@ -80,18 +86,33 @@ def check_access(self, channel_name: str, user, asset, is_proxied_request: bool) if not asset.is_public("process") and organization_id not in asset.get_authorized_ids("process"): raise AssetPermissionError() - def download_file(self, request, asset_class, content_field, address_field): - if settings.ISOLATED: - return ApiResponse({"detail": "Asset not available in isolated mode"}, status=status.HTTP_410_GONE) + def get_key(self, request) -> str: lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field key = self.kwargs[lookup_url_kwarg] - channel_name = get_channel_name(request) - - validated_key = validate_key(key) - asset = asset_class.objects.filter(channel=channel_name).get(key=validated_key) + return validate_key(key) + + def get_asset(self, request, key: str, channel_name: str, asset_class: Type[AssetType]) -> AssetType: + asset = asset_class.objects.filter(channel=channel_name).get(key=key) + self.check_access(channel_name, request.user, asset, is_proxied_request(request)) + + return asset + + def download_file( + self, + request, + *, + asset_class: Type[AssetType], + local_file_class: Type[LocalFileType], + content_field: str, + address_field: str, + ): + if settings.ISOLATED: + return ApiResponse({"detail": "Asset not available in isolated mode"}, status=status.HTTP_410_GONE) + key = self.get_key(request) + channel_name = get_channel_name(request) try: - self.check_access(channel_name, request.user, asset, is_proxied_request(request)) + asset = self.get_asset(request, key, channel_name, asset_class) except AssetPermissionError as e: return ApiResponse({"detail": str(e)}, status=status.HTTP_403_FORBIDDEN) @@ -99,49 +120,70 @@ def download_file(self, request, asset_class, content_field, address_field): if not url: return ApiResponse({"detail": "Asset not available anymore"}, status=status.HTTP_410_GONE) - if get_owner() == asset.get_owner(): - response = self._get_local_file_response(content_field) - else: - response = self._download_remote_file(channel_name, asset.get_owner(), url) - - return response + return get_file_response( + key=key, + local_file_class=local_file_class, + asset_owner=asset.get_owner(), + content_field=content_field, + channel_name=channel_name, + url=url, + ) - def _get_local_file_response(self, content_field): - obj = self.get_object() - data = getattr(obj, content_field) - if isinstance(data.storage, MinioStorage): - filename = str(obj.key) - else: - filename = os.path.basename(data.path) - data = open(data.path, "rb") +def get_file_response( + *, + local_file_class: Type[LocalFileType], + content_field: str, + key: str, + asset_owner: str, + channel_name: str, + url: str, +) -> django.http.FileResponse: + if get_owner() == asset_owner: + local_file = local_file_class.objects.get(pk=key) + response = _get_local_file_response(local_file, key, content_field) + else: + response = _download_remote_file(channel_name, asset_owner, url) - response = CustomFileResponse( - data, - as_attachment=True, - filename=filename, - ) - return response + return response - def _download_remote_file(self, channel_name: str, owner: str, url: str) -> django.http.FileResponse: - proxy_response = organization_client.streamed_get( - channel=channel_name, - organization_id=owner, - url=url, - headers={HTTP_HEADER_PROXY_ASSET: "True"}, - ) - response = CustomFileResponse( - streaming_content=(chunk for chunk in proxy_response.iter_content(512 * 1024)), - status=proxy_response.status_code, - ) - for header in proxy_response.headers: - # We don't use hop_by_hop headers since they are incompatible - # with WSGI - if not is_hop_by_hop(header): - response[header] = proxy_response.headers.get(header) +def _get_local_file_response(local_file: LocalFileType, key: str, content_field: str): + data = getattr(local_file, content_field) - return response + if isinstance(data.storage, MinioStorage): + filename = key + else: + filename = os.path.basename(data.path) + data = open(data.path, "rb") + + response = CustomFileResponse( + data, + as_attachment=True, + filename=filename, + ) + return response + + +def _download_remote_file(channel_name: str, owner: str, url: str) -> django.http.FileResponse: + proxy_response = organization_client.streamed_get( + channel=channel_name, + organization_id=owner, + url=url, + headers={HTTP_HEADER_PROXY_ASSET: "True"}, + ) + response = CustomFileResponse( + streaming_content=(chunk for chunk in proxy_response.iter_content(512 * 1024)), + status=proxy_response.status_code, + ) + + for header in proxy_response.headers: + # We don't use hop_by_hop headers since they are incompatible + # with WSGI + if not is_hop_by_hop(header): + response[header] = proxy_response.headers.get(header) + + return response def validate_key(key) -> str: diff --git a/backend/builder/tasks/task.py b/backend/builder/tasks/task.py index 77fe7b12c..459f00e99 100644 --- a/backend/builder/tasks/task.py +++ b/backend/builder/tasks/task.py @@ -1,27 +1,15 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Any - -if TYPE_CHECKING: - from billiard.einfo import ExceptionInfo - import structlog -from celery import Task from django.conf import settings import orchestrator - -# from substrapp.compute_tasks import errors as compute_task_errors +from substrapp.models import FailedAssetKind from substrapp.orchestrator import get_orchestrator_client - -# from substrapp.utils.errors import store_failure - +from substrapp.tasks.task import FailableTask logger = structlog.get_logger("builder") -class BuildTask(Task): +class BuildTask(FailableTask): autoretry_for = settings.CELERY_TASK_AUTORETRY_FOR max_retries = settings.CELERY_TASK_MAX_RETRIES retry_backoff = settings.CELERY_TASK_RETRY_BACKOFF @@ -31,45 +19,13 @@ class BuildTask(Task): reject_on_worker_lost = True ignore_result = False + asset_type = FailedAssetKind.FAILED_ASSET_FUNCTION + @property def attempt(self) -> int: return self.request.retries + 1 # type: ignore - def on_failure( - self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo - ) -> None: - logger.error(exc) - logger.error(einfo) - function_key, channel_name = self.get_task_info(args, kwargs) - with get_orchestrator_client(channel_name) as client: - client.update_function_status( - function_key=function_key, action=orchestrator.function_pb2.FUNCTION_ACTION_FAILED - ) - - # def on_failure( - # self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo - # ) -> None: - # super():on_failure(exc, task_id, args, kwargs, einfo) - # channel_name, function, compute_task_key = self.split_args(args) - - # failure_report = store_failure(exc, compute_task_key) - # error_type = compute_task_errors.get_error_type(exc) - - # with get_orchestrator_client(channel_name) as client: - # # On the backend, only execution errors lead to the creation of compute task failure report instances - # # to store the execution logs. - # if failure_report: - # logs_address = { - # "checksum": failure_report.logs_checksum, - # "storage_address": failure_report.logs_address, - # } - # else: - # logs_address = None - - # client.register_failure_report( - # {"compute_task_key": compute_task_key, "error_type": error_type, "logs_address": logs_address} - # ) - + # Celery does not provide unpacked arguments, we are doing it in `get_task_info` def before_start(self, task_id: str, args: tuple, kwargs: dict) -> None: function_key, channel_name = self.get_task_info(args, kwargs) with get_orchestrator_client(channel_name) as client: diff --git a/backend/builder/tasks/tasks_build_image.py b/backend/builder/tasks/tasks_build_image.py index c6f904ed7..dda71d8dc 100644 --- a/backend/builder/tasks/tasks_build_image.py +++ b/backend/builder/tasks/tasks_build_image.py @@ -30,7 +30,7 @@ def build_image(task: BuildTask, function_serialized: str, channel_name: str) -> except BuildRetryError as e: logger.info( "Retrying build", - celery_task_id=function.key, + function_id=function.key, attempt=(task.attempt + 1), max_attempts=(task.max_retries + 1), ) diff --git a/backend/builder/tests/test_task_build_image.py b/backend/builder/tests/test_task_build_image.py index 2ea9a5043..151821d57 100644 --- a/backend/builder/tests/test_task_build_image.py +++ b/backend/builder/tests/test_task_build_image.py @@ -1,6 +1,7 @@ import pytest from builder.exceptions import BuildError +from substrapp.models import FailedAssetKind from substrapp.utils.errors import store_failure @@ -10,8 +11,10 @@ def test_store_failure_build_error(): msg = "Error building image" exc = BuildError(msg) - failure_report = store_failure(exc, compute_task_key) + failure_report = store_failure( + exc, compute_task_key, FailedAssetKind.FAILED_ASSET_FUNCTION, error_type=BuildError.error_type.value + ) failure_report.refresh_from_db() - assert str(failure_report.compute_task_key) == compute_task_key + assert str(failure_report.asset_key) == compute_task_key assert failure_report.logs.read() == str.encode(msg) diff --git a/backend/orchestrator/failure_report_pb2.py b/backend/orchestrator/failure_report_pb2.py index 0fc9bf1ef..b08cb7081 100644 --- a/backend/orchestrator/failure_report_pb2.py +++ b/backend/orchestrator/failure_report_pb2.py @@ -15,7 +15,7 @@ from . import common_pb2 as common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66\x61ilure_report.proto\x12\x0corchestrator\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0c\x63ommon.proto\"\xc9\x01\n\rFailureReport\x12\x18\n\x10\x63ompute_task_key\x18\x01 \x01(\t\x12+\n\nerror_type\x18\x02 \x01(\x0e\x32\x17.orchestrator.ErrorType\x12/\n\x0clogs_address\x18\x03 \x01(\x0b\x32\x19.orchestrator.Addressable\x12\x31\n\rcreation_date\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\r\n\x05owner\x18\x05 \x01(\t\"\x8a\x01\n\x10NewFailureReport\x12\x18\n\x10\x63ompute_task_key\x18\x01 \x01(\t\x12+\n\nerror_type\x18\x02 \x01(\x0e\x32\x17.orchestrator.ErrorType\x12/\n\x0clogs_address\x18\x03 \x01(\x0b\x32\x19.orchestrator.Addressable\"1\n\x15GetFailureReportParam\x12\x18\n\x10\x63ompute_task_key\x18\x01 \x01(\t*p\n\tErrorType\x12\x1a\n\x16\x45RROR_TYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45RROR_TYPE_BUILD\x10\x01\x12\x18\n\x14\x45RROR_TYPE_EXECUTION\x10\x02\x12\x17\n\x13\x45RROR_TYPE_INTERNAL\x10\x03\x32\xc2\x01\n\x14\x46\x61ilureReportService\x12T\n\x15RegisterFailureReport\x12\x1e.orchestrator.NewFailureReport\x1a\x1b.orchestrator.FailureReport\x12T\n\x10GetFailureReport\x12#.orchestrator.GetFailureReportParam\x1a\x1b.orchestrator.FailureReportB+Z)github.com/substra/orchestrator/lib/assetb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66\x61ilure_report.proto\x12\x0corchestrator\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0c\x63ommon.proto\"\xf5\x01\n\rFailureReport\x12\x11\n\tasset_key\x18\x01 \x01(\t\x12+\n\nerror_type\x18\x02 \x01(\x0e\x32\x17.orchestrator.ErrorType\x12/\n\x0clogs_address\x18\x03 \x01(\x0b\x32\x19.orchestrator.Addressable\x12\x31\n\rcreation_date\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\r\n\x05owner\x18\x05 \x01(\t\x12\x31\n\nasset_type\x18\x06 \x01(\x0e\x32\x1d.orchestrator.FailedAssetKind\"\xb6\x01\n\x10NewFailureReport\x12\x11\n\tasset_key\x18\x01 \x01(\t\x12+\n\nerror_type\x18\x02 \x01(\x0e\x32\x17.orchestrator.ErrorType\x12/\n\x0clogs_address\x18\x03 \x01(\x0b\x32\x19.orchestrator.Addressable\x12\x31\n\nasset_type\x18\x04 \x01(\x0e\x32\x1d.orchestrator.FailedAssetKind\"*\n\x15GetFailureReportParam\x12\x11\n\tasset_key\x18\x01 \x01(\t*p\n\tErrorType\x12\x1a\n\x16\x45RROR_TYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45RROR_TYPE_BUILD\x10\x01\x12\x18\n\x14\x45RROR_TYPE_EXECUTION\x10\x02\x12\x17\n\x13\x45RROR_TYPE_INTERNAL\x10\x03*e\n\x0f\x46\x61iledAssetKind\x12\x18\n\x14\x46\x41ILED_ASSET_UNKNOWN\x10\x00\x12\x1d\n\x19\x46\x41ILED_ASSET_COMPUTE_TASK\x10\x01\x12\x19\n\x15\x46\x41ILED_ASSET_FUNCTION\x10\x02\x32\xc2\x01\n\x14\x46\x61ilureReportService\x12T\n\x15RegisterFailureReport\x12\x1e.orchestrator.NewFailureReport\x1a\x1b.orchestrator.FailureReport\x12T\n\x10GetFailureReport\x12#.orchestrator.GetFailureReportParam\x1a\x1b.orchestrator.FailureReportB+Z)github.com/substra/orchestrator/lib/assetb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'failure_report_pb2', globals()) @@ -23,14 +23,16 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'Z)github.com/substra/orchestrator/lib/asset' - _ERRORTYPE._serialized_start=481 - _ERRORTYPE._serialized_end=593 + _ERRORTYPE._serialized_start=562 + _ERRORTYPE._serialized_end=674 + _FAILEDASSETKIND._serialized_start=676 + _FAILEDASSETKIND._serialized_end=777 _FAILUREREPORT._serialized_start=86 - _FAILUREREPORT._serialized_end=287 - _NEWFAILUREREPORT._serialized_start=290 - _NEWFAILUREREPORT._serialized_end=428 - _GETFAILUREREPORTPARAM._serialized_start=430 - _GETFAILUREREPORTPARAM._serialized_end=479 - _FAILUREREPORTSERVICE._serialized_start=596 - _FAILUREREPORTSERVICE._serialized_end=790 + _FAILUREREPORT._serialized_end=331 + _NEWFAILUREREPORT._serialized_start=334 + _NEWFAILUREREPORT._serialized_end=516 + _GETFAILUREREPORTPARAM._serialized_start=518 + _GETFAILUREREPORTPARAM._serialized_end=560 + _FAILUREREPORTSERVICE._serialized_start=780 + _FAILUREREPORTSERVICE._serialized_end=974 # @@protoc_insertion_point(module_scope) diff --git a/backend/orchestrator/failure_report_pb2.pyi b/backend/orchestrator/failure_report_pb2.pyi index 2abced6e5..c6d6c407e 100644 --- a/backend/orchestrator/failure_report_pb2.pyi +++ b/backend/orchestrator/failure_report_pb2.pyi @@ -56,40 +56,62 @@ It is likely to be caused by a fault in the system. It would require the action """ global___ErrorType = ErrorType +class _FailedAssetKind: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _FailedAssetKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_FailedAssetKind.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + FAILED_ASSET_UNKNOWN: _FailedAssetKind.ValueType # 0 + FAILED_ASSET_COMPUTE_TASK: _FailedAssetKind.ValueType # 1 + FAILED_ASSET_FUNCTION: _FailedAssetKind.ValueType # 2 + +class FailedAssetKind(_FailedAssetKind, metaclass=_FailedAssetKindEnumTypeWrapper): ... + +FAILED_ASSET_UNKNOWN: FailedAssetKind.ValueType # 0 +FAILED_ASSET_COMPUTE_TASK: FailedAssetKind.ValueType # 1 +FAILED_ASSET_FUNCTION: FailedAssetKind.ValueType # 2 +global___FailedAssetKind = FailedAssetKind + @typing_extensions.final class FailureReport(google.protobuf.message.Message): """FailureReport is used to store information related to a failed ComputeTask.""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - COMPUTE_TASK_KEY_FIELD_NUMBER: builtins.int + ASSET_KEY_FIELD_NUMBER: builtins.int ERROR_TYPE_FIELD_NUMBER: builtins.int LOGS_ADDRESS_FIELD_NUMBER: builtins.int CREATION_DATE_FIELD_NUMBER: builtins.int OWNER_FIELD_NUMBER: builtins.int - compute_task_key: builtins.str + ASSET_TYPE_FIELD_NUMBER: builtins.int + asset_key: builtins.str error_type: global___ErrorType.ValueType @property def logs_address(self) -> common_pb2.Addressable: ... @property def creation_date(self) -> google.protobuf.timestamp_pb2.Timestamp: ... owner: builtins.str - """The owner of a failure report matches the 'worker' field of the associated compute task but can differ from + """In the case of a compute task failure, the owner of a failure report matches the 'worker' field of the associated compute task but can differ from the owner of the compute task. Indeed, a task belonging to some user can be executed on an organization belonging - to another user. The failure report generated will be located on the execution organization and belong to the owner + to another user. + In the case of a function, the owner will be the owner of the function (which builds the function). + The failure report generated will be located on the execution organization and belong to the owner of this organization. """ + asset_type: global___FailedAssetKind.ValueType def __init__( self, *, - compute_task_key: builtins.str = ..., + asset_key: builtins.str = ..., error_type: global___ErrorType.ValueType = ..., logs_address: common_pb2.Addressable | None = ..., creation_date: google.protobuf.timestamp_pb2.Timestamp | None = ..., owner: builtins.str = ..., + asset_type: global___FailedAssetKind.ValueType = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["creation_date", b"creation_date", "logs_address", b"logs_address"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["compute_task_key", b"compute_task_key", "creation_date", b"creation_date", "error_type", b"error_type", "logs_address", b"logs_address", "owner", b"owner"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["asset_key", b"asset_key", "asset_type", b"asset_type", "creation_date", b"creation_date", "error_type", b"error_type", "logs_address", b"logs_address", "owner", b"owner"]) -> None: ... global___FailureReport = FailureReport @@ -101,22 +123,25 @@ class NewFailureReport(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - COMPUTE_TASK_KEY_FIELD_NUMBER: builtins.int + ASSET_KEY_FIELD_NUMBER: builtins.int ERROR_TYPE_FIELD_NUMBER: builtins.int LOGS_ADDRESS_FIELD_NUMBER: builtins.int - compute_task_key: builtins.str + ASSET_TYPE_FIELD_NUMBER: builtins.int + asset_key: builtins.str error_type: global___ErrorType.ValueType @property def logs_address(self) -> common_pb2.Addressable: ... + asset_type: global___FailedAssetKind.ValueType def __init__( self, *, - compute_task_key: builtins.str = ..., + asset_key: builtins.str = ..., error_type: global___ErrorType.ValueType = ..., logs_address: common_pb2.Addressable | None = ..., + asset_type: global___FailedAssetKind.ValueType = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["logs_address", b"logs_address"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["compute_task_key", b"compute_task_key", "error_type", b"error_type", "logs_address", b"logs_address"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["asset_key", b"asset_key", "asset_type", b"asset_type", "error_type", b"error_type", "logs_address", b"logs_address"]) -> None: ... global___NewFailureReport = NewFailureReport @@ -126,13 +151,13 @@ class GetFailureReportParam(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - COMPUTE_TASK_KEY_FIELD_NUMBER: builtins.int - compute_task_key: builtins.str + ASSET_KEY_FIELD_NUMBER: builtins.int + asset_key: builtins.str def __init__( self, *, - compute_task_key: builtins.str = ..., + asset_key: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["compute_task_key", b"compute_task_key"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["asset_key", b"asset_key"]) -> None: ... global___GetFailureReportParam = GetFailureReportParam diff --git a/backend/substrapp/compute_tasks/errors.py b/backend/substrapp/compute_tasks/errors.py index 610286593..71bd12823 100644 --- a/backend/substrapp/compute_tasks/errors.py +++ b/backend/substrapp/compute_tasks/errors.py @@ -71,10 +71,10 @@ class ExecutionError(_ComputeTaskError, CeleryNoRetryError): def __init__(self, logs: BinaryIO, *args, **kwargs): self.logs = logs - super().__init__(*args, **kwargs) + super().__init__(logs, *args, **kwargs) -def get_error_type(exc: Exception) -> failure_report_pb2.ErrorType: +def get_error_type(exc: Exception) -> failure_report_pb2.ErrorType.ValueType: """From a given exception, return an error type safe to store and to advertise to the user. Args: diff --git a/backend/substrapp/events/reactor.py b/backend/substrapp/events/reactor.py index 045c07093..6ffd803b2 100644 --- a/backend/substrapp/events/reactor.py +++ b/backend/substrapp/events/reactor.py @@ -103,8 +103,8 @@ def on_function_event(payload): } ( # TODO switch to function.model_dump_json() as soon as pydantic is updated to > 2.0 - build_image.si(**building_params).set(queue=builder_queue, task_id=function_key) - | save_image_task.si(**building_params).set(queue=WORKER_QUEUE, task_id=function_key) + build_image.si(**building_params).set(queue=builder_queue) + | save_image_task.si(**building_params).set(queue=WORKER_QUEUE) ).apply_async() else: diff --git a/backend/substrapp/migrations/0006_create_compute_task_failure_model.py b/backend/substrapp/migrations/0006_create_compute_task_failure_model.py index c6a7abe1e..1ce781d19 100644 --- a/backend/substrapp/migrations/0006_create_compute_task_failure_model.py +++ b/backend/substrapp/migrations/0006_create_compute_task_failure_model.py @@ -25,7 +25,7 @@ class Migration(migrations.Migration): models.FileField( max_length=36, storage=django.core.files.storage.FileSystemStorage(), - upload_to=substrapp.models.compute_task_failure_report._upload_to, + upload_to=substrapp.models.asset_failure_report._upload_to, ), ), ("logs_checksum", models.CharField(max_length=64)), diff --git a/backend/substrapp/migrations/0012_alter_algo_description_alter_algo_file_and_more.py b/backend/substrapp/migrations/0012_alter_algo_description_alter_algo_file_and_more.py index 0e891e227..81475f62b 100644 --- a/backend/substrapp/migrations/0012_alter_algo_description_alter_algo_file_and_more.py +++ b/backend/substrapp/migrations/0012_alter_algo_description_alter_algo_file_and_more.py @@ -3,7 +3,7 @@ from django.db import migrations from django.db import models -import substrapp.models.compute_task_failure_report +import substrapp.models.asset_failure_report import substrapp.models.datamanager import substrapp.models.function import substrapp.storages.minio @@ -39,7 +39,7 @@ class Migration(migrations.Migration): field=models.FileField( max_length=36, storage=substrapp.storages.minio.MinioStorage("substra-compute-task-logs"), - upload_to=substrapp.models.compute_task_failure_report._upload_to, + upload_to=substrapp.models.asset_failure_report._upload_to, ), ), migrations.AlterField( diff --git a/backend/substrapp/migrations/0013_alter_algo_description_alter_algo_file_and_more.py b/backend/substrapp/migrations/0013_alter_algo_description_alter_algo_file_and_more.py index a2f59d0eb..df8a7fe7d 100644 --- a/backend/substrapp/migrations/0013_alter_algo_description_alter_algo_file_and_more.py +++ b/backend/substrapp/migrations/0013_alter_algo_description_alter_algo_file_and_more.py @@ -4,7 +4,7 @@ from django.db import migrations from django.db import models -import substrapp.models.compute_task_failure_report +import substrapp.models.asset_failure_report import substrapp.models.datamanager import substrapp.models.function @@ -39,7 +39,7 @@ class Migration(migrations.Migration): field=models.FileField( max_length=36, storage=django.core.files.storage.FileSystemStorage(), - upload_to=substrapp.models.compute_task_failure_report._upload_to, + upload_to=substrapp.models.asset_failure_report._upload_to, ), ), migrations.AlterField( diff --git a/backend/substrapp/migrations/0015_alter_computetaskfailurereport_logs.py b/backend/substrapp/migrations/0015_alter_computetaskfailurereport_logs.py index 238959ad8..b53457814 100644 --- a/backend/substrapp/migrations/0015_alter_computetaskfailurereport_logs.py +++ b/backend/substrapp/migrations/0015_alter_computetaskfailurereport_logs.py @@ -3,7 +3,7 @@ from django.db import migrations from django.db import models -import substrapp.models.compute_task_failure_report +import substrapp.models.asset_failure_report import substrapp.storages.minio @@ -19,7 +19,7 @@ class Migration(migrations.Migration): field=models.FileField( max_length=36, storage=substrapp.storages.minio.MinioStorage("substra-compute-task-logs"), - upload_to=substrapp.models.compute_task_failure_report._upload_to, + upload_to=substrapp.models.asset_failure_report._upload_to, ), ), ] diff --git a/backend/substrapp/migrations/0017_alter_computetaskfailurereport_logs_and_more.py b/backend/substrapp/migrations/0017_alter_computetaskfailurereport_logs_and_more.py index a21c77454..48977d179 100644 --- a/backend/substrapp/migrations/0017_alter_computetaskfailurereport_logs_and_more.py +++ b/backend/substrapp/migrations/0017_alter_computetaskfailurereport_logs_and_more.py @@ -4,7 +4,7 @@ from django.db import migrations from django.db import models -import substrapp.models.compute_task_failure_report +import substrapp.models.asset_failure_report import substrapp.models.function @@ -20,7 +20,7 @@ class Migration(migrations.Migration): field=models.FileField( max_length=36, storage=django.core.files.storage.FileSystemStorage(), - upload_to=substrapp.models.compute_task_failure_report._upload_to, + upload_to=substrapp.models.asset_failure_report._upload_to, ), ), migrations.AlterField( diff --git a/backend/substrapp/migrations/0018_rename_computetaskfailurereport_and_more.py b/backend/substrapp/migrations/0018_rename_computetaskfailurereport_and_more.py new file mode 100644 index 000000000..6745ecdb0 --- /dev/null +++ b/backend/substrapp/migrations/0018_rename_computetaskfailurereport_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.3 on 2023-08-30 15:07 + +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + dependencies = [ + ("substrapp", "0017_alter_computetaskfailurereport_logs_and_more"), + ] + + operations = [ + migrations.RenameModel("ComputeTaskFailureReport", "AssetFailureReport"), + migrations.RenameField("AssetFailureReport", "compute_task_key", "asset_key"), + migrations.AddField( + model_name="assetfailurereport", + name="asset_type", + field=models.CharField( + choices=[ + ("FAILED_ASSET_UNKNOWN", "Failed Asset Unknown"), + ("FAILED_ASSET_COMPUTE_TASK", "Failed Asset Compute Task"), + ("FAILED_ASSET_FUNCTION", "Failed Asset Function"), + ], + default="FAILED_ASSET_UNKNOWN", + max_length=100, + ), + preserve_default=False, + ), + ] diff --git a/backend/substrapp/migrations/0019_alter_assetfailurereport_logs_and_more.py b/backend/substrapp/migrations/0019_alter_assetfailurereport_logs_and_more.py new file mode 100644 index 000000000..68c284b9d --- /dev/null +++ b/backend/substrapp/migrations/0019_alter_assetfailurereport_logs_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.3 on 2023-09-06 15:58 + +import django.core.files.storage +from django.db import migrations +from django.db import models + +import substrapp.models.asset_failure_report +import substrapp.models.function + + +class Migration(migrations.Migration): + dependencies = [ + ("substrapp", "0018_rename_computetaskfailurereport_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="assetfailurereport", + name="logs", + field=models.FileField( + max_length=36, + storage=django.core.files.storage.FileSystemStorage(), + upload_to=substrapp.models.asset_failure_report._upload_to, + ), + ), + ] diff --git a/backend/substrapp/models/__init__.py b/backend/substrapp/models/__init__.py index 793035cba..28bc2de5b 100644 --- a/backend/substrapp/models/__init__.py +++ b/backend/substrapp/models/__init__.py @@ -1,4 +1,5 @@ -from .compute_task_failure_report import ComputeTaskFailureReport +from .asset_failure_report import AssetFailureReport +from .asset_failure_report import FailedAssetKind from .computeplan_worker_mapping import ComputePlanWorkerMapping from .datamanager import DataManager from .datasample import DataSample @@ -11,11 +12,12 @@ __all__ = [ "DataSample", "DataManager", + "FailedAssetKind", "Function", "FunctionImage", "Model", "ComputePlanWorkerMapping", "ImageEntrypoint", - "ComputeTaskFailureReport", + "AssetFailureReport", "WorkerLastEvent", ] diff --git a/backend/substrapp/models/compute_task_failure_report.py b/backend/substrapp/models/asset_failure_report.py similarity index 57% rename from backend/substrapp/models/compute_task_failure_report.py rename to backend/substrapp/models/asset_failure_report.py index ab5dcca3b..0dbd0f45e 100644 --- a/backend/substrapp/models/compute_task_failure_report.py +++ b/backend/substrapp/models/asset_failure_report.py @@ -5,6 +5,8 @@ from django.conf import settings from django.db import models +from orchestrator import failure_report_pb2 + LOGS_BASE_PATH: Final[str] = "logs" LOGS_FILE_PATH: Final[str] = "file" @@ -12,14 +14,20 @@ _SHA256_STRING_REPR_LENGTH: Final[int] = 256 // 4 -def _upload_to(instance: "ComputeTaskFailureReport", _filename: str) -> str: - return str(instance.compute_task_key) +def _upload_to(instance: "AssetFailureReport", _filename: str) -> str: + return str(instance.asset_key) + + +FailedAssetKind = models.TextChoices( + "FailedAssetKind", [(status_name, status_name) for status_name in failure_report_pb2.FailedAssetKind.keys()] +) -class ComputeTaskFailureReport(models.Model): +class AssetFailureReport(models.Model): """Store information relative to a compute task.""" - compute_task_key = models.UUIDField(primary_key=True, editable=False) + asset_key = models.UUIDField(primary_key=True, editable=False) + asset_type = models.CharField(max_length=100, choices=FailedAssetKind.choices) logs = models.FileField( storage=settings.COMPUTE_TASK_LOGS_STORAGE, max_length=_UUID_STRING_REPR_LENGTH, upload_to=_upload_to ) @@ -28,9 +36,9 @@ class ComputeTaskFailureReport(models.Model): @property def key(self) -> uuid.UUID: - return self.compute_task_key + return self.asset_key @property def logs_address(self) -> str: - logs_path = f"{LOGS_BASE_PATH}/{self.compute_task_key}/{LOGS_FILE_PATH}/" + logs_path = f"{LOGS_BASE_PATH}/{self.asset_key}/{LOGS_FILE_PATH}/" return urllib.parse.urljoin(settings.DEFAULT_DOMAIN, logs_path) diff --git a/backend/substrapp/tasks/task.py b/backend/substrapp/tasks/task.py new file mode 100644 index 000000000..0fab4e653 --- /dev/null +++ b/backend/substrapp/tasks/task.py @@ -0,0 +1,101 @@ +""" +This file contains the main logic for executing a compute task: + +- Create execution context +- Populate asset buffer +- Loads assets from the asset buffer +- **Execute the compute task** +- Save the models/results +- Teardown the context + +We also handle the retry logic here. +""" +import enum +import pickle # nosec B403 +from typing import Any + +import structlog +from billiard.einfo import ExceptionInfo +from celery import Task +from django.conf import settings + +import orchestrator +from substrapp.compute_tasks.compute_pod import delete_compute_plan_pods +from substrapp.models import FailedAssetKind +from substrapp.task_routing import WORKER_QUEUE +from substrapp.tasks.tasks_asset_failure_report import store_asset_failure_report + +logger = structlog.get_logger(__name__) + + +class FailableTask(Task): + asset_type: FailedAssetKind + + # Celery does not provide unpacked arguments, we are doing it in `get_task_info` + def on_failure( + self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo + ) -> None: + asset_key, channel_name = self.get_task_info(args, kwargs) + exception_pickled = pickle.dumps(exc) + store_asset_failure_report.apply_async( + args, + { + "asset_key": asset_key, + "asset_type": self.asset_type, + "channel_name": channel_name, + "exception_pickled": exception_pickled, + }, + queue=WORKER_QUEUE, + ) + + def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: + raise NotImplementedError() + + +class ComputeTaskSteps(enum.Enum): + BUILD_IMAGE = "build_image" + PREPARE_INPUTS = "prepare_inputs" + TASK_EXECUTION = "task_execution" + SAVE_OUTPUTS = "save_outputs" + + +class ComputeTask(FailableTask): + autoretry_for = settings.CELERY_TASK_AUTORETRY_FOR + max_retries = settings.CELERY_TASK_MAX_RETRIES + retry_backoff = settings.CELERY_TASK_RETRY_BACKOFF + retry_backoff_max = settings.CELERY_TASK_RETRY_BACKOFF_MAX + retry_jitter = settings.CELERY_TASK_RETRY_JITTER + + asset_type = FailedAssetKind.FAILED_ASSET_COMPUTE_TASK + + @property + def attempt(self) -> int: + return self.request.retries + 1 # type: ignore + + # Celery does not provide unpacked arguments + def on_success(self, retval: dict[str, Any], task_id: str, args: tuple, kwargs: dict[str, Any]) -> None: + from django.db import close_old_connections + + close_old_connections() + + # Celery does not provide unpacked arguments, we are doing it in `split_args` + def on_retry(self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo) -> None: + _, task = self.split_args(args) + # delete compute pod to reset hardware ressources + delete_compute_plan_pods(task.compute_plan_key) + logger.info( + "Retrying task", + celery_task_id=task_id, + attempt=(self.attempt + 1), + max_attempts=(settings.CELERY_TASK_MAX_RETRIES + 1), + ) + + def split_args(self, celery_args: tuple) -> tuple[str, orchestrator.ComputeTask]: + channel_name = celery_args[0] + task = orchestrator.ComputeTask.parse_raw(celery_args[1]) + return channel_name, task + + def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: + channel_name, task = self.split_args(args) + + return task.key, channel_name diff --git a/backend/substrapp/tasks/tasks_asset_failure_report.py b/backend/substrapp/tasks/tasks_asset_failure_report.py new file mode 100644 index 000000000..bcac56e50 --- /dev/null +++ b/backend/substrapp/tasks/tasks_asset_failure_report.py @@ -0,0 +1,68 @@ +import pickle # nosec B403 - internal to the worker + +import structlog +from celery import Task +from django.conf import settings + +from backend.celery import app +from substrapp.compute_tasks import errors as compute_task_errors +from substrapp.models import FailedAssetKind +from substrapp.orchestrator import get_orchestrator_client +from substrapp.utils.errors import store_failure + +REGISTRY = settings.REGISTRY +REGISTRY_SCHEME = settings.REGISTRY_SCHEME +SUBTUPLE_TMP_DIR = settings.SUBTUPLE_TMP_DIR + +logger = structlog.get_logger("worker") + + +class StoreAssetFailureReportTask(Task): + max_retries = 0 + reject_on_worker_lost = True + ignore_result = False + + @property + def attempt(self) -> int: + return self.request.retries + 1 # type: ignore + + def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str, str]: + asset_key = kwargs["asset_key"] + asset_type = kwargs["asset_type"] + channel_name = kwargs["channel_name"] + return asset_key, asset_type, channel_name + + +@app.task( + bind=True, + acks_late=True, + reject_on_worker_lost=True, + ignore_result=False, + base=StoreAssetFailureReportTask, +) +def store_asset_failure_report( + task: StoreAssetFailureReportTask, *, asset_key: str, asset_type: str, channel_name: str, exception_pickled: bytes +) -> None: + exception = pickle.loads(exception_pickled) # nosec B301 + + if asset_type == FailedAssetKind.FAILED_ASSET_FUNCTION: + error_type = compute_task_errors.ComputeTaskErrorType.BUILD_ERROR.value + else: + error_type = compute_task_errors.get_error_type(exception) + + failure_report = store_failure(exception, asset_key, asset_type, error_type) + + with get_orchestrator_client(channel_name) as client: + # On the backend, only building and execution errors lead to the creation of compute task failure + # report instances to store the execution logs. + if failure_report: + logs_address = { + "checksum": failure_report.logs_checksum, + "storage_address": failure_report.logs_address, + } + else: + logs_address = None + + client.register_failure_report( + {"asset_key": asset_key, "error_type": error_type, "asset_type": asset_type, "logs_address": logs_address} + ) diff --git a/backend/substrapp/tasks/tasks_compute_task.py b/backend/substrapp/tasks/tasks_compute_task.py index fc269b018..8c8a237a2 100644 --- a/backend/substrapp/tasks/tasks_compute_task.py +++ b/backend/substrapp/tasks/tasks_compute_task.py @@ -16,19 +16,14 @@ import enum import errno import os -from typing import TYPE_CHECKING from typing import Any import celery.exceptions import structlog -from celery import Task from celery.result import AsyncResult from django.conf import settings from rest_framework import status -if TYPE_CHECKING: - from billiard.einfo import ExceptionInfo - import orchestrator from backend.celery import app from substrapp.clients import organization as organization_client @@ -40,7 +35,6 @@ from substrapp.compute_tasks.asset_buffer import clear_assets_buffer from substrapp.compute_tasks.asset_buffer import init_asset_buffer from substrapp.compute_tasks.chainkeys import prepare_chainkeys_dir -from substrapp.compute_tasks.compute_pod import delete_compute_plan_pods from substrapp.compute_tasks.context import Context from substrapp.compute_tasks.datastore import Datastore from substrapp.compute_tasks.datastore import get_datastore @@ -57,11 +51,11 @@ from substrapp.exceptions import OrganizationHttpError from substrapp.lock_local import lock_resource from substrapp.orchestrator import get_orchestrator_client +from substrapp.tasks.task import ComputeTask from substrapp.utils import Timer from substrapp.utils import get_owner from substrapp.utils import list_dir from substrapp.utils import retry -from substrapp.utils.errors import store_failure from substrapp.utils.url import TASK_PROFILING_BASE_URL from substrapp.utils.url import get_task_profiling_detail_url from substrapp.utils.url import get_task_profiling_steps_base_url @@ -77,67 +71,6 @@ class ComputeTaskSteps(enum.Enum): SAVE_OUTPUTS = "save_outputs" -class ComputeTask(Task): - autoretry_for = settings.CELERY_TASK_AUTORETRY_FOR - max_retries = settings.CELERY_TASK_MAX_RETRIES - retry_backoff = settings.CELERY_TASK_RETRY_BACKOFF - retry_backoff_max = settings.CELERY_TASK_RETRY_BACKOFF_MAX - retry_jitter = settings.CELERY_TASK_RETRY_JITTER - - @property - def attempt(self) -> int: - return self.request.retries + 1 # type: ignore - - def on_success(self, retval: dict[str, Any], task_id: str, args: tuple, kwargs: dict[str, Any]) -> None: - from django.db import close_old_connections - - close_old_connections() - - def on_retry(self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo) -> None: - _, task = self.split_args(args) - # delete compute pod to reset hardware ressources - delete_compute_plan_pods(task.compute_plan_key) - logger.info( - "Retrying task", - celery_task_id=task_id, - attempt=(self.attempt + 1), - max_attempts=(settings.CELERY_TASK_MAX_RETRIES + 1), - ) - - def on_failure( - self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo - ) -> None: - from django.db import close_old_connections - - close_old_connections() - - channel_name, task = self.split_args(args) - compute_task_key = task.key - - failure_report = store_failure(exc, compute_task_key) - error_type = compute_task_errors.get_error_type(exc) - - with get_orchestrator_client(channel_name) as client: - # On the backend, only execution errors lead to the creation of compute task failure report instances - # to store the execution logs. - if failure_report: - logs_address = { - "checksum": failure_report.logs_checksum, - "storage_address": failure_report.logs_address, - } - else: - logs_address = None - - client.register_failure_report( - {"compute_task_key": compute_task_key, "error_type": error_type, "logs_address": logs_address} - ) - - def split_args(self, celery_args: tuple) -> tuple[str, orchestrator.ComputeTask]: - channel_name = celery_args[0] - task = orchestrator.ComputeTask.parse_raw(celery_args[1]) - return channel_name, task - - def queue_compute_task(channel_name: str, task: orchestrator.ComputeTask) -> None: from substrapp.task_routing import get_worker_queue @@ -163,7 +96,10 @@ def queue_compute_task(channel_name: str, task: orchestrator.ComputeTask) -> Non worker_queue=worker_queue, ) - compute_task.apply_async((channel_name, task, task.compute_plan_key), queue=worker_queue, task_id=task.key) + compute_task.apply_async( + (channel_name, task, task.compute_plan_key), + queue=worker_queue, + ) @app.task( diff --git a/backend/substrapp/tasks/tasks_save_image.py b/backend/substrapp/tasks/tasks_save_image.py index feb5a0455..23c8672a1 100644 --- a/backend/substrapp/tasks/tasks_save_image.py +++ b/backend/substrapp/tasks/tasks_save_image.py @@ -3,15 +3,9 @@ import os import pathlib from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING from typing import Any import structlog - -if TYPE_CHECKING: - from billiard.einfo import ExceptionInfo - -from celery import Task from django.conf import settings from django.core.files import File @@ -20,9 +14,10 @@ from image_transfer import make_payload from substrapp.compute_tasks import utils from substrapp.docker_registry import USER_IMAGE_REPOSITORY +from substrapp.models import FailedAssetKind from substrapp.models import FunctionImage from substrapp.orchestrator import get_orchestrator_client -from substrapp.tasks.tasks_compute_task import ComputeTask +from substrapp.tasks.task import FailableTask REGISTRY = settings.REGISTRY REGISTRY_SCHEME = settings.REGISTRY_SCHEME @@ -31,7 +26,7 @@ logger = structlog.get_logger("worker") -class SaveImageTask(Task): +class SaveImageTask(FailableTask): autoretry_for = settings.CELERY_TASK_AUTORETRY_FOR max_retries = settings.CELERY_TASK_MAX_RETRIES retry_backoff = settings.CELERY_TASK_RETRY_BACKOFF @@ -41,27 +36,19 @@ class SaveImageTask(Task): reject_on_worker_lost = True ignore_result = False + asset_type = FailedAssetKind.FAILED_ASSET_FUNCTION + @property def attempt(self) -> int: return self.request.retries + 1 # type: ignore - def on_failure( - self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo - ) -> None: - logger.error(exc) - logger.error(einfo) - function_key, channel_name = self.get_task_info(args, kwargs) - with get_orchestrator_client(channel_name) as client: - client.update_function_status( - function_key=function_key, action=orchestrator.function_pb2.FUNCTION_ACTION_FAILED - ) - # Returns (function key, channel) def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: function = orchestrator.Function.parse_raw(kwargs["function_serialized"]) channel_name = kwargs["channel_name"] return function.key, channel_name + # Celery does not provide unpacked arguments, we are doing it in `get_task_info` def on_success(self, retval: dict[str, Any], task_id: str, args: tuple, kwargs: dict[str, Any]) -> None: function_key, channel_name = self.get_task_info(args, kwargs) with get_orchestrator_client(channel_name) as client: @@ -80,7 +67,7 @@ def on_success(self, retval: dict[str, Any], task_id: str, args: tuple, kwargs: # Ack late and reject on worker lost allows use to # see http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-reject-on-worker-lost # and https://github.com/celery/celery/issues/5106 -def save_image_task(task: ComputeTask, function_serialized: str, channel_name: str) -> tuple[str, str]: +def save_image_task(task: SaveImageTask, function_serialized: str, channel_name: str) -> tuple[str, str]: logger.info("Starting save_image_task") logger.info(f"Parameters: function_serialized {function_serialized}, " f"channel_name {channel_name}") # create serialized image diff --git a/backend/substrapp/tests/tasks/test_compute_task.py b/backend/substrapp/tests/tasks/test_compute_task.py index 8ed556294..aee99f9f6 100644 --- a/backend/substrapp/tests/tasks/test_compute_task.py +++ b/backend/substrapp/tests/tasks/test_compute_task.py @@ -1,9 +1,7 @@ import datetime import errno -import io import tempfile from functools import wraps -from typing import Type from unittest.mock import MagicMock import pytest @@ -19,7 +17,6 @@ from substrapp.exceptions import OrganizationHttpError from substrapp.tasks import tasks_compute_task from substrapp.tasks.tasks_compute_task import compute_task -from substrapp.utils.errors import store_failure CHANNEL = "mychannel" @@ -185,24 +182,6 @@ def basic_retry(exc, **retry_kwargs): assert mock_retry.call_count == 2 -@pytest.mark.django_db -@pytest.mark.parametrize("logs", [b"", b"Hello, World!"]) -def test_store_failure_execution_error(logs: bytes): - compute_task_key = "42ff54eb-f4de-43b2-a1a0-a9f4c5f4737f" - exc = errors.ExecutionError(logs=io.BytesIO(logs)) - - failure_report = store_failure(exc, compute_task_key) - failure_report.refresh_from_db() - - assert str(failure_report.compute_task_key) == compute_task_key - assert failure_report.logs.read() == logs - - -@pytest.mark.parametrize("exc_class", [Exception]) -def test_store_failure_ignored_exception(exc_class: Type[Exception]): - assert store_failure(exc_class(), "uuid") is None - - @pytest.mark.django_db def test_send_profiling_event(mock_retry: MagicMock, mocker: MockerFixture): mock_post = mocker.patch("substrapp.clients.organization.post") diff --git a/backend/substrapp/tests/tasks/test_store_asset_failure_report.py b/backend/substrapp/tests/tasks/test_store_asset_failure_report.py new file mode 100644 index 000000000..da897dce8 --- /dev/null +++ b/backend/substrapp/tests/tasks/test_store_asset_failure_report.py @@ -0,0 +1,69 @@ +import io +import pickle +from typing import Type + +import pytest +from pytest_mock import MockerFixture + +from substrapp.compute_tasks import errors +from substrapp.compute_tasks.errors import ComputeTaskErrorType +from substrapp.models import FailedAssetKind +from substrapp.tasks.tasks_asset_failure_report import store_asset_failure_report +from substrapp.utils.errors import store_failure + +CHANNEL = "mychannel" + + +@pytest.fixture +def mock_orchestrator_client(mocker: MockerFixture): + return mocker.patch("substrapp.tasks.tasks_asset_failure_report.get_orchestrator_client") + + +@pytest.mark.django_db +def test_store_asset_failure_report_success(mock_orchestrator_client: MockerFixture): + exc = errors.ExecutionError(io.BytesIO(b"logs")) + exception_pickled = pickle.dumps(exc) + store_asset_failure_report( + asset_key="e21f6352-75c1-4b79-9a00-1f547697ef25", + asset_type=FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, + channel_name=CHANNEL, + exception_pickled=exception_pickled, + ) + + +def test_store_asset_failure_report_ignored(mock_orchestrator_client): + exception_pickled = pickle.dumps(Exception()) + store_asset_failure_report( + asset_key="750836e4-0def-465a-8397-57c49ebd38bf", + asset_type=FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, + channel_name=CHANNEL, + exception_pickled=exception_pickled, + ) + + +@pytest.mark.django_db +@pytest.mark.parametrize("logs", [b"", b"Hello, World!"]) +def test_store_failure_execution_error(logs: bytes): + compute_task_key = "42ff54eb-f4de-43b2-a1a0-a9f4c5f4737f" + exc = errors.ExecutionError(logs=io.BytesIO(logs)) + + failure_report = store_failure( + exc, + compute_task_key, + FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, + error_type=ComputeTaskErrorType.EXECUTION_ERROR.value, + ) + failure_report.refresh_from_db() + + assert str(failure_report.asset_key) == compute_task_key + assert failure_report.logs.read() == logs + + +@pytest.mark.parametrize("exc_class", [Exception]) +def test_store_failure_ignored_exception(exc_class: Type[Exception]): + assert ( + store_failure( + exc_class(), "uuid", FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, ComputeTaskErrorType.INTERNAL_ERROR.value + ) + is None + ) diff --git a/backend/substrapp/utils/errors.py b/backend/substrapp/utils/errors.py index 730e40a6b..013406f66 100644 --- a/backend/substrapp/utils/errors.py +++ b/backend/substrapp/utils/errors.py @@ -2,27 +2,31 @@ from django.core import files -from builder import exceptions as builder_errors +from orchestrator import failure_report_pb2 from substrapp import models from substrapp import utils -from substrapp.compute_tasks import errors as compute_task_errors -def store_failure(exc: Exception, compute_task_key: str) -> Optional[models.ComputeTaskFailureReport]: +def store_failure( + exception: Exception, + asset_key: str, + asset_type: models.FailedAssetKind, + error_type: failure_report_pb2.ErrorType.ValueType, +) -> Optional[models.AssetFailureReport]: """If the provided exception is a `BuildError` or an `ExecutionError`, store its logs in the Django storage and in the database. Otherwise, do nothing. Returns: - An instance of `models.ComputeTaskFailureReport` storing the error logs or None if the provided exception is + An instance of `models.AssetFailureReport` storing the error logs or None if the provided exception is neither a `BuildError` nor an `ExecutionError`. """ - if not isinstance(exc, (compute_task_errors.ExecutionError, builder_errors.BuildError)): + if error_type not in [failure_report_pb2.ERROR_TYPE_BUILD, failure_report_pb2.ERROR_TYPE_EXECUTION]: return None - file = files.File(exc.logs) - failure_report = models.ComputeTaskFailureReport( - compute_task_key=compute_task_key, logs_checksum=utils.get_hash(file) + file = files.File(exception.logs) + failure_report = models.AssetFailureReport( + asset_key=asset_key, asset_type=asset_type, logs_checksum=utils.get_hash(file) ) - failure_report.logs.save(name=compute_task_key, content=file, save=True) + failure_report.logs.save(name=asset_key, content=file, save=True) return failure_report