Skip to content

Commit

Permalink
fix(sagemaker): read csv with escaped chars (#6102)
Browse files Browse the repository at this point in the history
Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
deepankarm and jina-bot authored Nov 2, 2023
1 parent 46c0638 commit 7bed945
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 18 deletions.
18 changes: 13 additions & 5 deletions jina/serve/runtimes/worker/http_sagemaker_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ async def post(request: Request):
return await process(input_model(**json_body))

elif content_type in ('text/csv', 'application/csv'):
import csv
from io import StringIO

bytes_body = await request.body()
csv_body = bytes_body.decode('utf-8')
if not is_valid_csv(csv_body):
Expand All @@ -147,17 +150,22 @@ async def post(request: Request):
# This will also enforce the order of the fields in the csv file.
# This also means, all fields in the input model must be present in the
# csv file including the optional ones.
# We also expect the csv file to have no quotes and use the escape char '\'
field_names = [f for f in input_doc_list_model.__fields__]
data = []
for line in csv_body.splitlines():
fields = line.split(',')
if len(fields) != len(field_names):
for line in csv.reader(
StringIO(csv_body),
delimiter=',',
quoting=csv.QUOTE_NONE,
escapechar='\\',
):
if len(line) != len(field_names):
raise HTTPException(
status_code=400,
detail=f'Invalid CSV format. Line {fields} doesn\'t match '
detail=f'Invalid CSV format. Line {line} doesn\'t match '
f'the expected field order {field_names}.',
)
data.append(input_doc_list_model(**dict(zip(field_names, fields))))
data.append(input_doc_list_model(**dict(zip(field_names, line))))

return await process(input_model(data=data))

Expand Down
52 changes: 52 additions & 0 deletions jina_cli/autocomplete.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,56 @@
'normalize',
'logs',
],
'cloud deployment list': ['--help', '--phase', '--name', '--labels'],
'cloud deployment remove': ['--help', '--phase'],
'cloud deployment update': ['--help'],
'cloud deployment restart': ['--help'],
'cloud deployment pause': ['--help'],
'cloud deployment resume': ['--help'],
'cloud deployment scale': ['--help', '--replicas'],
'cloud deployment recreate': ['--help'],
'cloud deployment status': ['--help', '--verbose'],
'cloud deployment deploy': ['--help'],
'cloud deployment logs': ['--help'],
'cloud deployment': [
'--help',
'list',
'remove',
'update',
'restart',
'pause',
'resume',
'scale',
'recreate',
'status',
'deploy',
'logs',
],
'cloud jds list': ['--help', '--phase', '--name', '--labels'],
'cloud jds remove': ['--help', '--phase'],
'cloud jds update': ['--help'],
'cloud jds restart': ['--help'],
'cloud jds pause': ['--help'],
'cloud jds resume': ['--help'],
'cloud jds scale': ['--help', '--replicas'],
'cloud jds recreate': ['--help'],
'cloud jds status': ['--help', '--verbose'],
'cloud jds deploy': ['--help'],
'cloud jds logs': ['--help'],
'cloud jds': [
'--help',
'list',
'remove',
'update',
'restart',
'pause',
'resume',
'scale',
'recreate',
'status',
'deploy',
'logs',
],
'cloud job list': ['--help'],
'cloud job remove': ['--help'],
'cloud job logs': ['--help'],
Expand Down Expand Up @@ -325,6 +375,8 @@
'logout',
'flow',
'flows',
'deployment',
'jds',
'job',
'jobs',
'secret',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


class TextDoc(BaseDoc):
text: str
text: str = Field(description="The text of the document", default="")


class EmbeddingResponseModel(BaseDoc):
class EmbeddingResponseModel(TextDoc):
embeddings: NdArray = Field(description="The embedding of the texts", default=[])

class Config(BaseDoc.Config):
Expand All @@ -25,6 +25,10 @@ def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseMode
ret = []
for doc in docs:
ret.append(
EmbeddingResponseModel(id=doc.id, embeddings=np.random.random((1, 64)))
EmbeddingResponseModel(
id=doc.id,
text=doc.text,
embeddings=np.random.random((1, 64)),
)
)
return DocList[EmbeddingResponseModel](ret)
37 changes: 27 additions & 10 deletions tests/integration/docarray_v2/sagemaker/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import csv
import io
import os
import time
from contextlib import AbstractContextManager
Expand Down Expand Up @@ -73,7 +75,14 @@ def test_provider_sagemaker_pod_inference():
assert len(resp_json['data'][0]['embeddings'][0]) == 64


def test_provider_sagemaker_pod_batch_transform_valid():
@pytest.mark.parametrize(
"filename",
[
"valid_input_1.csv",
"valid_input_2.csv",
],
)
def test_provider_sagemaker_pod_batch_transform_valid(filename):
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
args, _ = set_pod_parser().parse_known_args(
[
Expand All @@ -86,24 +95,32 @@ def test_provider_sagemaker_pod_batch_transform_valid():
)
with Pod(args):
# Test `POST /invocations` endpoint for batch-transform with valid input
with open(
os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r'
) as f:
texts = []
with open(os.path.join(os.path.dirname(__file__), filename), "r") as f:
csv_data = f.read()

for line in csv.reader(
io.StringIO(csv_data),
delimiter=",",
quoting=csv.QUOTE_NONE,
escapechar="\\",
):
texts.append(line[1])

resp = requests.post(
f'http://localhost:{sagemaker_port}/invocations',
f"http://localhost:{sagemaker_port}/invocations",
headers={
'accept': 'application/json',
'content-type': 'text/csv',
"accept": "application/json",
"content-type": "text/csv",
},
data=csv_data,
)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json['data']) == 10
for d in resp_json['data']:
assert len(d['embeddings'][0]) == 64
assert len(resp_json["data"]) == 10
for idx, d in enumerate(resp_json["data"]):
assert d["text"] == texts[idx]
assert len(d["embeddings"][0]) == 64


def test_provider_sagemaker_pod_batch_transform_invalid():
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/docarray_v2/sagemaker/valid_input_2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
1,abcd
2,efgh\, with comma
3,ijkl with \"quote\"
4,mn\\nop with newline
5,qrst with \\ backslash
6,uvwx with both\, comma and \"quote\"
7,yzab with newline\\nand comma\,
8,cde\"f with embedded quote
9,ghij with special char #
10,klmn with everything\, \"quote\" \\backslash and \\nnewline

0 comments on commit 7bed945

Please sign in to comment.