Skip to content

Commit

Permalink
fix: gateway forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
NarekA committed Oct 27, 2023
1 parent 38a577b commit c4afd99
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
5 changes: 3 additions & 2 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,15 @@ async def send_streaming_message(self, doc: 'Document', on: str):
:param on: Request endpoint
:yields: responses
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'headers': {'Accept': 'text/event-stream'},
'data': doc.json().encode(),
'json': req_dict,
}

async with self.session.get(**request_kwargs) as response:
async for chunk, _ in response.content.iter_chunks():
async for chunk in response.content.iter_any():
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
Expand Down
5 changes: 2 additions & 3 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,8 @@ def add_streaming_routes(
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request, body: input_doc_model = None):
if not body:
query_params = dict(request.query_params)
body = input_doc_model.parse_obj(query_params)
body = body or dict(request.query_params)
body = input_doc_model.parse_obj(body) if docarray_v2 else Document.from_dict(body)

async def event_generator():
async for doc, error in streamer.stream_doc(
Expand Down
14 changes: 11 additions & 3 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,25 @@ def _http_fastapi_default_app(
)
)

async def _load_balance(self, request):

async def _load_balance(self, request: 'aiohttp.web_request.Request'):
import aiohttp
from aiohttp import web

target_server = next(self.load_balancer_servers)
target_url = f'{target_server}{request.path_qs}'


try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
async with session.get(target_url) as response:
request_kwargs = {}
payload = await request.json()
if payload:
request_kwargs['json'] = payload


async with session.get(url=target_url, **request_kwargs) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
Expand Down
5 changes: 4 additions & 1 deletion jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,17 @@ def add_streaming_routes(
methods=['GET'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request, body: input_doc_model = None):
async def streaming_get(request: Request = None, body: input_doc_model = None):
if not body:
query_params = dict(request.query_params)
body = (
input_doc_model.parse_obj(query_params)
if docarray_v2
else Document.from_dict(query_params)
)
else:
if not docarray_v2:
body = Document.from_pydantic_model(body)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ async def non_gen_task(self, docs: DocumentArray, **kwargs):
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
@pytest.mark.parametrize('include_gateway', [False, True])
async def test_streaming_deployment(protocol, include_gateway):
from jina import Deployment

port = random_port()
docs = []

with Deployment(
uses=MyExecutor,
Expand All @@ -38,10 +38,12 @@ async def test_streaming_deployment(protocol, include_gateway):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
async for doc in client.stream_doc(
on='/hello', inputs=Document(text='hello world')
on='/hello', inputs=Document(text='hello world'), return_type=Document, input_type=Document
):
assert doc.text == f'hello world {i}'
docs.append(doc.text)
i += 1
assert docs == [f'hello world {i}' for i in range(100)]
assert len(docs) == 100


class WaitStreamExecutor(Executor):
Expand Down

0 comments on commit c4afd99

Please sign in to comment.