Skip to content

Commit

Permalink
Add tests for validation. Make streaming not depend on msg pack.
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Nov 26, 2024
1 parent 1399421 commit b14aa3b
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 54 deletions.
54 changes: 53 additions & 1 deletion truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib
import logging
import re
from typing import List
from typing import AsyncIterator, Iterator, List

import pydantic
import pytest
Expand Down Expand Up @@ -505,3 +505,55 @@ def run_remote(argument: object): ...
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
with public_api.run_local():
MultiIssue()


def test_raises_iterator_no_yield():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoYield\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint returns an iterator \(streaming\), it must have `yield` statements"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoYield(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator[str]:
return "123" # type: ignore[return-value]


def test_raises_yield_no_iterator():
match = (
rf"{TEST_FILE}:\d+ \(YieldNoIterator\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint is streaming \(has `yield` statements\), the return type must be an iterator"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class YieldNoIterator(chains.ChainletBase):
async def run_remote(self) -> str: # type: ignore[misc]
yield "123"


def test_raises_iterator_sync():
match = (
rf"{TEST_FILE}:\d+ \(IteratorSync\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Streaming endpoints \(containing `yield` statements\) are only supported for async endpoints"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorSync(chains.ChainletBase):
def run_remote(self) -> Iterator[str]:
yield "123"


def test_raises_iterator_no_arg():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoArg\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Iterators must be annotated with type \(one of \['str', 'bytes'\]\)"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoArg(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator:
yield "123"
12 changes: 11 additions & 1 deletion truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ def _validate_streaming_output_type(
origin = get_origin(annotation)
assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator)
args = get_args(annotation)
if len(args) < 1:
_collect_error(
f"Iterators must be annotated with type (one of {list(x.__name__ for x in _STREAM_TYPES)}).",
_ErrorKind.IO_TYPE_ERROR,
location,
)
return definitions.StreamingTypeDescriptor(
raw=annotation, origin_type=origin, arg_type=bytes
)

assert len(args) == 1, "AsyncIterator cannot have more than 1 arg."
arg = args[0]
if arg not in _STREAM_TYPES:
Expand Down Expand Up @@ -479,7 +489,7 @@ def _validate_and_describe_endpoint(
_collect_error(
"`Streaming endpoints (containing `yield` statements) are only "
"supported for async endpoints.",
_ErrorKind.TYPE_ERROR,
_ErrorKind.IO_TYPE_ERROR,
location,
)

Expand Down
41 changes: 20 additions & 21 deletions truss-chains/truss_chains/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import struct
import sys
from collections.abc import AsyncIterator
from typing import Generic, Optional, Protocol, Type, TypeVar, overload
from typing import Generic, Optional, Protocol, Type, TypeVar, Union, overload

import pydantic
from truss.templates.shared import serialization

TAG_SIZE = 5 # uint8 + uint32.
JSONType = Union[str, int, float, bool, None, list["JSONType"], dict[str, "JSONType"]]
_T = TypeVar("_T")

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -136,14 +136,14 @@ async def readexactly(self, num_bytes: int) -> bytes:

class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
_stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
_footer_data: Optional[serialization.MsgPackType]
_footer_data: Optional[bytes]

async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]: ...
async def _read(self) -> tuple[_Delimiter, bytes]: ...


class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]):
_stream: _ByteReader
_footer_data: Optional[serialization.MsgPackType]
_footer_data: Optional[bytes]

def __init__(
self,
Expand All @@ -159,24 +159,24 @@ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]:
enum_value, length = struct.unpack(">BI", tag)
return _Delimiter(enum_value), length

async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]:
async def _read(self) -> tuple[_Delimiter, bytes]:
try:
tag = await self._stream.readexactly(TAG_SIZE)
# It's ok to read nothing (end of stream), but unexpected to read partial.
except asyncio.IncompleteReadError:
raise
except EOFError:
return _Delimiter.END, None
return _Delimiter.END, b""

delimiter, length = self._unpack_tag(tag)
if not length:
return delimiter, None
return delimiter, b""
data_bytes = await self._stream.readexactly(length)
print(f"Read Delimiter: {delimiter}")
return delimiter, serialization.truss_msgpack_deserialize(data_bytes)
return delimiter, data_bytes

async def read_items(self) -> AsyncIterator[ItemT]:
delimiter, data_dict = await self._read()
delimiter, data_bytes = await self._read()
if delimiter == _Delimiter.HEADER:
raise ValueError(
"Called `read_items`, but there the stream contains header data, which "
Expand All @@ -187,39 +187,39 @@ async def read_items(self) -> AsyncIterator[ItemT]:

assert delimiter == _Delimiter.ITEM
while True:
yield self._stream_types.item_t.model_validate(data_dict)
yield self._stream_types.item_t.model_validate_json(data_bytes)
# We don't know if the next data is another item, footer or the end.
delimiter, data_dict = await self._read()
delimiter, data_bytes = await self._read()
if delimiter == _Delimiter.END:
return
if delimiter == _Delimiter.FOOTER:
self._footer_data = data_dict
self._footer_data = data_bytes
return


class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]):
async def read_header(
self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT],
) -> HeaderT:
delimiter, data_dict = await self._read()
delimiter, data_bytes = await self._read()
if delimiter != _Delimiter.HEADER:
raise ValueError("Stream does not contain header.")
return self._stream_types.header_t.model_validate(data_dict)
return self._stream_types.header_t.model_validate_json(data_bytes)


