Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Nov 27, 2024
1 parent b14aa3b commit 680ba4e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 48 deletions.
4 changes: 3 additions & 1 deletion truss-chains/examples/streaming/streaming_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class ConsumerOutput(pydantic.BaseModel):
strings: str


STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header, footer_t=Footer)
STREAM_TYPES = streaming.stream_types(
MyDataChunk, header_type=Header, footer_type=Footer
)


class Generator(chains.ChainletBase):
Expand Down
50 changes: 36 additions & 14 deletions truss-chains/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ async def to_bytes_iterator(data_stream) -> AsyncIterator[bytes]:

@pytest.mark.asyncio
async def test_streaming_with_header_and_footer():
types = streaming.stream_types(item_t=MyDataChunk, header_t=Header, footer_t=Footer)
types = streaming.stream_types(
item_type=MyDataChunk, header_type=Header, footer_type=Footer
)

writer = streaming.stream_writer(types)
header = Header(time=123.456, msg="Start of stream")
Expand Down Expand Up @@ -61,7 +63,7 @@ async def test_streaming_with_header_and_footer():

@pytest.mark.asyncio
async def test_streaming_with_items_only():
types = streaming.stream_types(item_t=MyDataChunk)
types = streaming.stream_types(item_type=MyDataChunk)
writer = streaming.stream_writer(types)

items = [
Expand All @@ -84,8 +86,8 @@ async def test_streaming_with_items_only():

@pytest.mark.asyncio
async def test_reading_header_when_none_sent():
types = streaming.stream_types(item_t=MyDataChunk, header_t=Header)
writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk))
types = streaming.stream_types(item_type=MyDataChunk, header_type=Header)
writer = streaming.stream_writer(types)
items = [MyDataChunk(words=["hello", "world"])]

data_stream = []
Expand All @@ -99,8 +101,8 @@ async def test_reading_header_when_none_sent():

@pytest.mark.asyncio
async def test_reading_items_with_wrong_model():
types_writer = streaming.stream_types(item_t=MyDataChunk)
types_reader = streaming.stream_types(item_t=Header) # Wrong item type
types_writer = streaming.stream_types(item_type=MyDataChunk)
types_reader = streaming.stream_types(item_type=Header) # Wrong item type
writer = streaming.stream_writer(types_writer)
items = [MyDataChunk(words=["hello", "world"])]
data_stream = []
Expand All @@ -117,9 +119,9 @@ async def test_reading_items_with_wrong_model():
@pytest.mark.asyncio
async def test_streaming_with_wrong_order():
types = streaming.stream_types(
item_t=MyDataChunk,
header_t=Header,
footer_t=Footer,
item_type=MyDataChunk,
header_type=Header,
footer_type=Footer,
)

writer = streaming.stream_writer(types)
Expand All @@ -129,19 +131,22 @@ async def test_streaming_with_wrong_order():
data_stream = []
for item in items:
data_stream.append(writer.yield_item(item))
data_stream.append(writer.yield_header(header))

with pytest.raises(
ValueError, match="Cannot yield header after other data has been sent."
):
data_stream.append(writer.yield_header(header))
data_stream.append(writer.yield_footer(footer))

reader = streaming.stream_reader(types, to_bytes_iterator(data_stream))

# Try to read header, should fail because the first data is an item
with pytest.raises(ValueError, match="Stream does not contain header."):
await reader.read_header()


@pytest.mark.asyncio
async def test_reading_items_without_consuming_header():
types = streaming.stream_types(item_t=MyDataChunk, header_t=Header)
types = streaming.stream_types(item_type=MyDataChunk, header_type=Header)
writer = streaming.stream_writer(types)
header = Header(time=123.456, msg="Start of stream")
items = [MyDataChunk(words=["hello", "world"])]
Expand All @@ -163,8 +168,8 @@ async def test_reading_items_without_consuming_header():

@pytest.mark.asyncio
async def test_reading_footer_when_none_sent():
types = streaming.stream_types(item_t=MyDataChunk, footer_t=Footer)
writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk))
types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer)
writer = streaming.stream_writer(types)
items = [MyDataChunk(words=["hello", "world"])]
data_stream = []
for item in items:
Expand All @@ -179,3 +184,20 @@ async def test_reading_footer_when_none_sent():
# Try to read footer, expect an error
with pytest.raises(ValueError, match="Stream does not contain footer."):
await reader.read_footer()


@pytest.mark.asyncio
async def test_reading_footer_with_no_items():
types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer)
writer = streaming.stream_writer(types)
footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream")
data_stream = [writer.yield_footer(footer)]

reader = streaming.stream_reader(types, to_bytes_iterator(data_stream))
read_items = []
async for item in reader.read_items():
read_items.append(item)
assert len(read_items) == 0

read_footer = await reader.read_footer()
assert read_footer == footer
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _validate_streaming_output_type(
raw=annotation, origin_type=origin, arg_type=bytes
)

assert len(args) == 1, "AsyncIterator cannot have more than 1 arg."
assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg."
arg = args[0]
if arg not in _STREAM_TYPES:
msg = (
Expand Down
79 changes: 47 additions & 32 deletions truss-chains/truss_chains/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import pydantic

TAG_SIZE = 5 # uint8 + uint32.
JSONType = Union[str, int, float, bool, None, list["JSONType"], dict[str, "JSONType"]]
_TAG_SIZE = 5 # uint8 + uint32.
_JSONType = Union[
str, int, float, bool, None, list["_JSONType"], dict[str, "_JSONType"]
]
_T = TypeVar("_T")

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -43,56 +45,57 @@ async def anext(iterable: AsyncIterator[_T]) -> _T:

@dataclasses.dataclass
class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]):
item_t: Type[ItemT]
header_t: HeaderTT # Is either `Type[HeaderT]` or `None`.
footer_t: FooterTT # Is either `Type[FooterT]` or `None`.
item_type: Type[ItemT]
header_type: HeaderTT # Is either `Type[HeaderT]` or `None`.
footer_type: FooterTT # Is either `Type[FooterT]` or `None`.


@overload
def stream_types(
item_t: Type[ItemT],
item_type: Type[ItemT],
*,
header_t: Type[HeaderT],
footer_t: Type[FooterT],
header_type: Type[HeaderT],
footer_type: Type[FooterT],
) -> StreamTypes[ItemT, HeaderT, FooterT]: ...


@overload
def stream_types(
item_t: Type[ItemT],
item_type: Type[ItemT],
*,
header_t: Type[HeaderT],
header_type: Type[HeaderT],
) -> StreamTypes[ItemT, HeaderT, None]: ...


@overload
def stream_types(
item_t: Type[ItemT],
item_type: Type[ItemT],
*,
footer_t: Type[FooterT],
footer_type: Type[FooterT],
) -> StreamTypes[ItemT, None, FooterT]: ...


@overload
def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ...
def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ...


def stream_types(
item_t: Type[ItemT],
item_type: Type[ItemT],
*,
header_t: Optional[Type[HeaderT]] = None,
footer_t: Optional[Type[FooterT]] = None,
header_type: Optional[Type[HeaderT]] = None,
footer_type: Optional[Type[FooterT]] = None,
) -> StreamTypes:
"""Creates a bundle of item type and potentially header/footer types,
each as pydantic model."""
# This indirection for creating `StreamTypes` is needed to get generic typing.
return StreamTypes(item_t, header_t, footer_t)
return StreamTypes(item_type, header_type, footer_type)


# Reading ##############################################################################


class _Delimiter(enum.IntEnum):
NOT_SET = enum.auto()
HEADER = enum.auto()
ITEM = enum.auto()
FOOTER = enum.auto()
Expand Down Expand Up @@ -161,7 +164,7 @@ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]:

async def _read(self) -> tuple[_Delimiter, bytes]:
try:
tag = await self._stream.readexactly(TAG_SIZE)
tag = await self._stream.readexactly(_TAG_SIZE)
# It's ok to read nothing (end of stream), but unexpected to read partial.
except asyncio.IncompleteReadError:
raise
Expand All @@ -182,12 +185,13 @@ async def read_items(self) -> AsyncIterator[ItemT]:
"Called `read_items`, but there the stream contains header data, which "
"is not consumed. Call `read_header` first or remove sending a header."
)
if delimiter in (_Delimiter.FOOTER, _Delimiter.END):
if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items.
self._footer_data = data_bytes
return

assert delimiter == _Delimiter.ITEM
while True:
yield self._stream_types.item_t.model_validate_json(data_bytes)
yield self._stream_types.item_type.model_validate_json(data_bytes)
# We don't know if the next data is another item, footer or the end.
delimiter, data_bytes = await self._read()
if delimiter == _Delimiter.END:
Expand All @@ -204,7 +208,7 @@ async def read_header(
delimiter, data_bytes = await self._read()
if delimiter != _Delimiter.HEADER:
raise ValueError("Stream does not contain header.")
return self._stream_types.header_t.model_validate_json(data_bytes)
return self._stream_types.header_type.model_validate_json(data_bytes)


class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]):
Expand All @@ -219,7 +223,7 @@ async def read_footer(
raise ValueError("Stream does not contain footer.")
self._footer_data = data_bytes

footer = self._stream_types.footer_t.model_validate_json(self._footer_data)
footer = self._stream_types.footer_type.model_validate_json(self._footer_data)
self._footer_data = None
return footer

Expand Down Expand Up @@ -273,11 +277,11 @@ def stream_reader(
types: StreamTypes[ItemT, HeaderTT, FooterTT],
stream: AsyncIterator[bytes],
) -> _StreamReader:
if types.header_t is None and types.footer_t is None:
if types.header_type is None and types.footer_type is None:
return _StreamReader(types, stream)
if types.header_t is None:
if types.header_type is None:
return StreamReaderWithFooter(types, stream)
if types.footer_t is None:
if types.footer_type is None:
return StreamReaderWithHeader(types, stream)

return StreamReaderFull(types, stream)
Expand All @@ -288,11 +292,17 @@ def stream_reader(

class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]):
_stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]
_last_sent: _Delimiter

def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ...


class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]):
def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None:
super().__init__(types)
self._last_sent = _Delimiter.NOT_SET
self._stream_types = types

@staticmethod
def _pack_tag(delimiter: _Delimiter, length: int) -> bytes:
return struct.pack(">BI", delimiter.value, length)
Expand All @@ -305,24 +315,29 @@ def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes:
return memoryview(data)

def yield_item(self, item: ItemT) -> bytes:
if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END):
raise ValueError("Cannot yield item after sending footer / closing stream.")
self._last_sent = _Delimiter.ITEM
return self._serialize(item, _Delimiter.ITEM)


class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]):
def yield_header(
self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT
) -> bytes:
if self._stream_types.header_t is None or header is None:
raise ValueError()
if self._last_sent != _Delimiter.NOT_SET:
raise ValueError("Cannot yield header after other data has been sent.")
self._last_sent = _Delimiter.HEADER
return self._serialize(header, _Delimiter.HEADER)


class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]):
def yield_footer(
self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT
) -> bytes:
if self._stream_types.header_t is None or footer is None:
raise ValueError()
if self._last_sent == _Delimiter.END:
raise ValueError("Cannot yield footer after closing stream.")
self._last_sent = _Delimiter.FOOTER
return self._serialize(footer, _Delimiter.FOOTER)


Expand Down Expand Up @@ -370,11 +385,11 @@ def stream_writer(
def stream_writer(
types: StreamTypes[ItemT, HeaderTT, FooterTT],
) -> _StreamWriter:
if types.header_t is None and types.footer_t is None:
if types.header_type is None and types.footer_type is None:
return _StreamWriter(types)
if types.header_t is None:
if types.header_type is None:
return StreamWriterWithFooter(types)
if types.footer_t is None:
if types.footer_type is None:
return StreamWriterWithHeader(types)

return StreamWriterFull(types)

0 comments on commit 680ba4e

Please sign in to comment.