diff --git a/CHANGELOG.md b/CHANGELOG.md index d41c572cc..d9d6b451a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Celery task `FailableTask` that contains the logic to store the failure report, that can be re-used in different assets. ([#727](https://github.com/Substra/substra-backend/pull/727)) - Add `FunctionStatus` enum ([#714](https://github.com/Substra/orchestrator/pull/714)) - BREAKING: Add `status` on `api.Function` (type `FunctionStatus`) ([#714](https://github.com/Substra/substra-backend/pull/714)) +- Tests to ensure build order is made in order of submission (including retries) ([#740](https://github.com/Substra/substra-backend/pull/740)) ### Changed diff --git a/backend/backend/settings/test.py b/backend/backend/settings/test.py index 29d3905a8..166f76f10 100644 --- a/backend/backend/settings/test.py +++ b/backend/backend/settings/test.py @@ -31,3 +31,7 @@ MSP_ID = "testOrgMSP" CHANNELS = {"mychannel": {"model_export_enabled": False}} + +CELERY_BROKER_URL = "memory://" +CELERY_RESULT_BACKEND = "cache+memory://" +CELERY_TASK_ALWAYS_EAGER = False diff --git a/backend/builder/tasks/task.py b/backend/builder/tasks/task.py index 459f00e99..50f59ddb3 100644 --- a/backend/builder/tasks/task.py +++ b/backend/builder/tasks/task.py @@ -34,6 +34,6 @@ def before_start(self, task_id: str, args: tuple, kwargs: dict) -> None: ) def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: - function = orchestrator.Function.parse_raw(kwargs["function_serialized"]) + function = orchestrator.Function.model_validate_json(kwargs["function_serialized"]) channel_name = kwargs["channel_name"] return function.key, channel_name diff --git a/backend/builder/tasks/tasks_build_image.py b/backend/builder/tasks/tasks_build_image.py index dda71d8dc..d335fe1c4 100644 --- a/backend/builder/tasks/tasks_build_image.py +++ b/backend/builder/tasks/tasks_build_image.py @@ -20,7 +20,7 @@ # see http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-reject-on-worker-lost # and https://github.com/celery/celery/issues/5106 def build_image(task: BuildTask, function_serialized: str, channel_name: str) -> None: - function = orchestrator.Function.parse_raw(function_serialized) + function = orchestrator.Function.model_validate_json(function_serialized) attempt = 0 while attempt <= task.max_retries: diff --git a/backend/builder/tests/test_task_build_image.py b/backend/builder/tests/test_task_build_image.py index 151821d57..bb868d196 100644 --- a/backend/builder/tests/test_task_build_image.py +++ b/backend/builder/tests/test_task_build_image.py @@ -1,9 +1,17 @@ +import time + +import celery import pytest +import orchestrator.mock as orc_mock from builder.exceptions import BuildError +from builder.exceptions import BuildRetryError +from builder.tasks.tasks_build_image import build_image from substrapp.models import FailedAssetKind from substrapp.utils.errors import store_failure +CHANNEL = "mychannel" + @pytest.mark.django_db def test_store_failure_build_error(): @@ -18,3 +26,60 @@ def test_store_failure_build_error(): assert str(failure_report.asset_key) == compute_task_key assert failure_report.logs.read() == str.encode(msg) + + +@pytest.mark.parametrize("execution_number", range(10)) +def test_order_building_success(celery_app, celery_worker, mocker, execution_number): + function_1 = orc_mock.FunctionFactory() + function_2 = orc_mock.FunctionFactory() + + # BuildTask `before_start` uses this client to change the status, which would lead to `OrcError` + mocker.patch("builder.tasks.task.get_orchestrator_client") + mocker.patch("builder.tasks.tasks_build_image.build_image_if_missing", side_effect=lambda x, y: time.sleep(0.5)) + + result_1 = build_image.apply_async( + kwargs={"function_serialized": function_1.model_dump_json(), "channel_name": CHANNEL} + ) + result_2 = build_image.apply_async( + kwargs={"function_serialized": function_2.model_dump_json(), "channel_name": CHANNEL} + ) + # get waits for the completion + result_1.get() + + assert result_1.state == celery.states.SUCCESS + assert result_2.state == "WAITING" + + +@pytest.mark.parametrize("execution_number", range(10)) +def test_order_building_retry(celery_app, celery_worker, mocker, execution_number): + function_retry = orc_mock.FunctionFactory() + function_other = orc_mock.FunctionFactory() + + # Only retry once for function_retry + def side_effect_creator(): + already_raised = False + + def side_effect(*args, **kwargs): + nonlocal already_raised + time.sleep(0.5) + key = args[1].key + if not already_raised and function_retry.key == key: + already_raised = True + raise BuildRetryError("random retriable error") + + return side_effect + + # BuildTask `before_start` uses this client to change the status, which would lead to `OrcError` + mocker.patch("builder.tasks.task.get_orchestrator_client") + mocker.patch("builder.tasks.tasks_build_image.build_image_if_missing", side_effect=side_effect_creator()) + + result_retry = build_image.apply_async( + kwargs={"function_serialized": function_retry.model_dump_json(), "channel_name": CHANNEL} + ) + result_other = build_image.apply_async( + kwargs={"function_serialized": function_other.model_dump_json(), "channel_name": CHANNEL} + ) + + result_retry.get() + assert result_retry.state == celery.states.SUCCESS + assert result_other.state == "WAITING" diff --git a/backend/dev-requirements.txt b/backend/dev-requirements.txt index eac575e7a..877c2e474 100644 --- a/backend/dev-requirements.txt +++ b/backend/dev-requirements.txt @@ -21,4 +21,5 @@ mypy==1.4.1 djangorestframework-stubs==1.8.0 django-stubs==1.14.0 celery-types==0.14.0 -docker==6.1.3 \ No newline at end of file +docker==6.1.3 +celery[pytest] \ No newline at end of file diff --git a/backend/substrapp/tasks/task.py b/backend/substrapp/tasks/task.py index 9c1915192..60db3d507 100644 --- a/backend/substrapp/tasks/task.py +++ b/backend/substrapp/tasks/task.py @@ -91,7 +91,7 @@ def on_retry(self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, def split_args(self, celery_args: tuple) -> tuple[str, orchestrator.ComputeTask]: channel_name = celery_args[0] - task = orchestrator.ComputeTask.parse_raw(celery_args[1]) + task = orchestrator.ComputeTask.model_validate_json(celery_args[1]) return channel_name, task def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: diff --git a/backend/substrapp/tasks/tasks_save_image.py b/backend/substrapp/tasks/tasks_save_image.py index 1ada2b6ae..15d9eaf16 100644 --- a/backend/substrapp/tasks/tasks_save_image.py +++ b/backend/substrapp/tasks/tasks_save_image.py @@ -46,7 +46,7 @@ def attempt(self) -> int: # Returns (function key, channel) def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]: - function = orchestrator.Function.parse_raw(kwargs["function_serialized"]) + function = orchestrator.Function.model_validate_json(kwargs["function_serialized"]) channel_name = kwargs["channel_name"] return function.key, channel_name @@ -75,7 +75,7 @@ def save_image_task(task: SaveImageTask, function_serialized: str, channel_name: logger.info("Starting save_image_task") logger.info(f"Parameters: function_serialized {function_serialized}, " f"channel_name {channel_name}") # create serialized image - function = orchestrator.Function.parse_raw(function_serialized) + function = orchestrator.Function.model_validate_json(function_serialized) container_image_tag = utils.container_image_tag_from_function(function) os.makedirs(SUBTUPLE_TMP_DIR, exist_ok=True)