class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]):
_footer_data: Optional[serialization.MsgPackType]
_footer_data: Optional[bytes]

async def read_footer(
self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT],
) -> FooterT:
if self._footer_data is None:
delimiter, data_dict = await self._read()
delimiter, data_bytes = await self._read()
if delimiter != _Delimiter.FOOTER:
raise ValueError("Stream does not contain footer.")
self._footer_data = data_dict
self._footer_data = data_bytes

footer = self._stream_types.footer_t.model_validate(self._footer_data)
footer = self._stream_types.footer_t.model_validate_json(self._footer_data)
self._footer_data = None
return footer

Expand Down Expand Up @@ -298,8 +298,7 @@ def _pack_tag(delimiter: _Delimiter, length: int) -> bytes:
return struct.pack(">BI", delimiter.value, length)

def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes:
data_dict = obj.model_dump()
data_bytes = serialization.truss_msgpack_serialize(data_dict)
data_bytes = obj.model_dump_json().encode()
data = bytearray(self._pack_tag(delimiter, len(data_bytes)))
data.extend(data_bytes)
# Starlette cannot handle byte array, but view works..
Expand Down
23 changes: 17 additions & 6 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

import opentelemetry.sdk.trace as sdk_trace
import pydantic
import starlette.requests
import starlette.responses
from anyio import Semaphore, to_thread
Expand Down Expand Up @@ -56,6 +57,15 @@
TRT_LLM_EXTENSION_NAME = "trt_llm"
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30

InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel]
OutputType = Union[
serialization.JSONType,
serialization.MsgPackType,
Generator[bytes, None, None],
AsyncGenerator[bytes, None],
"starlette.responses.Response",
]


@asynccontextmanager
async def deferred_semaphore_and_span(
Expand Down Expand Up @@ -520,7 +530,7 @@ async def poll_for_environment_updates(self) -> None:

async def preprocess(
self,
inputs: serialization.InputType,
inputs: InputType,
request: starlette.requests.Request,
) -> Any:
descriptor = self.model_descriptor.preprocess
Expand All @@ -538,7 +548,7 @@ async def predict(
self,
inputs: Any,
request: starlette.requests.Request,
) -> Union[serialization.OutputType, Any]:
) -> Union[OutputType, Any]:
# The result can be a serializable data structure, byte-generator, a request,
# or, if `postprocessing` is used, anything. In the last case postprocessing
# must convert the result to something serializable.
Expand All @@ -555,9 +565,9 @@ async def predict(

async def postprocess(
self,
result: Union[serialization.InputType, Any],
result: Union[InputType, Any],
request: starlette.requests.Request,
) -> serialization.OutputType:
) -> OutputType:
# The postprocess function can handle outputs of `predict`, but not
# generators and responses - in that case predict must return directly
# and postprocess is skipped.
Expand Down Expand Up @@ -642,9 +652,9 @@ async def _buffered_response_generator() -> AsyncGenerator[bytes, None]:

async def __call__(
self,
inputs: Optional[serialization.InputType],
inputs: Optional[InputType],
request: starlette.requests.Request,
) -> serialization.OutputType:
) -> OutputType:
"""
Returns result from: preprocess -> predictor -> postprocess.
"""
Expand Down Expand Up @@ -726,6 +736,7 @@ async def __call__(
), tracing.detach_context():
postprocess_result = await self.postprocess(predict_result, request)

final_result: OutputType
if isinstance(postprocess_result, BaseModel):
# If we return a pydantic object, convert it back to a dict
with tracing.section_as_event(span_post, "dump-pydantic"):
Expand Down
6 changes: 3 additions & 3 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.routing import APIRoute as FastAPIRoute
from model_wrapper import ModelWrapper
from model_wrapper import InputType, ModelWrapper
from opentelemetry import propagate as otel_propagate
from opentelemetry import trace
from opentelemetry.sdk import trace as sdk_trace
Expand Down Expand Up @@ -104,7 +104,7 @@ async def _parse_body(
body_raw: bytes,
truss_schema: Optional[TrussSchema],
span: trace.Span,
) -> serialization.InputType:
) -> InputType:
if self.is_binary(request):
with tracing.section_as_event(span, "binary-deserialize"):
inputs = serialization.truss_msgpack_deserialize(body_raw)
Expand Down Expand Up @@ -157,7 +157,7 @@ async def predict(
with self._tracer.start_as_current_span(
"predict-endpoint", context=trace_ctx
) as span:
inputs: Optional[serialization.InputType]
inputs: Optional[InputType]
if model.model_descriptor.skip_input_parsing:
inputs = None
else:
Expand Down
23 changes: 1 addition & 22 deletions truss/templates/shared/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,9 @@
import uuid
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Optional,
Union,
)

import pydantic
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

if TYPE_CHECKING:
import starlette.responses
from numpy.typing import NDArray


Expand All @@ -38,14 +25,6 @@
List["MsgPackType"],
Dict[str, "MsgPackType"],
]
InputType = Union[JSONType, MsgPackType, pydantic.BaseModel]
OutputType = Union[
JSONType,
MsgPackType,
Generator[bytes, None, None],
AsyncGenerator[bytes, None],
"starlette.responses.Response",
]


# mostly cribbed from django.core.serializer.DjangoJSONEncoder
Expand Down

0 comments on commit b14aa3b

Please sign in to comment.