Skip to content

Commit

Permalink
test: add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 18, 2024
1 parent 148f40a commit 05c1ee4
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
16 changes: 15 additions & 1 deletion jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,23 @@ def _validate_sagemaker(self):
return

def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]):
import collections

def deep_update(source, overrides):
"""
Update a nested dictionary or similar mapping.
Modify ``source`` in place.
"""
for key, value in overrides.items():
if isinstance(value, collections.Mapping) and value:
returned = deep_update(source.get(key, {}), value)
source[key] = returned
else:
source[key] = overrides[key]
return source
if _dynamic_batching:
self.dynamic_batching = getattr(self, 'dynamic_batching', {})
self.dynamic_batching.update(_dynamic_batching)
self.dynamic_batching = deep_update(self.dynamic_batching, _dynamic_batching)

def _add_metas(self, _metas: Optional[Dict]):
from jina.serve.executors.metas import get_default_metas
Expand Down
4 changes: 2 additions & 2 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def _assign_results(
return num_assigned_docs

def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Optional = None):
if n is None and iterable_metrics is None:
if n is None:
yield iterable_1, iterable_2
return
if n is not None and iterable_metrics is None:
Expand All @@ -262,7 +262,7 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option

if batch_weight >= n:
yield iterable_1[batch_idx: i + 1], iterable_2[batch_idx: i + 1]
batch_idx = i
batch_idx = i + 1
batch_weight = 0

# Yield any remaining items
Expand Down
64 changes: 64 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,67 @@ def foo(self, docs, **kwargs):

assert smaller_than_5 == (1 if allow_concurrent else 0)
assert larger_than_5 > 0


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [False, True])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=10, custom_metric=lambda x: len(x.text))
@requests(on='/foo')
def foo(self, docs, **kwargs):
time.sleep(0.5)
total_len = sum([len(doc.text) for doc in docs])
for doc in docs:
doc.text = f"{total_len}"

depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input

# If custom_metric and flush all
if use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "10"

elif not use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "50"

elif use_custom_metric and flush_all:
# There will be 2 "10" and the rest will be "240"
num_10 = 0
num_240 = 0
for doc in res:
if doc.text == "10":
num_10 += 1
elif doc.text == "240":
num_240 += 1

assert num_10 == 2
assert num_240 == 48
elif not use_custom_metric and flush_all:
# There will be 10 "50" and the rest will be "200"
num_50 = 0
num_200 = 0
for doc in res:
if doc.text == "50":
num_50 += 1
elif doc.text == "200":
num_200 += 1

assert num_50 == 10
assert num_200 == 40

0 comments on commit 05c1ee4

Please sign in to comment.