From 3eed000f8ad247857e4638b3df94c675a9d33d63 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 14 Dec 2023 11:02:07 +0100 Subject: [PATCH] test: add extra tests for dynamic batching --- jina/serve/runtimes/worker/batch_queue.py | 9 +++--- .../dynamic_batching/test_dynamic_batching.py | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 4f8cc387a9150..b4100c6f6d880 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -72,17 +72,17 @@ def _cancel_timer_if_pending(self): def _start_timer(self): self._cancel_timer_if_pending() self._timer_task = asyncio.create_task( - self._sleep_then_set(self._flush_trigger) + self._sleep_then_set() ) self._timer_started = True - async def _sleep_then_set(self, event: Event): + async def _sleep_then_set(self): """Sleep and then set the event :param event: event to set """ await asyncio.sleep(self._timeout / 1000) - event.set() + self._flush_trigger.set() async def push(self, request: DataRequest) -> asyncio.Queue: """Append request to the the list of requests to be processed. @@ -220,7 +220,6 @@ def batch(iterable_1, iterable_2, n=1): # communicate that the request has been processed properly. At this stage the data_lock is ours and # therefore noone can add requests to this list. self._flush_trigger: Event = Event() - self._cancel_timer_if_pending() self._timer_task = None try: if not docarray_v2: @@ -274,7 +273,7 @@ def batch(iterable_1, iterable_2, n=1): await request_full.put(exc) else: # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend(batch_res_docs) + non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch) non_assigned_to_response_request_idxs.extend(req_idxs) num_assigned_docs = await _assign_results( non_assigned_to_response_docs, diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 3e174162c2894..b21769e09fddd 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -14,6 +14,7 @@ DocumentArray, Executor, Flow, + Deployment, dynamic_batching, requests, ) @@ -21,6 +22,7 @@ from jina.serve.networking.utils import send_request_sync from jina.serve.runtimes.servers import BaseServer from jina_cli.api import executor_native +from jina.proto import jina_pb2 from tests.helper import _generate_pod_args cur_dir = os.path.dirname(__file__) @@ -311,6 +313,7 @@ def test_preferred_batch_size(add_parameters, use_stream): assert time_taken < TIMEOUT_TOLERANCE +@pytest.mark.repeat(10) @pytest.mark.parametrize('use_stream', [False, True]) def test_correctness(use_stream): f = Flow().add(uses=PlaceholderExecutor) @@ -492,6 +495,7 @@ def test_param_correctness(use_stream): ] assert [doc.text for doc in results[2]] == [f'D{str(PARAM1)}'] + @pytest.mark.parametrize( 'uses', [ @@ -622,3 +626,29 @@ def test_failure_propagation(): '/wronglennone', inputs=DocumentArray([Document(text=str(i)) for i in range(8)]), ) + + +@pytest.mark.repeat(10) +def test_exception_handling_in_dynamic_batch(): + class SlowExecutorWithException(Executor): + + @dynamic_batching(preferred_batch_size=3, timeout=1000) + @requests(on='/foo') + def foo(self, docs, **kwargs): + for doc in docs: + if doc.text == 'fail': + raise Exception('Fail is in the Batch') + + depl = Deployment(uses=SlowExecutorWithException) + + with depl: + da = DocumentArray([Document(text='good') for _ in range(50)]) + da[4].text = 'fail' + responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True) + assert len(responses) == 50 # 1 request per input + num_failed_requests = 0 + for r in responses: + if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR: + num_failed_requests += 1 + + assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing