From 680ba4e1c50d9baf5881d66132dd173a9038e495 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Wed, 27 Nov 2024 10:21:47 -0800 Subject: [PATCH] Review comments --- .../examples/streaming/streaming_chain.py | 4 +- truss-chains/tests/test_streaming.py | 50 ++++++++---- truss-chains/truss_chains/framework.py | 2 +- truss-chains/truss_chains/streaming.py | 79 +++++++++++-------- 4 files changed, 87 insertions(+), 48 deletions(-) diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index 5ef979f50..4b1b2488d 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -30,7 +30,9 @@ class ConsumerOutput(pydantic.BaseModel): strings: str -STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header, footer_t=Footer) +STREAM_TYPES = streaming.stream_types( + MyDataChunk, header_type=Header, footer_type=Footer +) class Generator(chains.ChainletBase): diff --git a/truss-chains/tests/test_streaming.py b/truss-chains/tests/test_streaming.py index 1eee6c031..88dd5421a 100644 --- a/truss-chains/tests/test_streaming.py +++ b/truss-chains/tests/test_streaming.py @@ -30,7 +30,9 @@ async def to_bytes_iterator(data_stream) -> AsyncIterator[bytes]: @pytest.mark.asyncio async def test_streaming_with_header_and_footer(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header, footer_t=Footer) + types = streaming.stream_types( + item_type=MyDataChunk, header_type=Header, footer_type=Footer + ) writer = streaming.stream_writer(types) header = Header(time=123.456, msg="Start of stream") @@ -61,7 +63,7 @@ async def test_streaming_with_header_and_footer(): @pytest.mark.asyncio async def test_streaming_with_items_only(): - types = streaming.stream_types(item_t=MyDataChunk) + types = streaming.stream_types(item_type=MyDataChunk) writer = streaming.stream_writer(types) items = [ @@ -84,8 +86,8 @@ async def test_streaming_with_items_only(): @pytest.mark.asyncio async def test_reading_header_when_none_sent(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) - writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) + writer = streaming.stream_writer(types) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] @@ -99,8 +101,8 @@ async def test_reading_header_when_none_sent(): @pytest.mark.asyncio async def test_reading_items_with_wrong_model(): - types_writer = streaming.stream_types(item_t=MyDataChunk) - types_reader = streaming.stream_types(item_t=Header) # Wrong item type + types_writer = streaming.stream_types(item_type=MyDataChunk) + types_reader = streaming.stream_types(item_type=Header) # Wrong item type writer = streaming.stream_writer(types_writer) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] @@ -117,9 +119,9 @@ async def test_reading_items_with_wrong_model(): @pytest.mark.asyncio async def test_streaming_with_wrong_order(): types = streaming.stream_types( - item_t=MyDataChunk, - header_t=Header, - footer_t=Footer, + item_type=MyDataChunk, + header_type=Header, + footer_type=Footer, ) writer = streaming.stream_writer(types) @@ -129,11 +131,14 @@ async def test_streaming_with_wrong_order(): data_stream = [] for item in items: data_stream.append(writer.yield_item(item)) - data_stream.append(writer.yield_header(header)) + + with pytest.raises( + ValueError, match="Cannot yield header after other data has been sent." + ): + data_stream.append(writer.yield_header(header)) data_stream.append(writer.yield_footer(footer)) reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) - # Try to read header, should fail because the first data is an item with pytest.raises(ValueError, match="Stream does not contain header."): await reader.read_header() @@ -141,7 +146,7 @@ async def test_streaming_with_wrong_order(): @pytest.mark.asyncio async def test_reading_items_without_consuming_header(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) writer = streaming.stream_writer(types) header = Header(time=123.456, msg="Start of stream") items = [MyDataChunk(words=["hello", "world"])] @@ -163,8 +168,8 @@ async def test_reading_items_without_consuming_header(): @pytest.mark.asyncio async def test_reading_footer_when_none_sent(): - types = streaming.stream_types(item_t=MyDataChunk, footer_t=Footer) - writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] for item in items: @@ -179,3 +184,20 @@ async def test_reading_footer_when_none_sent(): # Try to read footer, expect an error with pytest.raises(ValueError, match="Stream does not contain footer."): await reader.read_footer() + + +@pytest.mark.asyncio +async def test_reading_footer_with_no_items(): + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + data_stream = [writer.yield_footer(footer)] + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert len(read_items) == 0 + + read_footer = await reader.read_footer() + assert read_footer == footer diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index d86015974..db2f822fa 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -315,7 +315,7 @@ def _validate_streaming_output_type( raw=annotation, origin_type=origin, arg_type=bytes ) - assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." + assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg." arg = args[0] if arg not in _STREAM_TYPES: msg = ( diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 6dde0d473..9d9a1cae8 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -8,8 +8,10 @@ import pydantic -TAG_SIZE = 5 # uint8 + uint32. -JSONType = Union[str, int, float, bool, None, list["JSONType"], dict[str, "JSONType"]] +_TAG_SIZE = 5 # uint8 + uint32. +_JSONType = Union[ + str, int, float, bool, None, list["_JSONType"], dict[str, "_JSONType"] +] _T = TypeVar("_T") if sys.version_info < (3, 10): @@ -43,56 +45,57 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: @dataclasses.dataclass class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]): - item_t: Type[ItemT] - header_t: HeaderTT # Is either `Type[HeaderT]` or `None`. - footer_t: FooterTT # Is either `Type[FooterT]` or `None`. + item_type: Type[ItemT] + header_type: HeaderTT # Is either `Type[HeaderT]` or `None`. + footer_type: FooterTT # Is either `Type[FooterT]` or `None`. @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Type[HeaderT], - footer_t: Type[FooterT], + header_type: Type[HeaderT], + footer_type: Type[FooterT], ) -> StreamTypes[ItemT, HeaderT, FooterT]: ... @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Type[HeaderT], + header_type: Type[HeaderT], ) -> StreamTypes[ItemT, HeaderT, None]: ... @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - footer_t: Type[FooterT], + footer_type: Type[FooterT], ) -> StreamTypes[ItemT, None, FooterT]: ... @overload -def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... +def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Optional[Type[HeaderT]] = None, - footer_t: Optional[Type[FooterT]] = None, + header_type: Optional[Type[HeaderT]] = None, + footer_type: Optional[Type[FooterT]] = None, ) -> StreamTypes: """Creates a bundle of item type and potentially header/footer types, each as pydantic model.""" # This indirection for creating `StreamTypes` is needed to get generic typing. - return StreamTypes(item_t, header_t, footer_t) + return StreamTypes(item_type, header_type, footer_type) # Reading ############################################################################## class _Delimiter(enum.IntEnum): + NOT_SET = enum.auto() HEADER = enum.auto() ITEM = enum.auto() FOOTER = enum.auto() @@ -161,7 +164,7 @@ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]: async def _read(self) -> tuple[_Delimiter, bytes]: try: - tag = await self._stream.readexactly(TAG_SIZE) + tag = await self._stream.readexactly(_TAG_SIZE) # It's ok to read nothing (end of stream), but unexpected to read partial. except asyncio.IncompleteReadError: raise @@ -182,12 +185,13 @@ async def read_items(self) -> AsyncIterator[ItemT]: "Called `read_items`, but there the stream contains header data, which " "is not consumed. Call `read_header` first or remove sending a header." ) - if delimiter in (_Delimiter.FOOTER, _Delimiter.END): + if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items. + self._footer_data = data_bytes return assert delimiter == _Delimiter.ITEM while True: - yield self._stream_types.item_t.model_validate_json(data_bytes) + yield self._stream_types.item_type.model_validate_json(data_bytes) # We don't know if the next data is another item, footer or the end. delimiter, data_bytes = await self._read() if delimiter == _Delimiter.END: @@ -204,7 +208,7 @@ async def read_header( delimiter, data_bytes = await self._read() if delimiter != _Delimiter.HEADER: raise ValueError("Stream does not contain header.") - return self._stream_types.header_t.model_validate_json(data_bytes) + return self._stream_types.header_type.model_validate_json(data_bytes) class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]): @@ -219,7 +223,7 @@ async def read_footer( raise ValueError("Stream does not contain footer.") self._footer_data = data_bytes - footer = self._stream_types.footer_t.model_validate_json(self._footer_data) + footer = self._stream_types.footer_type.model_validate_json(self._footer_data) self._footer_data = None return footer @@ -273,11 +277,11 @@ def stream_reader( types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], ) -> _StreamReader: - if types.header_t is None and types.footer_t is None: + if types.header_type is None and types.footer_type is None: return _StreamReader(types, stream) - if types.header_t is None: + if types.header_type is None: return StreamReaderWithFooter(types, stream) - if types.footer_t is None: + if types.footer_type is None: return StreamReaderWithHeader(types, stream) return StreamReaderFull(types, stream) @@ -288,11 +292,17 @@ def stream_reader( class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]): _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + _last_sent: _Delimiter def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ... class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]): + def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: + super().__init__(types) + self._last_sent = _Delimiter.NOT_SET + self._stream_types = types + @staticmethod def _pack_tag(delimiter: _Delimiter, length: int) -> bytes: return struct.pack(">BI", delimiter.value, length) @@ -305,6 +315,9 @@ def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: return memoryview(data) def yield_item(self, item: ItemT) -> bytes: + if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END): + raise ValueError("Cannot yield item after sending footer / closing stream.") + self._last_sent = _Delimiter.ITEM return self._serialize(item, _Delimiter.ITEM) @@ -312,8 +325,9 @@ class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]): def yield_header( self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT ) -> bytes: - if self._stream_types.header_t is None or header is None: - raise ValueError() + if self._last_sent != _Delimiter.NOT_SET: + raise ValueError("Cannot yield header after other data has been sent.") + self._last_sent = _Delimiter.HEADER return self._serialize(header, _Delimiter.HEADER) @@ -321,8 +335,9 @@ class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]): def yield_footer( self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT ) -> bytes: - if self._stream_types.header_t is None or footer is None: - raise ValueError() + if self._last_sent == _Delimiter.END: + raise ValueError("Cannot yield footer after closing stream.") + self._last_sent = _Delimiter.FOOTER return self._serialize(footer, _Delimiter.FOOTER) @@ -370,11 +385,11 @@ def stream_writer( def stream_writer( types: StreamTypes[ItemT, HeaderTT, FooterTT], ) -> _StreamWriter: - if types.header_t is None and types.footer_t is None: + if types.header_type is None and types.footer_type is None: return _StreamWriter(types) - if types.header_t is None: + if types.header_type is None: return StreamWriterWithFooter(types) - if types.footer_t is None: + if types.footer_type is None: return StreamWriterWithHeader(types) return StreamWriterFull(types)