From 380c4f454af2c53526ac9e5a3eed2991f0d8110a Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Thu, 14 Nov 2024 13:05:31 -0800 Subject: [PATCH 01/11] It works! Next step is clean up. Get types right. Maybe move some things in repo and cleanup code gen. Try bytes iterator, works local --- .../examples/streaming/streaming_chain.py | 106 ++++++++++ truss-chains/truss_chains/code_gen.py | 135 +++++++----- truss-chains/truss_chains/framework.py | 29 ++- truss-chains/truss_chains/model_skeleton.py | 5 +- truss-chains/truss_chains/remote.py | 8 +- truss-chains/truss_chains/streaming.py | 192 ++++++++++++++++++ truss-chains/truss_chains/stub.py | 53 ++++- truss-chains/truss_chains/utils.py | 68 ++++--- truss/templates/server/common/schema.py | 6 +- 9 files changed, 511 insertions(+), 91 deletions(-) create mode 100644 truss-chains/examples/streaming/streaming_chain.py create mode 100644 truss-chains/truss_chains/streaming.py diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py new file mode 100644 index 000000000..48069a4e9 --- /dev/null +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -0,0 +1,106 @@ +import logging +from typing import AsyncIterator + +LOG_FORMAT = ( + "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" +) + +logging.basicConfig( + level=logging.INFO, format=LOG_FORMAT, datefmt="%Y-%m-%d %H:%M:%S", force=True +) + +logging.info("Start") + +import asyncio +import time + +import pydantic + +logging.info("Import Chains") +import truss_chains as chains +from truss_chains import streaming + +logging.info("Chains imported") + + +class Header(pydantic.BaseModel): + time: float + msg: str + + +class MyDataChunk(pydantic.BaseModel): + words: list[str] + # numbers: np.ndarray + + +class Footer(pydantic.BaseModel): + time: float + duration_sec: float + msg: str + + +class Generator(chains.ChainletBase): + async def run_remote(self) -> AsyncIterator[bytes]: + print("Entering Generator") + streamer = streaming.stream_writer( + MyDataChunk, header_t=Header, footer_t=Footer + ) + header = Header(time=time.time(), msg="Start.") + yield streamer.yield_header(header) + for i in range(1, 5): + # numbers = np.full((3, 4), i) + data = MyDataChunk( + words=[chr(x + 70) * x for x in range(1, i + 1)], + # numbers=numbers + ) + print("Yield") + # await streamer.yield_header(item) # TyeError because type mismatch. + yield streamer.yield_item(data) + # if i >2: + # raise ValueError() + await asyncio.sleep(0.2) + + end_time = time.time() + footer = Footer(time=end_time, duration_sec=end_time - header.time, msg="Done.") + yield streamer.yield_footer(footer) # TyeError because footer type is None. + print("Exiting Generator") + + +class Consumer(chains.ChainletBase): + def __init__(self, generator=chains.depends(Generator)): + self._generator = generator + + async def run_remote(self) -> None: + print("Entering Consumer") + reader = streaming.StreamReader( + self._generator.run_remote(), MyDataChunk, header_t=Header, footer_t=Footer + ) + print("Consuming...") + header = await reader.read_header() + print(header) + async for data in reader.read_items(): + print(f"Read: {data}") + + # reader.yield_item() # Type error, is reader, not writer. + # footer = await generator.reader_footer() # Example does not have a footer. + print("Exiting Consumer") + + +logging.info("Module initialized") + +if __name__ == "__main__": + with chains.run_local(): + chain = Consumer() + result = asyncio.run(chain.run_remote()) + print(result) + + + from truss_chains import definitions, remote + + service = remote.push( + Consumer, + options=definitions.PushOptionsLocalDocker( + chain_name="stream", only_generate_trusses=False, use_local_chains_src=True + ), + ) + service.run_remote({}) diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 832e7c524..58c9c3269 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -23,6 +23,7 @@ requirement (site-package), it will not be copied from the local host. """ +import collections import logging import os import pathlib @@ -32,7 +33,7 @@ import subprocess import sys import textwrap -from typing import Any, Iterable, Mapping, Optional +from typing import Any, Iterable, Mapping, Optional, get_args, get_origin import libcst import truss @@ -93,7 +94,7 @@ def _update_src(new_source: _Source, src_parts: list[str], imports: set[str]) -> imports.update(new_source.imports) -def _gen_import_and_ref(raw_type: Any) -> _Source: +def _gen_pydantic_import_and_ref(raw_type: Any) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" if raw_type.__module__ == "__main__": # TODO: assuming that main is copied into package dir and can be imported. @@ -122,7 +123,7 @@ def _gen_import_and_ref(raw_type: Any) -> _Source: def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" if type_descr.is_pydantic: - return _gen_import_and_ref(type_descr.raw) + return _gen_pydantic_import_and_ref(type_descr.raw) elif isinstance(type_descr.raw, type): if not type_descr.raw.__module__ == "builtins": @@ -134,11 +135,28 @@ def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source: return _Source(src=str(type_descr.raw)) +def _gen_generator_type_import_and_ref( + endpoint: definitions.EndpointAPIDescriptor, +) -> _Source: + """Unlike other `_gen`-helpers, this does not define a type, it creates a symbol.""" + assert len(endpoint.output_types) == 1 + output_type = endpoint.output_types[0] + assert not output_type.is_pydantic + origin = get_origin(output_type.raw) + assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator), origin + args = get_args(output_type.raw) + assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." + + arg = args[0] + type_src = f"{origin.__module__}.{origin.__name__}[{arg.__name__}]" + return _Source(src=type_src, imports={f"import {origin.__module__}"}) + + def _gen_chainlet_import_and_ref( chainlet_descriptor: definitions.ChainletAPIDescriptor, ) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" - return _gen_import_and_ref(chainlet_descriptor.chainlet_cls) + return _gen_pydantic_import_and_ref(chainlet_descriptor.chainlet_cls) # I/O used by Stubs and Truss models ################################################### @@ -206,28 +224,31 @@ async def run_remote( ) -> tuple[shared_chainlet.SplitTextOutput, int]: ``` """ - if endpoint.is_generator: - raise NotImplementedError("Generator.") - imports = set() - args = [] + args = ["self"] for arg in endpoint.input_args: arg_ref = _gen_type_import_and_ref(arg.type) imports.update(arg_ref.imports) args.append(f"{arg.name}: {arg_ref.src}") - outputs: list[str] = [] - for output_type in endpoint.output_types: - _update_src(_gen_type_import_and_ref(output_type), outputs, imports) - - if len(outputs) == 1: - output = outputs[0] + if endpoint.is_generator: + generator_src = _gen_generator_type_import_and_ref(endpoint) + imports.update(generator_src.imports) + output = generator_src.src else: - output = f"tuple[{', '.join(outputs)}]" + outputs: list[str] = [] + for output_type in endpoint.output_types: + _update_src(_gen_type_import_and_ref(output_type), outputs, imports) + + if len(outputs) == 1: + output = outputs[0] + else: + output = f"tuple[{', '.join(outputs)}]" + # If we produce an async generator, we just pass it through. def_str = "async def" if endpoint.is_async else "def" return _Source( - src=f"{def_str} {endpoint.name}(self, {','.join(args)}) -> {output}:", + src=f"{def_str} {endpoint.name}({','.join(args)}) -> {output}:", imports=imports, ) @@ -244,23 +265,37 @@ def _stub_endpoint_body_src( return SplitTextOutput.model_validate(json_result).output ``` """ - if endpoint.is_generator: - raise NotImplementedError("Generator") - imports: set[str] = set() args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args] - inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()" + if args: + inputs = ( + f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()" + ) + else: + inputs = "{}" # Invoke remote. - if endpoint.is_async: - remote_call = f"await self._remote.predict_async({inputs})" + if not endpoint.is_generator: + if endpoint.is_async: + remote_call = f"await self._remote.predict_async({inputs})" + else: + remote_call = f"self._remote.predict_sync({inputs})" + + parts = [f"json_result = {remote_call}"] + # Unpack response and parse as pydantic models if needed. + output_model_name = _get_output_model_name(chainlet_name) + parts.append(f"return {output_model_name}.model_validate(json_result).root") else: - remote_call = f"self._remote.predict_sync({inputs})" + if endpoint.is_async: + parts = [ + f"async for data in await self._remote.predict_async_stream({inputs}):", + _indent("yield data"), + ] + else: + raise NotImplementedError( + "Streaming/Generator only supported for async `run_remote`." + ) - parts = [f"json_result = {remote_call}"] - # Unpack response and parse as pydantic models if needed. - output_model_name = _get_output_model_name(chainlet_name) - parts.append(f"return {output_model_name}.model_validate(json_result).root") return _Source(src="\n".join(parts), imports=imports) @@ -290,8 +325,9 @@ async def run_remote( src_parts: list[str] = [] input_src = _gen_truss_input_pydantic(chainlet) _update_src(input_src, src_parts, imports) - output_src = _gen_truss_output_pydantic(chainlet) - _update_src(output_src, src_parts, imports) + if not chainlet.endpoint.is_generator: + output_src = _gen_truss_output_pydantic(chainlet) + _update_src(output_src, src_parts, imports) signature = _stub_endpoint_signature_src(chainlet.endpoint) imports.update(signature.imports) body = _stub_endpoint_body_src(chainlet.endpoint, chainlet.name) @@ -396,42 +432,48 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source: """Generates AST for the `predict` method of the truss model.""" - if chainlet_descriptor.endpoint.is_generator: - raise NotImplementedError("Generator.") - imports: set[str] = {"from truss_chains import utils"} parts: list[str] = [] def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def" input_model_name = _get_input_model_name(chainlet_descriptor.name) - output_model_name = _get_output_model_name(chainlet_descriptor.name) + if chainlet_descriptor.endpoint.is_generator: + generator_src = _gen_generator_type_import_and_ref(chainlet_descriptor.endpoint) + imports.update(generator_src.imports) + output_type_name = generator_src.src + else: + output_type_name = _get_output_model_name(chainlet_descriptor.name) imports.add("import starlette.requests") imports.add("from truss_chains import stub") parts.append( f"{def_str} predict(self, inputs: {input_model_name}, " - f"request: starlette.requests.Request) -> {output_model_name}:" + f"request: starlette.requests.Request) -> {output_type_name}:" ) # Add error handling context manager: parts.append( _indent( f"with stub.trace_parent(request), utils.exception_to_http_error(" - f'include_stack=True, chainlet_name="{chainlet_descriptor.name}"):' + f'chainlet_name="{chainlet_descriptor.name}"):' ) ) # Invoke Chainlet. - maybe_await = "await " if chainlet_descriptor.endpoint.is_async else "" + if ( + chainlet_descriptor.endpoint.is_async + and not chainlet_descriptor.endpoint.is_generator + ): + maybe_await = "await " + else: + maybe_await = "" run_remote = chainlet_descriptor.endpoint.name - # `exclude_unset` is important to handle arguments where `run_remote` has a default - # correctly. In that case the pydantic model has an optional field and defaults to - # `None`. But there might also be situations where the user explicitly passes a - # value of `None`. So the condition whether to pass that argument or not is - # whether it was *set* in the model. It is considered unset, if the incoming JSON - # (from which the model was parsed/initialized) does not have that key. + # See docs of `pydantic_set_field_dict` for why this is needed. args = "**utils.pydantic_set_field_dict(inputs)" parts.append( _indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2) ) - result_pydantic = f"{output_model_name}(result)" - parts.append(_indent(f"return {result_pydantic}")) + if chainlet_descriptor.endpoint.is_generator: + parts.append(_indent("return result")) + else: + result_pydantic = f"{output_type_name}(result)" + parts.append(_indent(f"return {result_pydantic}")) return _Source(src="\n".join(parts), imports=imports) @@ -496,8 +538,9 @@ def _gen_truss_chainlet_file( input_src = _gen_truss_input_pydantic(chainlet_descriptor) _update_src(input_src, src_parts, imports) - output_src = _gen_truss_output_pydantic(chainlet_descriptor) - _update_src(output_src, src_parts, imports) + if not chainlet_descriptor.endpoint.is_generator: + output_src = _gen_truss_output_pydantic(chainlet_descriptor) + _update_src(output_src, src_parts, imports) model_src = _gen_truss_chainlet_model(chainlet_descriptor) _update_src(model_src, src_parts, imports) diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 771ee73d0..4ea5648ad 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -36,11 +36,13 @@ _SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None} _SIMPLE_CONTAINERS = {list, dict} +_STREAM_TYPES = {bytes, str} _DOCS_URL_CHAINING = ( "https://docs.baseten.co/chains/concepts#depends-call-other-chainlets" ) _DOCS_URL_LOCAL = "https://docs.baseten.co/chains/guide#local-development" +_DOCS_URL_STREAMING = "https://docs.baseten.co/chains/guide#streaming" _ENTRYPOINT_ATTR_NAME = "_chains_entrypoint" @@ -48,6 +50,7 @@ _P = ParamSpec("_P") _R = TypeVar("_R") + # Error Collector ###################################################################### @@ -296,6 +299,24 @@ def _validate_io_type( _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) +def _validate_generator_output_type( + annotation: Any, param_name: str, location: _ErrorLocation +) -> None: + """ + For Chainlet I/O (both data or parameters), we allow simple types + (int, str, float...) and `list` or `dict` containers of these. + Any deeper nested and structured data must be typed as a pydantic model. + """ + origin = get_origin(annotation) + assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator) + args = get_args(annotation) + assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." + arg = args[0] + if arg not in _SIMPLE_TYPES: + msg = "TODO TODO" + _collect_error(msg, _ErrorKind.IO_TYPE_ERROR, location) + + def _validate_endpoint_params( params: list[inspect.Parameter], location: _ErrorLocation ) -> list[definitions.InputArg]: @@ -346,11 +367,15 @@ def _validate_endpoint_output_types( location, ) return [] - if get_origin(annotation) is tuple: + origin = get_origin(annotation) + if origin is tuple: output_types = [] for i, arg in enumerate(get_args(annotation)): _validate_io_type(arg, f"return_type[{i}]", location) output_types.append(definitions.TypeDescriptor(raw=arg)) + if origin in (collections.abc.AsyncIterator, collections.abc.Iterator): + _validate_generator_output_type(annotation, "return_type", location) + output_types = [definitions.TypeDescriptor(raw=annotation)] else: _validate_io_type(annotation, "return_type", location) output_types = [definitions.TypeDescriptor(raw=annotation)] @@ -995,7 +1020,7 @@ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None: assert chainlet_cls._init_is_patched # Dependency chainlets are instantiated here, using their __init__ # that is patched for local. - logging.warning(f"Making first {dep.name}.") + logging.info(f"Making first {dep.name}.") instance = chainlet_cls() # type: ignore # Here init args are patched. cls_to_instance[chainlet_cls] = instance kwargs_mod[arg_name] = instance diff --git a/truss-chains/truss_chains/model_skeleton.py b/truss-chains/truss_chains/model_skeleton.py index 6f637e8d9..4aa053178 100644 --- a/truss-chains/truss_chains/model_skeleton.py +++ b/truss-chains/truss_chains/model_skeleton.py @@ -16,9 +16,8 @@ def __init__( config: dict, data_dir: pathlib.Path, secrets: secrets_resolver.Secrets, - environment: Optional[ - dict - ] = None, # TODO: Remove the default value once all truss versions are synced up. + # TODO: Remove the default value once all truss versions are synced up. + environment: Optional[dict] = None, ) -> None: truss_metadata: definitions.TrussMetadata = ( definitions.TrussMetadata.model_validate( diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 91304c4ba..2b5f73863 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -44,8 +44,7 @@ class DockerTrussService(b10_service.TrussService): """This service is for Chainlets (not for Chains).""" def __init__(self, port: int, is_draft: bool, **kwargs): - # http://localhost:{port} seems to only work *sometimes* with docker. - remote_url = f"http://host.docker.internal:{port}" + remote_url = f"http://localhost:{port}" self._port = port super().__init__(remote_url, is_draft, **kwargs) @@ -411,8 +410,11 @@ def push( is_draft=True, port=port, ) + docker_internal_url = service.predict_url.replace( + "localhost", "host.docker.internal" + ) chainlet_to_predict_url[chainlet_artifact.display_name] = { - "predict_url": service.predict_url, + "predict_url": docker_internal_url, } chainlet_to_service[chainlet_artifact.name] = service diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py new file mode 100644 index 000000000..6bcd8d946 --- /dev/null +++ b/truss-chains/truss_chains/streaming.py @@ -0,0 +1,192 @@ +import asyncio +import enum +import struct +import sys +from collections.abc import AsyncIterator +from typing import Optional, Protocol, overload + +import pydantic +from truss.templates.shared import serialization +from typing_extensions import Generic, Type, TypeVar, runtime_checkable + +_T = TypeVar("_T") + +if sys.version_info < (3, 10): + + async def anext(iterable: AsyncIterator[_T]) -> _T: + return await iterable.__anext__() + + +# Type variables for Header, Item, and Footer +TItem = TypeVar("TItem", bound=pydantic.BaseModel) +THeader = TypeVar("THeader", pydantic.BaseModel, None) +TFooter = TypeVar("TFooter", pydantic.BaseModel, None) + +TAG_SIZE = 5 # uint8 + uint32. + + +@runtime_checkable +class _StreamReaderLike(Protocol): + async def readexactly(self, num_bytes: int) -> bytes: ... + + +class _ByteReader: + def __init__(self, source: AsyncIterator[bytes]) -> None: + self._source = source + self._buffer = bytearray() + self._lock = asyncio.Lock() + + async def readexactly(self, num_bytes: int) -> bytes: + async with self._lock: + while len(self._buffer) < num_bytes: + try: + chunk = await anext(self._source) + except StopAsyncIteration: + break + self._buffer.extend(chunk) + + if len(self._buffer) < num_bytes: + raise EOFError("TODO") + + result = bytes(self._buffer[:num_bytes]) + del self._buffer[:num_bytes] + return result + + +class Delimiter(enum.IntEnum): + HEADER = enum.auto() + ITEM = enum.auto() + FOOTER = enum.auto() + END = enum.auto() + + +class Streamer(Generic[TItem, THeader, TFooter]): + _item_t: Type[TItem] + _header_t: Optional[Type[THeader]] + _footer_t: Optional[Type[TFooter]] + + def __init__( + self, + item_t: Type[TItem], + header_t: Optional[Type[THeader]], + footer_t: Optional[Type[TFooter]], + ) -> None: + self._item_t = item_t + self._header_t = header_t + self._footer_t = footer_t + + +class StreamReader(Streamer[TItem, THeader, TFooter]): + _stream: _StreamReaderLike + + def __init__( + self, + stream: AsyncIterator[bytes], + item_t: Type[TItem], + header_t: Optional[Type[THeader]], + footer_t: Optional[Type[TFooter]], + ) -> None: + super().__init__(item_t, header_t, footer_t) + self._stream = _ByteReader(stream) + self._footer = None + + @staticmethod + def _unpack_tag(tag: bytes) -> tuple[Delimiter, int]: + enum_value, length = struct.unpack(">BI", tag) + return Delimiter(enum_value), length + + async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: + tag = await self._stream.readexactly(TAG_SIZE) + delimiter, length = self._unpack_tag(tag) + if not length: + return delimiter, None + data_bytes = await self._stream.readexactly(length) + return delimiter, serialization.truss_msgpack_deserialize(data_bytes) + + async def read_header(self) -> THeader: + delimiter, data_dict = await self._read() + assert delimiter == Delimiter.HEADER + return self._header_t.parse_obj(data_dict) + + async def read_items(self) -> AsyncIterator[TItem]: + delimiter, data_dict = await self._read() + assert delimiter == Delimiter.ITEM + + while delimiter == Delimiter.ITEM: + yield self._item_t.parse_obj(data_dict) + # Read next: either item, footer, or end. + delimiter, data_dict = await self._read() + if delimiter == Delimiter.END: + return + if delimiter == Delimiter.FOOTER: + self._footer = self._footer_t.parse_obj(data_dict) + return + + async def read_footer(self) -> TFooter: + if self._footer_t is None: + raise ValueError() + footer = self._footer_t + self._footer_t = None + return footer + + +class StreamWriter(Streamer[TItem, THeader, TFooter]): + @staticmethod + def _pack_tag(delimiter: Delimiter, length: int) -> bytes: + return struct.pack(">BI", delimiter.value, length) + + def _serialize(self, obj: pydantic.BaseModel, delimiter: Delimiter) -> bytes: + data_dict = obj.dict() + data_bytes = serialization.truss_msgpack_serialize(data_dict) + data = bytearray(self._pack_tag(delimiter, len(data_bytes))) + data.extend(data_bytes) + print(data) + # Starlette cannot handle byte array. + return memoryview(data) + + def yield_header(self, header: THeader) -> bytes: + return self._serialize(header, Delimiter.HEADER) + + def yield_item(self, item: TItem) -> bytes: + return self._serialize(item, Delimiter.ITEM) + + def yield_footer(self, footer: TFooter) -> bytes: + return self._serialize(footer, Delimiter.FOOTER) + + +@overload +def stream_writer( + item_t: Type[TItem], + *, + header_t: Type[THeader], + footer_t: Type[TFooter], +) -> StreamWriter[TItem, THeader, TFooter]: ... + + +@overload +def stream_writer( + item_t: Type[TItem], + *, + header_t: Type[THeader], +) -> StreamWriter[TItem, THeader, None]: ... + + +@overload +def stream_writer( + item_t: Type[TItem], + *, + footer_t: Type[TFooter], +) -> StreamWriter[TItem, None, TFooter]: ... + + +@overload +def stream_writer(item_t: Type[TItem]) -> StreamWriter[TItem, None, None]: ... + + +def stream_writer( + item_t: Type[TItem], + *, + header_t: Optional[Type[THeader]] = None, + footer_t: Optional[Type[TFooter]] = None, +) -> StreamWriter: + return StreamWriter(item_t, header_t, footer_t) diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 6e0927a30..1091462e4 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -6,7 +6,17 @@ import ssl import threading import time -from typing import Any, ClassVar, Iterator, Mapping, Optional, Type, TypeVar, final +from typing import ( + Any, + AsyncIterator, + ClassVar, + Iterator, + Mapping, + Optional, + Type, + TypeVar, + final, +) import aiohttp import httpx @@ -139,14 +149,16 @@ def predict_sync(self, json_payload): try: with self._sync_num_requests as num_requests: self._maybe_warn_for_overload(num_requests) - resp = self._client_sync().post( + response = self._client_sync().post( self._service_descriptor.predict_url, json=json_payload, headers={ definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() }, ) - return utils.handle_response(resp, self.name) + utils.response_raise_errors(response, self.name) + return response.json() + # As a special case we invalidate the client in case of certificate # errors. This has happened in the past and is a defensive measure. except ssl.SSLError: @@ -167,14 +179,45 @@ async def predict_async(self, json_payload): client = await self._client_async() async with self._async_num_requests as num_requests: self._maybe_warn_for_overload(num_requests) - resp = await client.post( + async with client.post( + self._service_descriptor.predict_url, + json=json_payload, + headers={ + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + }, + ) as response: + await utils.async_response_raise_errors(response, self.name) + return await response.json() + # As a special case we invalidate the client in case of certificate + # errors. This has happened in the past and is a defensive measure. + except ssl.SSLError: + self._cached_async_client = None + raise + + async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: + retrying = tenacity.AsyncRetrying( + stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), + retry=tenacity.retry_if_exception_type(Exception), + reraise=True, + ) + async for attempt in retrying: + with attempt: + if (num := attempt.retry_state.attempt_number) > 1: + logging.info(f"Retrying `{self.name}`, " f"attempt {num}") + try: + client = await self._client_async() + async with self._async_num_requests as num_requests: + self._maybe_warn_for_overload(num_requests) + response = await client.post( self._service_descriptor.predict_url, json=json_payload, headers={ definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() }, ) - return await utils.handle_async_response(resp, self.name) + await utils.async_response_raise_errors(response, self.name) + return response.content.iter_any() + # As a special case we invalidate the client in case of certificate # errors. This has happened in the past and is a defensive measure. except ssl.SSLError: diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index f29853542..28a485451 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -186,28 +186,27 @@ def populate_chainlet_service_predict_urls( # Error Propagation Utils. ############################################################# +# TODO: move request related code into `stub.py`. -def _handle_exception( - exception: Exception, include_stack: bool, chainlet_name: str -) -> NoReturn: +def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: """Raises `fastapi.HTTPException` with `RemoteErrorDetail` as detail.""" if hasattr(exception, "__module__"): exception_module_name = exception.__module__ else: exception_module_name = None - if include_stack: - error_stack = traceback.extract_tb(exception.__traceback__) - # Exclude the error handling functions from the stack trace. - exclude_frames = {exception_to_http_error.__name__, handle_response.__name__} - final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] - stack = list( - [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] - ) - else: - stack = [] - + error_stack = traceback.extract_tb(exception.__traceback__) + # Exclude the error handling functions from the stack trace. + exclude_frames = { + exception_to_http_error.__name__, + response_raise_errors.__name__, + async_response_raise_errors.__name__, + } + final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] + stack = list( + [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] + ) error = definitions.RemoteErrorDetail( remote_name=chainlet_name, exception_cls_name=exception.__class__.__name__, @@ -221,11 +220,12 @@ def _handle_exception( @contextlib.contextmanager -def exception_to_http_error(include_stack: bool, chainlet_name: str) -> Iterator[None]: +def exception_to_http_error(chainlet_name: str) -> Iterator[None]: + # TODO: move chainlet name from here to caller side. try: yield except Exception as e: - _handle_exception(e, include_stack, chainlet_name) + _handle_exception(e, chainlet_name) def _resolve_exception_class( @@ -279,8 +279,8 @@ def _handle_response_error(response_json: dict, remote_name: str): raise exception_cls(msg) -def handle_response(response: httpx.Response, remote_name: str) -> Any: - """For successful requests returns JSON, otherwise raises error. +def response_raise_errors(response: httpx.Response, remote_name: str) -> None: + """In case of error, raise it. If the response error contains `RemoteErrorDetail`, it tries to re-raise the same exception that was raised remotely and falls back to @@ -334,17 +334,11 @@ def handle_response(response: httpx.Response, remote_name: str) -> Any: ) from e _handle_response_error(response_json=response_json, remote_name=remote_name) - return response.json() - -async def handle_async_response( +async def async_response_raise_errors( response: aiohttp.ClientResponse, remote_name: str -) -> Any: - """For successful requests returns JSON, otherwise raises error. - - See `handle_response` for more details on the specifics of the error-handling - here. - """ +) -> None: + """Async version of `async_response_raise_errors`.""" if response.status >= 400: try: response_json = await response.json() @@ -353,10 +347,10 @@ async def handle_async_response( "Could not get JSON from error response. Status: " f"`{response.status}`." ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) - return await response.json() + +######################################################################################## class InjectedError(Exception): @@ -417,7 +411,21 @@ def issubclass_safe(x: Any, cls: type) -> bool: def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]: - """Like `BaseModel.model_dump(exclude_unset=True), but only top-level.""" + """Like `BaseModel.model_dump(exclude_unset=True), but only top-level. + + This is used to get kwargs for invoking a function, while dropping fields for which + there is no value explicitly set in the pydantic model. A field is considered unset + if the key was not present in the incoming JSON request (from which the model was + parsed/initialized) and the pydantic model has a default value, such as `None`. + + By dropping these unset fields, the default values from the function definition + will be used instead. This behavior ensures correct handling of arguments where + the function has a default, such as in the case of `run_remote`. If the model has + an optional field defaulting to `None`, this approach differentiates between + the user explicitly passing a value of `None` and the field being unset in the + request. + + """ return {name: getattr(obj, name) for name in obj.__fields_set__} diff --git a/truss/templates/server/common/schema.py b/truss/templates/server/common/schema.py index 89e7060f4..0201af824 100644 --- a/truss/templates/server/common/schema.py +++ b/truss/templates/server/common/schema.py @@ -2,8 +2,10 @@ from typing import ( Any, AsyncGenerator, + AsyncIterator, Awaitable, Generator, + Iterator, List, Optional, Type, @@ -83,7 +85,7 @@ def _annotation_is_pydantic_model(annotation: Any) -> bool: def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: """ - Therea are 4 possible cases for output_annotation: + There are 4 possible cases for output_annotation: 1. Data object -- represented by a Pydantic BaseModel 2. Streaming -- represented by a Generator or AsyncGenerator 3. Async -- represented by an Awaitable @@ -117,7 +119,7 @@ def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: def _is_generator_type(annotation: Any) -> bool: base_type = get_origin(annotation) return isinstance(base_type, type) and issubclass( - base_type, (Generator, AsyncGenerator) + base_type, (Generator, AsyncGenerator, Iterator, AsyncIterator) ) From 9e48c887e57a86422e219eb06dc341adbc0ddcd2 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Fri, 22 Nov 2024 18:37:52 -0800 Subject: [PATCH 02/11] Crazy typing WIP --- docs/examples/streaming.mdx | 2 +- .../examples/streaming/streaming_chain.py | 33 ++- truss-chains/truss_chains/streaming.py | 253 +++++++++++------- truss-chains/truss_chains/type_experiment.py | 53 ++++ 4 files changed, 232 insertions(+), 109 deletions(-) create mode 100644 truss-chains/truss_chains/type_experiment.py diff --git a/docs/examples/streaming.mdx b/docs/examples/streaming.mdx index 1767396b2..6c2376941 100644 --- a/docs/examples/streaming.mdx +++ b/docs/examples/streaming.mdx @@ -292,7 +292,7 @@ class Model: # Kick off a new thread to execute the model generation. # As the model generates outputs, they will be readable - # from the Streamer object. + # from the _Streamer object. thread = Thread( target=self.model.generate, kwargs=generation_kwargs diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index 48069a4e9..6907a9530 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -1,5 +1,5 @@ import logging -from typing import AsyncIterator +from typing import AsyncIterator, reveal_type LOG_FORMAT = ( "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" @@ -39,12 +39,25 @@ class Footer(pydantic.BaseModel): msg: str +STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header) + + +reveal_type(STREAM_TYPES) +reveal_type(STREAM_TYPES.header_t) +reveal_type(STREAM_TYPES.footer_t) +# header_instance = STREAM_TYPES.header_t() +# print(header_instance.time) +# +# +# footer_instance = STREAM_TYPES.footer_t() +# print(footer_instance.time) + + class Generator(chains.ChainletBase): async def run_remote(self) -> AsyncIterator[bytes]: print("Entering Generator") - streamer = streaming.stream_writer( - MyDataChunk, header_t=Header, footer_t=Footer - ) + streamer = streaming.StreamWriter(STREAM_TYPES) + reveal_type(streamer) header = Header(time=time.time(), msg="Start.") yield streamer.yield_header(header) for i in range(1, 5): @@ -54,7 +67,8 @@ async def run_remote(self) -> AsyncIterator[bytes]: # numbers=numbers ) print("Yield") - # await streamer.yield_header(item) # TyeError because type mismatch. + yield streamer.yield_header(data) # TyeError because type mismatch. + yield streamer.yield_item("ASdf") yield streamer.yield_item(data) # if i >2: # raise ValueError() @@ -72,17 +86,15 @@ def __init__(self, generator=chains.depends(Generator)): async def run_remote(self) -> None: print("Entering Consumer") - reader = streaming.StreamReader( - self._generator.run_remote(), MyDataChunk, header_t=Header, footer_t=Footer - ) + reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote()) print("Consuming...") header = await reader.read_header() print(header) async for data in reader.read_items(): print(f"Read: {data}") - # reader.yield_item() # Type error, is reader, not writer. - # footer = await generator.reader_footer() # Example does not have a footer. + footer = await reader.read_footer() # Example does not have a footer. + print(footer) print("Exiting Consumer") @@ -94,7 +106,6 @@ async def run_remote(self) -> None: result = asyncio.run(chain.run_remote()) print(result) - from truss_chains import definitions, remote service = remote.push( diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 6bcd8d946..a51014c9d 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -1,14 +1,14 @@ -import asyncio import enum import struct import sys from collections.abc import AsyncIterator -from typing import Optional, Protocol, overload +from typing import NamedTuple, Optional, overload import pydantic from truss.templates.shared import serialization -from typing_extensions import Generic, Type, TypeVar, runtime_checkable +from typing_extensions import Generic, Type, TypeVar +TAG_SIZE = 5 # uint8 + uint32. _T = TypeVar("_T") if sys.version_info < (3, 10): @@ -18,39 +18,78 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: # Type variables for Header, Item, and Footer -TItem = TypeVar("TItem", bound=pydantic.BaseModel) -THeader = TypeVar("THeader", pydantic.BaseModel, None) -TFooter = TypeVar("TFooter", pydantic.BaseModel, None) +ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) +HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel) +FooterT = TypeVar("FooterT", bound=pydantic.BaseModel) -TAG_SIZE = 5 # uint8 + uint32. +HeaderTT = TypeVar("HeaderTT") +FooterTT = TypeVar("FooterTT") + + +class StreamTypes(NamedTuple, Generic[ItemT, HeaderTT, FooterTT]): + item_t: Type[ItemT] + header_t: HeaderTT = None + footer_t: FooterTT = None + + +@overload +def stream_types( + item_t: Type[ItemT], + *, + header_t: Type[HeaderT], + footer_t: Type[FooterT], +) -> StreamTypes[ItemT, HeaderT, FooterT]: ... + + +@overload +def stream_types( + item_t: Type[ItemT], + *, + header_t: Type[HeaderT], +) -> StreamTypes[ItemT, HeaderT, None]: ... + + +@overload +def stream_types( + item_t: Type[ItemT], + *, + footer_t: Type[FooterT], +) -> StreamTypes[ItemT, None, FooterT]: ... + + +@overload +def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... -@runtime_checkable -class _StreamReaderLike(Protocol): - async def readexactly(self, num_bytes: int) -> bytes: ... +def stream_types( + item_t: Type[ItemT], + *, + header_t: Optional[Type[HeaderT]] = None, + footer_t: Optional[Type[FooterT]] = None, +) -> StreamTypes: + # This indirection for creating `StreamTypes` is needed to get generic typing. + return StreamTypes(item_t, header_t, footer_t) class _ByteReader: def __init__(self, source: AsyncIterator[bytes]) -> None: self._source = source self._buffer = bytearray() - self._lock = asyncio.Lock() async def readexactly(self, num_bytes: int) -> bytes: - async with self._lock: - while len(self._buffer) < num_bytes: - try: - chunk = await anext(self._source) - except StopAsyncIteration: - break - self._buffer.extend(chunk) + while len(self._buffer) < num_bytes: + try: + chunk = await anext(self._source) + except StopAsyncIteration: + break + self._buffer.extend(chunk) - if len(self._buffer) < num_bytes: - raise EOFError("TODO") + if len(self._buffer) < num_bytes: + raise EOFError("TODO") - result = bytes(self._buffer[:num_bytes]) - del self._buffer[:num_bytes] - return result + result = bytes(self._buffer[:num_bytes]) + del self._buffer[:num_bytes] + return result class Delimiter(enum.IntEnum): @@ -60,35 +99,24 @@ class Delimiter(enum.IntEnum): END = enum.auto() -class Streamer(Generic[TItem, THeader, TFooter]): - _item_t: Type[TItem] - _header_t: Optional[Type[THeader]] - _footer_t: Optional[Type[TFooter]] +class _Streamer(Generic[ItemT, HeaderT, FooterT]): + _stream_types: StreamTypes[ItemT, HeaderT, FooterT] - def __init__( - self, - item_t: Type[TItem], - header_t: Optional[Type[THeader]], - footer_t: Optional[Type[TFooter]], - ) -> None: - self._item_t = item_t - self._header_t = header_t - self._footer_t = footer_t + def __init__(self, stream_types: StreamTypes[ItemT, HeaderT, FooterT]) -> None: + self._stream_types = stream_types -class StreamReader(Streamer[TItem, THeader, TFooter]): - _stream: _StreamReaderLike +class StreamReader(_Streamer[ItemT, HeaderT, FooterT]): + _stream: _ByteReader def __init__( self, + stream_types: StreamTypes[ItemT, HeaderT, FooterT], stream: AsyncIterator[bytes], - item_t: Type[TItem], - header_t: Optional[Type[THeader]], - footer_t: Optional[Type[TFooter]], ) -> None: - super().__init__(item_t, header_t, footer_t) + super().__init__(stream_types) self._stream = _ByteReader(stream) - self._footer = None + self._footer_data = None @staticmethod def _unpack_tag(tag: bytes) -> tuple[Delimiter, int]: @@ -103,90 +131,121 @@ async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: data_bytes = await self._stream.readexactly(length) return delimiter, serialization.truss_msgpack_deserialize(data_bytes) - async def read_header(self) -> THeader: - delimiter, data_dict = await self._read() - assert delimiter == Delimiter.HEADER - return self._header_t.parse_obj(data_dict) - - async def read_items(self) -> AsyncIterator[TItem]: + async def read_items(self) -> AsyncIterator[ItemT]: delimiter, data_dict = await self._read() assert delimiter == Delimiter.ITEM while delimiter == Delimiter.ITEM: - yield self._item_t.parse_obj(data_dict) + yield self._stream_types.item_t.model_validate(data_dict) # Read next: either item, footer, or end. delimiter, data_dict = await self._read() if delimiter == Delimiter.END: return if delimiter == Delimiter.FOOTER: - self._footer = self._footer_t.parse_obj(data_dict) + self._footer_data = data_dict return - async def read_footer(self) -> TFooter: - if self._footer_t is None: + +class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterT]): + async def read_header(self: StreamReader[ItemT, HeaderT, FooterT]) -> HeaderT: + delimiter, data_dict = await self._read() + assert delimiter == Delimiter.HEADER + return self._stream_types.header_t.model_validate(data_dict) + + +class _FooterReadMixin(_Streamer[ItemT, HeaderT, FooterT]): + async def read_footer(self: StreamReader[ItemT, HeaderT, FooterT]) -> FooterT: + if self._footer_data is None: raise ValueError() - footer = self._footer_t - self._footer_t = None + footer = self._stream_types.footer_t.model_validate(self._footer_data) + self._footer_data = None return footer -class StreamWriter(Streamer[TItem, THeader, TFooter]): - @staticmethod - def _pack_tag(delimiter: Delimiter, length: int) -> bytes: - return struct.pack(">BI", delimiter.value, length) +class StreamReaderWithHeader( + StreamReader[ItemT, HeaderT, FooterT], _HeaderReadMixin[ItemT, HeaderT, FooterT] +): ... - def _serialize(self, obj: pydantic.BaseModel, delimiter: Delimiter) -> bytes: - data_dict = obj.dict() - data_bytes = serialization.truss_msgpack_serialize(data_dict) - data = bytearray(self._pack_tag(delimiter, len(data_bytes))) - data.extend(data_bytes) - print(data) - # Starlette cannot handle byte array. - return memoryview(data) - def yield_header(self, header: THeader) -> bytes: - return self._serialize(header, Delimiter.HEADER) +class StreamReaderWithFooter( + StreamReader[ItemT, HeaderT, FooterT], _HeaderReadMixin[ItemT, HeaderT, FooterT] +): ... - def yield_item(self, item: TItem) -> bytes: - return self._serialize(item, Delimiter.ITEM) - def yield_footer(self, footer: TFooter) -> bytes: - return self._serialize(footer, Delimiter.FOOTER) +class StreamReaderFull( + StreamReader[ItemT, HeaderT, FooterT], + _HeaderReadMixin[ItemT, HeaderT, FooterT], + _FooterReadMixin[ItemT, HeaderT, FooterT], +): ... @overload -def stream_writer( - item_t: Type[TItem], - *, - header_t: Type[THeader], - footer_t: Type[TFooter], -) -> StreamWriter[TItem, THeader, TFooter]: ... +def stream_reader( + stream_types: StreamTypes[ItemT, None, None], + stream: AsyncIterator[bytes], +) -> StreamReader[ItemT, None, None]: ... @overload -def stream_writer( - item_t: Type[TItem], - *, - header_t: Type[THeader], -) -> StreamWriter[TItem, THeader, None]: ... +def stream_reader( + stream_types: StreamTypes[ItemT, HeaderT, None], + stream: AsyncIterator[bytes], +) -> StreamReaderWithFooter[ItemT, HeaderT, None]: ... @overload -def stream_writer( - item_t: Type[TItem], - *, - footer_t: Type[TFooter], -) -> StreamWriter[TItem, None, TFooter]: ... +def stream_reader( + stream_types: StreamTypes[ItemT, None, FooterT], + stream: AsyncIterator[bytes], +) -> StreamReaderWithFooter[ItemT, None, FooterT]: ... @overload -def stream_writer(item_t: Type[TItem]) -> StreamWriter[TItem, None, None]: ... +def stream_reader( + stream_types: StreamTypes[ItemT, HeaderT, FooterT], + stream: AsyncIterator[bytes], +) -> StreamReaderFull[ItemT, HeaderT, FooterT]: ... -def stream_writer( - item_t: Type[TItem], - *, - header_t: Optional[Type[THeader]] = None, - footer_t: Optional[Type[TFooter]] = None, -) -> StreamWriter: - return StreamWriter(item_t, header_t, footer_t) +def stream_reader( + stream_types: StreamTypes[ItemT, HeaderT, FooterT], + stream: AsyncIterator[bytes], +) -> StreamReader: + if stream_types.header_t is None and stream_types.footer_t is None: + return StreamReader(stream_types, stream) + if stream_types.header_t is None: + return StreamReaderWithFooter(stream_types, stream) + if stream_types.footer_t is None: + return StreamReaderWithHeader(stream_types, stream) + return StreamReaderFull(stream_types, stream) + + +######################################################################################## + + +class StreamWriter(_Streamer[ItemT, HeaderT, FooterT]): + @staticmethod + def _pack_tag(delimiter: Delimiter, length: int) -> bytes: + return struct.pack(">BI", delimiter.value, length) + + def _serialize(self, obj: pydantic.BaseModel, delimiter: Delimiter) -> bytes: + data_dict = obj.model_dump() + data_bytes = serialization.truss_msgpack_serialize(data_dict) + data = bytearray(self._pack_tag(delimiter, len(data_bytes))) + data.extend(data_bytes) + print(data) + # Starlette cannot handle byte array. + return memoryview(data) + + def yield_header(self, header: HeaderT) -> bytes: + if self._stream_types.header_t is None or header is None: + raise ValueError() + return self._serialize(header, Delimiter.HEADER) + + def yield_item(self, item: ItemT) -> bytes: + return self._serialize(item, Delimiter.ITEM) + + def yield_footer(self, footer: FooterT) -> bytes: + if self._stream_types.header_t is None or footer is None: + raise ValueError() + return self._serialize(footer, Delimiter.FOOTER) diff --git a/truss-chains/truss_chains/type_experiment.py b/truss-chains/truss_chains/type_experiment.py new file mode 100644 index 000000000..646d5e99d --- /dev/null +++ b/truss-chains/truss_chains/type_experiment.py @@ -0,0 +1,53 @@ +from typing import Generic, NamedTuple, Optional, Type, TypeVar, overload + +import pydantic +from typing_extensions import reveal_type + +ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) +HeaderT = TypeVar("HeaderT") + + +class StreamTypes(NamedTuple, Generic[ItemT, HeaderT]): + item_t: Type[ItemT] + header_t: HeaderT + + +@overload +def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None]: ... + + +@overload +def stream_types( + item_t: Type[ItemT], *, header_t: Type[HeaderT] +) -> StreamTypes[ItemT, Type[HeaderT]]: ... + + +def stream_types(item_t: Type[ItemT], *, header_t: Optional[Type[HeaderT]] = None): + return StreamTypes(item_t, header_t) + + +class _Streamer(Generic[ItemT, HeaderT]): + _stream_types: StreamTypes[ItemT, HeaderT] + + def __init__(self, stream_types_: StreamTypes[ItemT, HeaderT]) -> None: + self._stream_types = stream_types_ + + +if __name__ == "__main__": + + class Header(pydantic.BaseModel): + time: float + msg: str + + class MyDataChunk(pydantic.BaseModel): + words: list[str] + + NONE_TYPES = stream_types(MyDataChunk) + FULL_TYPES = stream_types(MyDataChunk, header_t=Header) + + streamer_none = _Streamer(NONE_TYPES) + reveal_type(streamer_none._stream_types.item_t) + reveal_type(streamer_none._stream_types.header_t) # Revealed type is 'None' + + streamer_full = _Streamer(FULL_TYPES) + reveal_type(streamer_full._stream_types.header_t) # Revealed type is 'Type[Header]' From 97266161c00ecb746e6ab1d6f7d1cb18f6f20674 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 11:58:08 -0800 Subject: [PATCH 03/11] WIP clean types --- truss-chains/truss_chains/streaming.py | 67 ++++++++++++++++++-------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index a51014c9d..f3ae3a5eb 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -2,7 +2,7 @@ import struct import sys from collections.abc import AsyncIterator -from typing import NamedTuple, Optional, overload +from typing import NamedTuple, Optional, Protocol, overload import pydantic from truss.templates.shared import serialization @@ -17,19 +17,20 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: return await iterable.__anext__() -# Type variables for Header, Item, and Footer ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel) FooterT = TypeVar("FooterT", bound=pydantic.BaseModel) +# Since header/footer could also be None, we need an extra type variable that +# can assume either `Type[HeaderT]` or `None` - `Type[None]` would not work. HeaderTT = TypeVar("HeaderTT") FooterTT = TypeVar("FooterTT") class StreamTypes(NamedTuple, Generic[ItemT, HeaderTT, FooterTT]): item_t: Type[ItemT] - header_t: HeaderTT = None - footer_t: FooterTT = None + header_t: HeaderTT # Is either `Type[HeaderT]` or `None`. + footer_t: FooterTT # Is either `Type[FooterT]` or `None`. @overload @@ -67,10 +68,32 @@ def stream_types( header_t: Optional[Type[HeaderT]] = None, footer_t: Optional[Type[FooterT]] = None, ) -> StreamTypes: + """Creates a bundle of item type and potentially header/footer types, + each as pydantic model.""" # This indirection for creating `StreamTypes` is needed to get generic typing. return StreamTypes(item_t, header_t, footer_t) +# Reading ############################################################################## + + +class Delimiter(enum.IntEnum): + HEADER = enum.auto() + ITEM = enum.auto() + FOOTER = enum.auto() + END = enum.auto() + + +class _Streamer(Generic[ItemT, HeaderT, FooterT]): + _stream_types: StreamTypes[ItemT, HeaderT, FooterT] + + def __init__(self, stream_types: StreamTypes[ItemT, HeaderT, FooterT]) -> None: + self._stream_types = stream_types + + +# Reading ############################################################################## + + class _ByteReader: def __init__(self, source: AsyncIterator[bytes]) -> None: self._source = source @@ -81,33 +104,29 @@ async def readexactly(self, num_bytes: int) -> bytes: try: chunk = await anext(self._source) except StopAsyncIteration: + if len(self._buffer) < num_bytes: + raise EOFError( + f"Requested to read `{num_bytes}` bytes, " + f"but only `{len(self._buffer)}` available" + ) break self._buffer.extend(chunk) - if len(self._buffer) < num_bytes: - raise EOFError("TODO") - result = bytes(self._buffer[:num_bytes]) del self._buffer[:num_bytes] return result -class Delimiter(enum.IntEnum): - HEADER = enum.auto() - ITEM = enum.auto() - FOOTER = enum.auto() - END = enum.auto() - +class _StreamReaderProtocol(Protocol[ItemT, HeaderT, FooterT]): + async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: ... -class _Streamer(Generic[ItemT, HeaderT, FooterT]): + _footer_data: Optional[serialization.MsgPackType] _stream_types: StreamTypes[ItemT, HeaderT, FooterT] - def __init__(self, stream_types: StreamTypes[ItemT, HeaderT, FooterT]) -> None: - self._stream_types = stream_types - class StreamReader(_Streamer[ItemT, HeaderT, FooterT]): _stream: _ByteReader + _footer_data: Optional[serialization.MsgPackType] def __init__( self, @@ -147,14 +166,20 @@ async def read_items(self) -> AsyncIterator[ItemT]: class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterT]): - async def read_header(self: StreamReader[ItemT, HeaderT, FooterT]) -> HeaderT: + async def read_header( + self: _StreamReaderProtocol[ItemT, HeaderT, FooterT], + ) -> HeaderT: delimiter, data_dict = await self._read() assert delimiter == Delimiter.HEADER return self._stream_types.header_t.model_validate(data_dict) class _FooterReadMixin(_Streamer[ItemT, HeaderT, FooterT]): - async def read_footer(self: StreamReader[ItemT, HeaderT, FooterT]) -> FooterT: + _footer_data: Optional[serialization.MsgPackType] + + async def read_footer( + self: _StreamReaderProtocol[ItemT, HeaderT, FooterT], + ) -> FooterT: if self._footer_data is None: raise ValueError() footer = self._stream_types.footer_t.model_validate(self._footer_data) @@ -190,7 +215,7 @@ def stream_reader( def stream_reader( stream_types: StreamTypes[ItemT, HeaderT, None], stream: AsyncIterator[bytes], -) -> StreamReaderWithFooter[ItemT, HeaderT, None]: ... +) -> StreamReaderWithHeader[ItemT, HeaderT, None]: ... @overload @@ -220,7 +245,7 @@ def stream_reader( return StreamReaderFull(stream_types, stream) -######################################################################################## +# Writing ############################################################################## class StreamWriter(_Streamer[ItemT, HeaderT, FooterT]): From bc3f2486885513b2d7be82e5cebf1503d15259ae Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 12:00:36 -0800 Subject: [PATCH 04/11] WIP clean types --- truss-chains/truss_chains/streaming.py | 36 +++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index f3ae3a5eb..582b0deba 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -84,10 +84,10 @@ class Delimiter(enum.IntEnum): END = enum.auto() -class _Streamer(Generic[ItemT, HeaderT, FooterT]): - _stream_types: StreamTypes[ItemT, HeaderT, FooterT] +class _Streamer(Generic[ItemT, HeaderTT, FooterTT]): + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] - def __init__(self, stream_types: StreamTypes[ItemT, HeaderT, FooterT]) -> None: + def __init__(self, stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: self._stream_types = stream_types @@ -117,20 +117,20 @@ async def readexactly(self, num_bytes: int) -> bytes: return result -class _StreamReaderProtocol(Protocol[ItemT, HeaderT, FooterT]): +class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]): async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: ... _footer_data: Optional[serialization.MsgPackType] - _stream_types: StreamTypes[ItemT, HeaderT, FooterT] + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] -class StreamReader(_Streamer[ItemT, HeaderT, FooterT]): +class StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]): _stream: _ByteReader _footer_data: Optional[serialization.MsgPackType] def __init__( self, - stream_types: StreamTypes[ItemT, HeaderT, FooterT], + stream_types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], ) -> None: super().__init__(stream_types) @@ -150,7 +150,7 @@ async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: data_bytes = await self._stream.readexactly(length) return delimiter, serialization.truss_msgpack_deserialize(data_bytes) - async def read_items(self) -> AsyncIterator[ItemT]: + async def read_items(self) -> AsyncIterator[ItemTT]: delimiter, data_dict = await self._read() assert delimiter == Delimiter.ITEM @@ -165,20 +165,20 @@ async def read_items(self) -> AsyncIterator[ItemT]: return -class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterT]): +class _HeaderReadMixin(_Streamer[ItemT, HeaderTT, FooterTT]): async def read_header( - self: _StreamReaderProtocol[ItemT, HeaderT, FooterT], - ) -> HeaderT: + self: _StreamReaderProtocol[ItemT, HeaderTT, FooterTT], + ) -> HeaderTT: delimiter, data_dict = await self._read() assert delimiter == Delimiter.HEADER return self._stream_types.header_t.model_validate(data_dict) -class _FooterReadMixin(_Streamer[ItemT, HeaderT, FooterT]): +class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterTT]): _footer_data: Optional[serialization.MsgPackType] async def read_footer( - self: _StreamReaderProtocol[ItemT, HeaderT, FooterT], + self: _StreamReaderProtocol[ItemT, HeaderTT, FooterTT], ) -> FooterT: if self._footer_data is None: raise ValueError() @@ -188,19 +188,19 @@ async def read_footer( class StreamReaderWithHeader( - StreamReader[ItemT, HeaderT, FooterT], _HeaderReadMixin[ItemT, HeaderT, FooterT] + StreamReader[ItemT, HeaderTT, FooterTT], _HeaderReadMixin[ItemT, HeaderTT, FooterTT] ): ... class StreamReaderWithFooter( - StreamReader[ItemT, HeaderT, FooterT], _HeaderReadMixin[ItemT, HeaderT, FooterT] + StreamReader[ItemT, HeaderTT, FooterTT], _HeaderReadMixin[ItemT, HeaderTT, FooterTT] ): ... class StreamReaderFull( - StreamReader[ItemT, HeaderT, FooterT], - _HeaderReadMixin[ItemT, HeaderT, FooterT], - _FooterReadMixin[ItemT, HeaderT, FooterT], + StreamReader[ItemT, HeaderTT, FooterTT], + _HeaderReadMixin[ItemT, HeaderTT, FooterTT], + _FooterReadMixin[ItemT, HeaderTT, FooterTT], ): ... From de2e0bd366ac579a707126f2bab97c23e1e216c3 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 12:07:02 -0800 Subject: [PATCH 05/11] WIP fixed reader types, except for stream_reader overloaded function. --- truss-chains/truss_chains/streaming.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 582b0deba..be8d8619c 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -150,7 +150,7 @@ async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: data_bytes = await self._stream.readexactly(length) return delimiter, serialization.truss_msgpack_deserialize(data_bytes) - async def read_items(self) -> AsyncIterator[ItemTT]: + async def read_items(self) -> AsyncIterator[ItemT]: delimiter, data_dict = await self._read() assert delimiter == Delimiter.ITEM @@ -165,20 +165,20 @@ async def read_items(self) -> AsyncIterator[ItemTT]: return -class _HeaderReadMixin(_Streamer[ItemT, HeaderTT, FooterTT]): +class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]): async def read_header( - self: _StreamReaderProtocol[ItemT, HeaderTT, FooterTT], - ) -> HeaderTT: + self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT], + ) -> HeaderT: delimiter, data_dict = await self._read() assert delimiter == Delimiter.HEADER return self._stream_types.header_t.model_validate(data_dict) -class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterTT]): +class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]): _footer_data: Optional[serialization.MsgPackType] async def read_footer( - self: _StreamReaderProtocol[ItemT, HeaderTT, FooterTT], + self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT], ) -> FooterT: if self._footer_data is None: raise ValueError() @@ -188,19 +188,19 @@ async def read_footer( class StreamReaderWithHeader( - StreamReader[ItemT, HeaderTT, FooterTT], _HeaderReadMixin[ItemT, HeaderTT, FooterTT] + StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT] ): ... class StreamReaderWithFooter( - StreamReader[ItemT, HeaderTT, FooterTT], _HeaderReadMixin[ItemT, HeaderTT, FooterTT] + StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT] ): ... class StreamReaderFull( - StreamReader[ItemT, HeaderTT, FooterTT], - _HeaderReadMixin[ItemT, HeaderTT, FooterTT], - _FooterReadMixin[ItemT, HeaderTT, FooterTT], + StreamReader[ItemT, HeaderT, FooterT], + _HeaderReadMixin[ItemT, HeaderT, FooterT], + _FooterReadMixin[ItemT, HeaderT, FooterT], ): ... @@ -233,7 +233,7 @@ def stream_reader( def stream_reader( - stream_types: StreamTypes[ItemT, HeaderT, FooterT], + stream_types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], ) -> StreamReader: if stream_types.header_t is None and stream_types.footer_t is None: From 5a026f9f320dc47fb3d24b2222286797dfd66635 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 17:14:56 -0800 Subject: [PATCH 06/11] Streaming with integration and local test --- docs/examples/streaming.mdx | 2 +- .../examples/streaming/streaming_chain.py | 98 ++++---- truss-chains/tests/chains_e2e_test.py | 53 ++++- truss-chains/truss_chains/code_gen.py | 65 +++--- truss-chains/truss_chains/definitions.py | 25 +- truss-chains/truss_chains/framework.py | 80 +++++-- truss-chains/truss_chains/streaming.py | 216 +++++++++++++----- truss-chains/truss_chains/stub.py | 23 +- truss-chains/truss_chains/type_experiment.py | 53 ----- 9 files changed, 379 insertions(+), 236 deletions(-) delete mode 100644 truss-chains/truss_chains/type_experiment.py diff --git a/docs/examples/streaming.mdx b/docs/examples/streaming.mdx index 6c2376941..1767396b2 100644 --- a/docs/examples/streaming.mdx +++ b/docs/examples/streaming.mdx @@ -292,7 +292,7 @@ class Model: # Kick off a new thread to execute the model generation. # As the model generates outputs, they will be readable - # from the _Streamer object. + # from the Streamer object. thread = Thread( target=self.model.generate, kwargs=generation_kwargs diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index 6907a9530..e09e8b8ac 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -1,27 +1,12 @@ -import logging -from typing import AsyncIterator, reveal_type - -LOG_FORMAT = ( - "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" -) - -logging.basicConfig( - level=logging.INFO, format=LOG_FORMAT, datefmt="%Y-%m-%d %H:%M:%S", force=True -) - -logging.info("Start") - import asyncio import time +from typing import AsyncIterator import pydantic -logging.info("Import Chains") import truss_chains as chains from truss_chains import streaming -logging.info("Chains imported") - class Header(pydantic.BaseModel): time: float @@ -30,7 +15,6 @@ class Header(pydantic.BaseModel): class MyDataChunk(pydantic.BaseModel): words: list[str] - # numbers: np.ndarray class Footer(pydantic.BaseModel): @@ -39,79 +23,83 @@ class Footer(pydantic.BaseModel): msg: str -STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header) +class ConsumerOutput(pydantic.BaseModel): + header: Header + chunks: list[MyDataChunk] + footer: Footer + strings: str -reveal_type(STREAM_TYPES) -reveal_type(STREAM_TYPES.header_t) -reveal_type(STREAM_TYPES.footer_t) -# header_instance = STREAM_TYPES.header_t() -# print(header_instance.time) -# -# -# footer_instance = STREAM_TYPES.footer_t() -# print(footer_instance.time) +STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header, footer_t=Footer) class Generator(chains.ChainletBase): + """Example that streams fully structured pydantic items with header and footer.""" + async def run_remote(self) -> AsyncIterator[bytes]: print("Entering Generator") - streamer = streaming.StreamWriter(STREAM_TYPES) - reveal_type(streamer) + streamer = streaming.stream_writer(STREAM_TYPES) header = Header(time=time.time(), msg="Start.") yield streamer.yield_header(header) for i in range(1, 5): - # numbers = np.full((3, 4), i) data = MyDataChunk( words=[chr(x + 70) * x for x in range(1, i + 1)], - # numbers=numbers ) print("Yield") - yield streamer.yield_header(data) # TyeError because type mismatch. - yield streamer.yield_item("ASdf") yield streamer.yield_item(data) - # if i >2: - # raise ValueError() - await asyncio.sleep(0.2) + await asyncio.sleep(0.05) end_time = time.time() footer = Footer(time=end_time, duration_sec=end_time - header.time, msg="Done.") - yield streamer.yield_footer(footer) # TyeError because footer type is None. + yield streamer.yield_footer(footer) print("Exiting Generator") +class StringGenerator(chains.ChainletBase): + """Minimal streaming example with raw strings (e.g. for LLM).""" + + async def run_remote(self) -> AsyncIterator[str]: + # Note: the "chunk" boundaries are lost, when streaming raw strings. You must + # add spaces and linebreaks to the items. + yield "First " + yield "second " + yield "last." + + class Consumer(chains.ChainletBase): - def __init__(self, generator=chains.depends(Generator)): + """Consume that reads the raw streams and parses them.""" + + def __init__( + self, + generator=chains.depends(Generator), + string_generator=chains.depends(StringGenerator), + ): self._generator = generator + self._string_generator = string_generator - async def run_remote(self) -> None: + async def run_remote(self) -> ConsumerOutput: print("Entering Consumer") reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote()) print("Consuming...") header = await reader.read_header() - print(header) + chunks = [] async for data in reader.read_items(): print(f"Read: {data}") - # reader.yield_item() # Type error, is reader, not writer. - footer = await reader.read_footer() # Example does not have a footer. - print(footer) - print("Exiting Consumer") + chunks.append(data) + footer = await reader.read_footer() + strings = [] + async for part in self._string_generator.run_remote(): + strings.append(part) + + print("Exiting Consumer") + return ConsumerOutput( + header=header, chunks=chunks, footer=footer, strings="".join(strings) + ) -logging.info("Module initialized") if __name__ == "__main__": with chains.run_local(): chain = Consumer() result = asyncio.run(chain.run_remote()) print(result) - - from truss_chains import definitions, remote - - service = remote.push( - Consumer, - options=definitions.PushOptionsLocalDocker( - chain_name="stream", only_generate_trusses=False, use_local_chains_src=True - ), - ) - service.run_remote({}) diff --git a/truss-chains/tests/chains_e2e_test.py b/truss-chains/tests/chains_e2e_test.py index a64adc6f1..31ecc2461 100644 --- a/truss-chains/tests/chains_e2e_test.py +++ b/truss-chains/tests/chains_e2e_test.py @@ -13,8 +13,8 @@ @pytest.mark.integration def test_chain(): with ensure_kill_all(): - root = Path(__file__).parent.resolve() - chain_root = root / "itest_chain" / "itest_chain.py" + tests_root = Path(__file__).parent.resolve() + chain_root = tests_root / "itest_chain" / "itest_chain.py" with framework.import_target(chain_root, "ItestChain") as entrypoint: options = definitions.PushOptionsLocalDocker( chain_name="integration-test", use_local_chains_src=True @@ -81,8 +81,8 @@ def test_chain(): @pytest.mark.asyncio async def test_chain_local(): - root = Path(__file__).parent.resolve() - chain_root = root / "itest_chain" / "itest_chain.py" + tests_root = Path(__file__).parent.resolve() + chain_root = tests_root / "itest_chain" / "itest_chain.py" with framework.import_target(chain_root, "ItestChain") as entrypoint: with public_api.run_local(): with pytest.raises(ValueError): @@ -119,3 +119,48 @@ async def test_chain_local(): match="Chainlets cannot be naively instantiated", ): await entrypoint().run_remote(length=20, num_partitions=5) + + +@pytest.mark.integration +def test_streaming_chain(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "streaming" / "streaming_chain.py" + with framework.import_target(chain_root, "Consumer") as entrypoint: + service = remote.push( + entrypoint, + options=definitions.PushOptionsLocalDocker( + chain_name="stream", + only_generate_trusses=False, + use_local_chains_src=True, + ), + ) + assert service is not None + response = service.run_remote({}) + assert response.status_code == 200 + print(response.json()) + result = response.json() + print(result) + assert result["header"]["msg"] == "Start." + assert result["chunks"][0]["words"] == ["G"] + assert result["chunks"][1]["words"] == ["G", "HH"] + assert result["chunks"][2]["words"] == ["G", "HH", "III"] + assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"] + assert result["footer"]["duration_sec"] > 0 + assert result["strings"] == ["First second last."] + + +@pytest.mark.asyncio +async def test_streaming_chain_local(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "streaming" / "streaming_chain.py" + with framework.import_target(chain_root, "Consumer") as entrypoint: + with public_api.run_local(): + result = await entrypoint().run_remote() + print(result) + assert result.header.msg == "Start." + assert result.chunks[0].words == ["G"] + assert result.chunks[1].words == ["G", "HH"] + assert result.chunks[2].words == ["G", "HH", "III"] + assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"] + assert result.footer.duration_sec > 0 + assert result.strings == ["First second last."] diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 58c9c3269..6ec2e98ca 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -23,7 +23,6 @@ requirement (site-package), it will not be copied from the local host. """ -import collections import logging import os import pathlib @@ -33,7 +32,7 @@ import subprocess import sys import textwrap -from typing import Any, Iterable, Mapping, Optional, get_args, get_origin +from typing import Any, Iterable, Mapping, Optional import libcst import truss @@ -135,21 +134,14 @@ def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source: return _Source(src=str(type_descr.raw)) -def _gen_generator_type_import_and_ref( - endpoint: definitions.EndpointAPIDescriptor, +def _gen_streaming_type_import_and_ref( + stream_type: definitions.StreamingTypeDescriptor, ) -> _Source: """Unlike other `_gen`-helpers, this does not define a type, it creates a symbol.""" - assert len(endpoint.output_types) == 1 - output_type = endpoint.output_types[0] - assert not output_type.is_pydantic - origin = get_origin(output_type.raw) - assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator), origin - args = get_args(output_type.raw) - assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." - - arg = args[0] - type_src = f"{origin.__module__}.{origin.__name__}[{arg.__name__}]" - return _Source(src=type_src, imports={f"import {origin.__module__}"}) + mod = stream_type.origin_type.__module__ + arg = stream_type.arg_type.__name__ + type_src = f"{mod}.{stream_type.origin_type.__name__}[{arg}]" + return _Source(src=type_src, imports={f"import {mod}"}) def _gen_chainlet_import_and_ref( @@ -231,10 +223,10 @@ async def run_remote( imports.update(arg_ref.imports) args.append(f"{arg.name}: {arg_ref.src}") - if endpoint.is_generator: - generator_src = _gen_generator_type_import_and_ref(endpoint) - imports.update(generator_src.imports) - output = generator_src.src + if endpoint.is_streaming: + streaming_src = _gen_streaming_type_import_and_ref(endpoint.streaming_type) + imports.update(streaming_src.imports) + output = streaming_src.src else: outputs: list[str] = [] for output_type in endpoint.output_types: @@ -245,7 +237,6 @@ async def run_remote( else: output = f"tuple[{', '.join(outputs)}]" - # If we produce an async generator, we just pass it through. def_str = "async def" if endpoint.is_async else "def" return _Source( src=f"{def_str} {endpoint.name}({','.join(args)}) -> {output}:", @@ -274,8 +265,9 @@ def _stub_endpoint_body_src( else: inputs = "{}" + parts = [] # Invoke remote. - if not endpoint.is_generator: + if not endpoint.is_streaming: if endpoint.is_async: remote_call = f"await self._remote.predict_async({inputs})" else: @@ -287,13 +279,17 @@ def _stub_endpoint_body_src( parts.append(f"return {output_model_name}.model_validate(json_result).root") else: if endpoint.is_async: - parts = [ + parts.append( f"async for data in await self._remote.predict_async_stream({inputs}):", - _indent("yield data"), - ] + ) + if endpoint.streaming_type.is_string: + parts.append(_indent("yield data.decode()")) + else: + parts.append(_indent("yield data")) else: raise NotImplementedError( - "Streaming/Generator only supported for async `run_remote`." + "`Streaming endpoints (containing `yield` statements) are only " + "supported for async endpoints." ) return _Source(src="\n".join(parts), imports=imports) @@ -325,7 +321,7 @@ async def run_remote( src_parts: list[str] = [] input_src = _gen_truss_input_pydantic(chainlet) _update_src(input_src, src_parts, imports) - if not chainlet.endpoint.is_generator: + if not chainlet.endpoint.is_streaming: output_src = _gen_truss_output_pydantic(chainlet) _update_src(output_src, src_parts, imports) signature = _stub_endpoint_signature_src(chainlet.endpoint) @@ -436,10 +432,12 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> parts: list[str] = [] def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def" input_model_name = _get_input_model_name(chainlet_descriptor.name) - if chainlet_descriptor.endpoint.is_generator: - generator_src = _gen_generator_type_import_and_ref(chainlet_descriptor.endpoint) - imports.update(generator_src.imports) - output_type_name = generator_src.src + if chainlet_descriptor.endpoint.is_streaming: + streaming_src = _gen_streaming_type_import_and_ref( + chainlet_descriptor.endpoint.streaming_type + ) + imports.update(streaming_src.imports) + output_type_name = streaming_src.src else: output_type_name = _get_output_model_name(chainlet_descriptor.name) imports.add("import starlette.requests") @@ -458,7 +456,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> # Invoke Chainlet. if ( chainlet_descriptor.endpoint.is_async - and not chainlet_descriptor.endpoint.is_generator + and not chainlet_descriptor.endpoint.is_streaming ): maybe_await = "await " else: @@ -469,7 +467,8 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> parts.append( _indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2) ) - if chainlet_descriptor.endpoint.is_generator: + if chainlet_descriptor.endpoint.is_streaming: + # Streaming returns raw iterator, no pydantic model. parts.append(_indent("return result")) else: result_pydantic = f"{output_type_name}(result)" @@ -538,7 +537,7 @@ def _gen_truss_chainlet_file( input_src = _gen_truss_input_pydantic(chainlet_descriptor) _update_src(input_src, src_parts, imports) - if not chainlet_descriptor.endpoint.is_generator: + if not chainlet_descriptor.endpoint.is_streaming: output_src = _gen_truss_output_pydantic(chainlet_descriptor) _update_src(output_src, src_parts, imports) model_src = _gen_truss_chainlet_model(chainlet_descriptor) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index efe3c1095..0510f9f4c 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -524,6 +524,19 @@ def is_pydantic(self) -> bool: ) +class StreamingTypeDescriptor(TypeDescriptor): + origin_type: type + arg_type: type + + @property + def is_string(self) -> bool: + return self.arg_type is str + + @property + def is_pydantic(self) -> bool: + return False + + class InputArg(SafeModelNonSerializable): name: str type: TypeDescriptor @@ -535,7 +548,17 @@ class EndpointAPIDescriptor(SafeModelNonSerializable): input_args: list[InputArg] output_types: list[TypeDescriptor] is_async: bool - is_generator: bool + is_streaming: bool + + @property + def streaming_type(self) -> StreamingTypeDescriptor: + if ( + not self.is_streaming + or len(self.output_types) != 1 + or not isinstance(self.output_types[0], StreamingTypeDescriptor) + ): + raise ValueError(f"{self} is not a streaming endpoint.") + return cast(StreamingTypeDescriptor, self.output_types[0]) class DependencyDescriptor(SafeModelNonSerializable): diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 4ea5648ad..90ef39973 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -299,23 +299,27 @@ def _validate_io_type( _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) -def _validate_generator_output_type( - annotation: Any, param_name: str, location: _ErrorLocation -) -> None: - """ - For Chainlet I/O (both data or parameters), we allow simple types - (int, str, float...) and `list` or `dict` containers of these. - Any deeper nested and structured data must be typed as a pydantic model. - """ +def _validate_streaming_output_type( + annotation: Any, location: _ErrorLocation +) -> definitions.StreamingTypeDescriptor: origin = get_origin(annotation) assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator) args = get_args(annotation) assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." arg = args[0] - if arg not in _SIMPLE_TYPES: - msg = "TODO TODO" + if arg not in _STREAM_TYPES: + msg = ( + "Streaming endpoints (containing `yield` statements) can only yield string " + "or byte items. For streaming structured pydantic data, use `stream_writer`" + "and `stream_reader` helpers.\n" + f"See streaming docs: {_DOCS_URL_STREAMING}" + ) _collect_error(msg, _ErrorKind.IO_TYPE_ERROR, location) + return definitions.StreamingTypeDescriptor( + raw=annotation, origin_type=origin, arg_type=arg + ) + def _validate_endpoint_params( params: list[inspect.Parameter], location: _ErrorLocation @@ -357,8 +361,9 @@ def _validate_endpoint_params( def _validate_endpoint_output_types( - annotation: Any, signature, location: _ErrorLocation + annotation: Any, signature, location: _ErrorLocation, is_streaming: bool ) -> list[definitions.TypeDescriptor]: + has_streaming_type = False if annotation == inspect.Parameter.empty: _collect_error( "Return values of endpoints must be type annotated. Got:\n" @@ -374,11 +379,28 @@ def _validate_endpoint_output_types( _validate_io_type(arg, f"return_type[{i}]", location) output_types.append(definitions.TypeDescriptor(raw=arg)) if origin in (collections.abc.AsyncIterator, collections.abc.Iterator): - _validate_generator_output_type(annotation, "return_type", location) - output_types = [definitions.TypeDescriptor(raw=annotation)] + output_types = [_validate_streaming_output_type(annotation, location)] + has_streaming_type = True + if not is_streaming: + _collect_error( + "If the endpoint returns an iterator (streaming), it must have `yield` " + "statements.", + _ErrorKind.IO_TYPE_ERROR, + location, + ) else: _validate_io_type(annotation, "return_type", location) output_types = [definitions.TypeDescriptor(raw=annotation)] + + if is_streaming and not has_streaming_type: + _collect_error( + "If the endpoint is streaming (has `yield` statements), the return type must" + "be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n" + f"\t{location.method_name}{signature} -> {annotation}", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + return output_types @@ -409,7 +431,7 @@ def _validate_and_describe_endpoint( # Return a "neutral dummy" if validation fails, this allows to safely # continue checking for more errors. return definitions.EndpointAPIDescriptor( - input_args=[], output_types=[], is_async=False, is_generator=False + input_args=[], output_types=[], is_async=False, is_streaming=False ) # This is the unbound method. @@ -427,26 +449,38 @@ def _validate_and_describe_endpoint( # Return a "neutral dummy" if validation fails, this allows to safely # continue checking for more errors. return definitions.EndpointAPIDescriptor( - input_args=[], output_types=[], is_async=False, is_generator=False + input_args=[], output_types=[], is_async=False, is_streaming=False ) signature = inspect.signature(endpoint_method) input_args = _validate_endpoint_params( list(signature.parameters.values()), location ) - output_types = _validate_endpoint_output_types( - signature.return_annotation, signature, location - ) - if inspect.isasyncgenfunction(endpoint_method): is_async = True - is_generator = True + is_streaming = True elif inspect.iscoroutinefunction(endpoint_method): is_async = True - is_generator = False + is_streaming = False else: is_async = False - is_generator = inspect.isgeneratorfunction(endpoint_method) + is_streaming = inspect.isgeneratorfunction(endpoint_method) + + output_types = _validate_endpoint_output_types( + signature.return_annotation, + signature, + location, + is_streaming, + ) + + if is_streaming: + if not is_async: + _collect_error( + "`Streaming endpoints (containing `yield` statements) are only " + "supported for async endpoints.", + _ErrorKind.TYPE_ERROR, + location, + ) if not is_async: warnings.warn( @@ -471,7 +505,7 @@ def _validate_and_describe_endpoint( input_args=input_args, output_types=output_types, is_async=is_async, - is_generator=is_generator, + is_streaming=is_streaming, ) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index be8d8619c..9b39ed4b9 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -1,3 +1,4 @@ +import asyncio import enum import struct import sys @@ -17,12 +18,25 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: return await iterable.__anext__() +# Note on the (verbose) typing in this module: we want exact typing of the reader and +# writer helpers, while also allowing flexibility to users to leave out header/footer +# if not needed. +# Putting both a constraint on the header/footer types to be pydantic +# models, but also letting them be optional is not well-supported by typing tools, +# (missing feature is using type variables a constraints on other type variables). +# +# A functional, yet verbose workaround that gives correct variadic type inference, +# is using intermediate type variables `HeaderT` <-> `HeaderTT` and in conjunction with +# mapping out all usage combinations with overloads (the overloads essentially allow +# "conditional" binding of type vars). These overloads also allow to use granular +# reader/writer sub-classes conditionally, that have the read/write methods only for the +# data types configured, and implemented DRY with mixin classes. ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel) FooterT = TypeVar("FooterT", bound=pydantic.BaseModel) -# Since header/footer could also be None, we need an extra type variable that -# can assume either `Type[HeaderT]` or `None` - `Type[None]` would not work. +# Since header/footer could also be `None`, we need an extra type variable that +# can assume either `Type[HeaderT]` or `None` - `Type[None]` causes issues. HeaderTT = TypeVar("HeaderTT") FooterTT = TypeVar("FooterTT") @@ -77,7 +91,7 @@ def stream_types( # Reading ############################################################################## -class Delimiter(enum.IntEnum): +class _Delimiter(enum.IntEnum): HEADER = enum.auto() ITEM = enum.auto() FOOTER = enum.auto() @@ -87,14 +101,16 @@ class Delimiter(enum.IntEnum): class _Streamer(Generic[ItemT, HeaderTT, FooterTT]): _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] - def __init__(self, stream_types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: - self._stream_types = stream_types + def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: + self._stream_types = types # Reading ############################################################################## class _ByteReader: + """Helper to provide `readexactly` API for an async bytes iterator.""" + def __init__(self, source: AsyncIterator[bytes]) -> None: self._source = source self._buffer = bytearray() @@ -104,63 +120,78 @@ async def readexactly(self, num_bytes: int) -> bytes: try: chunk = await anext(self._source) except StopAsyncIteration: - if len(self._buffer) < num_bytes: - raise EOFError( - f"Requested to read `{num_bytes}` bytes, " - f"but only `{len(self._buffer)}` available" - ) break self._buffer.extend(chunk) + if len(self._buffer) < num_bytes: + if len(self._buffer) == 0: + raise EOFError() + raise asyncio.IncompleteReadError(self._buffer, num_bytes) + result = bytes(self._buffer[:num_bytes]) del self._buffer[:num_bytes] return result class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]): - async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: ... - - _footer_data: Optional[serialization.MsgPackType] _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + _footer_data: Optional[serialization.MsgPackType] + + async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]: ... -class StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]): +class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]): _stream: _ByteReader _footer_data: Optional[serialization.MsgPackType] def __init__( self, - stream_types: StreamTypes[ItemT, HeaderTT, FooterTT], + types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], ) -> None: - super().__init__(stream_types) + super().__init__(types) self._stream = _ByteReader(stream) self._footer_data = None @staticmethod - def _unpack_tag(tag: bytes) -> tuple[Delimiter, int]: + def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]: enum_value, length = struct.unpack(">BI", tag) - return Delimiter(enum_value), length + return _Delimiter(enum_value), length + + async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]: + try: + tag = await self._stream.readexactly(TAG_SIZE) + # It's ok to read nothing (end of stream), but unexpected to read partial. + except asyncio.IncompleteReadError: + raise + except EOFError: + return _Delimiter.END, None - async def _read(self) -> tuple[Delimiter, serialization.MsgPackType]: - tag = await self._stream.readexactly(TAG_SIZE) delimiter, length = self._unpack_tag(tag) if not length: return delimiter, None data_bytes = await self._stream.readexactly(length) + print(f"Read Delimiter: {delimiter}") return delimiter, serialization.truss_msgpack_deserialize(data_bytes) async def read_items(self) -> AsyncIterator[ItemT]: delimiter, data_dict = await self._read() - assert delimiter == Delimiter.ITEM - - while delimiter == Delimiter.ITEM: + if delimiter == _Delimiter.HEADER: + raise ValueError( + "Called `read_items`, but there the stream contains header data, which " + "is not consumed. Call `read_header` or remove sending a header." + ) + if delimiter in (_Delimiter.FOOTER, _Delimiter.END): + return + + assert delimiter == _Delimiter.ITEM + while True: yield self._stream_types.item_t.model_validate(data_dict) - # Read next: either item, footer, or end. + # We don't know if the next data is another item, footer or the end. delimiter, data_dict = await self._read() - if delimiter == Delimiter.END: + if delimiter == _Delimiter.END: return - if delimiter == Delimiter.FOOTER: + if delimiter == _Delimiter.FOOTER: self._footer_data = data_dict return @@ -170,7 +201,8 @@ async def read_header( self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT], ) -> HeaderT: delimiter, data_dict = await self._read() - assert delimiter == Delimiter.HEADER + if delimiter != _Delimiter.HEADER: + raise ValueError("Stream does not contain header.") return self._stream_types.header_t.model_validate(data_dict) @@ -181,24 +213,28 @@ async def read_footer( self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT], ) -> FooterT: if self._footer_data is None: - raise ValueError() + delimiter, data_dict = await self._read() + if delimiter != _Delimiter.FOOTER: + raise ValueError("Stream does not contain footer.") + self._footer_data = data_dict + footer = self._stream_types.footer_t.model_validate(self._footer_data) self._footer_data = None return footer class StreamReaderWithHeader( - StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT] + _StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT] ): ... class StreamReaderWithFooter( - StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT] + _StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT] ): ... class StreamReaderFull( - StreamReader[ItemT, HeaderT, FooterT], + _StreamReader[ItemT, HeaderT, FooterT], _HeaderReadMixin[ItemT, HeaderT, FooterT], _FooterReadMixin[ItemT, HeaderT, FooterT], ): ... @@ -206,71 +242,139 @@ class StreamReaderFull( @overload def stream_reader( - stream_types: StreamTypes[ItemT, None, None], + types: StreamTypes[ItemT, None, None], stream: AsyncIterator[bytes], -) -> StreamReader[ItemT, None, None]: ... +) -> _StreamReader[ItemT, None, None]: ... @overload def stream_reader( - stream_types: StreamTypes[ItemT, HeaderT, None], + types: StreamTypes[ItemT, HeaderT, None], stream: AsyncIterator[bytes], ) -> StreamReaderWithHeader[ItemT, HeaderT, None]: ... @overload def stream_reader( - stream_types: StreamTypes[ItemT, None, FooterT], + types: StreamTypes[ItemT, None, FooterT], stream: AsyncIterator[bytes], ) -> StreamReaderWithFooter[ItemT, None, FooterT]: ... @overload def stream_reader( - stream_types: StreamTypes[ItemT, HeaderT, FooterT], + types: StreamTypes[ItemT, HeaderT, FooterT], stream: AsyncIterator[bytes], ) -> StreamReaderFull[ItemT, HeaderT, FooterT]: ... def stream_reader( - stream_types: StreamTypes[ItemT, HeaderTT, FooterTT], + types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], -) -> StreamReader: - if stream_types.header_t is None and stream_types.footer_t is None: - return StreamReader(stream_types, stream) - if stream_types.header_t is None: - return StreamReaderWithFooter(stream_types, stream) - if stream_types.footer_t is None: - return StreamReaderWithHeader(stream_types, stream) - return StreamReaderFull(stream_types, stream) +) -> _StreamReader: + if types.header_t is None and types.footer_t is None: + return _StreamReader(types, stream) + if types.header_t is None: + return StreamReaderWithFooter(types, stream) + if types.footer_t is None: + return StreamReaderWithHeader(types, stream) + + return StreamReaderFull(types, stream) # Writing ############################################################################## -class StreamWriter(_Streamer[ItemT, HeaderT, FooterT]): +class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]): + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + + def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ... + + +class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]): @staticmethod - def _pack_tag(delimiter: Delimiter, length: int) -> bytes: + def _pack_tag(delimiter: _Delimiter, length: int) -> bytes: return struct.pack(">BI", delimiter.value, length) - def _serialize(self, obj: pydantic.BaseModel, delimiter: Delimiter) -> bytes: + def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: data_dict = obj.model_dump() data_bytes = serialization.truss_msgpack_serialize(data_dict) data = bytearray(self._pack_tag(delimiter, len(data_bytes))) data.extend(data_bytes) - print(data) - # Starlette cannot handle byte array. + # Starlette cannot handle byte array, but view works.. return memoryview(data) - def yield_header(self, header: HeaderT) -> bytes: + def yield_item(self, item: ItemT) -> bytes: + return self._serialize(item, _Delimiter.ITEM) + + +class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]): + def yield_header( + self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT + ) -> bytes: if self._stream_types.header_t is None or header is None: raise ValueError() - return self._serialize(header, Delimiter.HEADER) + return self._serialize(header, _Delimiter.HEADER) - def yield_item(self, item: ItemT) -> bytes: - return self._serialize(item, Delimiter.ITEM) - def yield_footer(self, footer: FooterT) -> bytes: +class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]): + def yield_footer( + self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT + ) -> bytes: if self._stream_types.header_t is None or footer is None: raise ValueError() - return self._serialize(footer, Delimiter.FOOTER) + return self._serialize(footer, _Delimiter.FOOTER) + + +class StreamWriterWithHeader( + _StreamWriter[ItemT, HeaderT, FooterTT], _HeaderWriteMixin[ItemT, HeaderT, FooterTT] +): ... + + +class StreamWriterWithFooter( + _StreamWriter[ItemT, HeaderTT, FooterT], _FooterWriteMixin[ItemT, HeaderTT, FooterT] +): ... + + +class StreamWriterFull( + _StreamWriter[ItemT, HeaderT, FooterT], + _HeaderWriteMixin[ItemT, HeaderT, FooterT], + _FooterWriteMixin[ItemT, HeaderT, FooterT], +): ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, None, None], +) -> _StreamWriter[ItemT, None, None]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, HeaderT, None], +) -> StreamWriterWithHeader[ItemT, HeaderT, None]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, None, FooterT], +) -> StreamWriterWithFooter[ItemT, None, FooterT]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, HeaderT, FooterT], +) -> StreamWriterFull[ItemT, HeaderT, FooterT]: ... + + +def stream_writer( + types: StreamTypes[ItemT, HeaderTT, FooterTT], +) -> _StreamWriter: + if types.header_t is None and types.footer_t is None: + return _StreamWriter(types) + if types.header_t is None: + return StreamWriterWithFooter(types) + if types.footer_t is None: + return StreamWriterWithHeader(types) + + return StreamWriterFull(types) diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 1091462e4..5de4f66de 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -137,6 +137,9 @@ async def _client_async(self) -> aiohttp.ClientSession: return self._cached_async_client[0] def predict_sync(self, json_payload): + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } retrying = tenacity.Retrying( stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), retry=tenacity.retry_if_exception_type(Exception), @@ -152,9 +155,7 @@ def predict_sync(self, json_payload): response = self._client_sync().post( self._service_descriptor.predict_url, json=json_payload, - headers={ - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - }, + headers=headers, ) utils.response_raise_errors(response, self.name) return response.json() @@ -166,6 +167,9 @@ def predict_sync(self, json_payload): raise async def predict_async(self, json_payload): + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } retrying = tenacity.AsyncRetrying( stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), retry=tenacity.retry_if_exception_type(Exception), @@ -182,9 +186,7 @@ async def predict_async(self, json_payload): async with client.post( self._service_descriptor.predict_url, json=json_payload, - headers={ - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - }, + headers=headers, ) as response: await utils.async_response_raise_errors(response, self.name) return await response.json() @@ -194,7 +196,10 @@ async def predict_async(self, json_payload): self._cached_async_client = None raise - async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: + async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: # type: ignore[return] # Handled by retries. + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } retrying = tenacity.AsyncRetrying( stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), retry=tenacity.retry_if_exception_type(Exception), @@ -211,9 +216,7 @@ async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: response = await client.post( self._service_descriptor.predict_url, json=json_payload, - headers={ - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - }, + headers=headers, ) await utils.async_response_raise_errors(response, self.name) return response.content.iter_any() diff --git a/truss-chains/truss_chains/type_experiment.py b/truss-chains/truss_chains/type_experiment.py deleted file mode 100644 index 646d5e99d..000000000 --- a/truss-chains/truss_chains/type_experiment.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Generic, NamedTuple, Optional, Type, TypeVar, overload - -import pydantic -from typing_extensions import reveal_type - -ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) -HeaderT = TypeVar("HeaderT") - - -class StreamTypes(NamedTuple, Generic[ItemT, HeaderT]): - item_t: Type[ItemT] - header_t: HeaderT - - -@overload -def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None]: ... - - -@overload -def stream_types( - item_t: Type[ItemT], *, header_t: Type[HeaderT] -) -> StreamTypes[ItemT, Type[HeaderT]]: ... - - -def stream_types(item_t: Type[ItemT], *, header_t: Optional[Type[HeaderT]] = None): - return StreamTypes(item_t, header_t) - - -class _Streamer(Generic[ItemT, HeaderT]): - _stream_types: StreamTypes[ItemT, HeaderT] - - def __init__(self, stream_types_: StreamTypes[ItemT, HeaderT]) -> None: - self._stream_types = stream_types_ - - -if __name__ == "__main__": - - class Header(pydantic.BaseModel): - time: float - msg: str - - class MyDataChunk(pydantic.BaseModel): - words: list[str] - - NONE_TYPES = stream_types(MyDataChunk) - FULL_TYPES = stream_types(MyDataChunk, header_t=Header) - - streamer_none = _Streamer(NONE_TYPES) - reveal_type(streamer_none._stream_types.item_t) - reveal_type(streamer_none._stream_types.header_t) # Revealed type is 'None' - - streamer_full = _Streamer(FULL_TYPES) - reveal_type(streamer_full._stream_types.header_t) # Revealed type is 'Type[Header]' From 9b360b81ae48ebf2cdc4f35513138195117923b9 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 18:02:32 -0800 Subject: [PATCH 07/11] Fix failures --- .github/workflows/pr.yml | 35 -------------------------- truss-chains/tests/chains_e2e_test.py | 4 +-- truss-chains/truss_chains/framework.py | 7 +++--- truss-chains/truss_chains/streaming.py | 17 ++++++++----- 4 files changed, 17 insertions(+), 46 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f1955c325..963da7544 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -51,38 +51,3 @@ jobs: with: use-verbose-mode: "yes" folder-path: "docs" - - enforce-chains-example-docs-sync: - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v4 - with: - lfs: true - fetch-depth: 2 - - - name: Fetch main branch - run: git fetch origin main - - - name: Check if chains examples were modified - id: check_files - run: | - if git diff --name-only origin/main | grep -q '^truss-chains/examples/.*'; then - echo "chains_docs_update_needed=true" >> $GITHUB_ENV - echo "Chains examples were modified." - else - echo "chains_docs_update_needed=false" >> $GITHUB_ENV - echo "Chains examples were not modified." - echo "::notice file=truss-chains/examples/::Chains examples not modified." - fi - - - name: Enforce acknowledgment in PR description - if: env.chains_docs_update_needed == 'true' - env: - DESCRIPTION: ${{ github.event.pull_request.body }} - run: | - if [[ "$DESCRIPTION" != *"UPDATE_DOCS=done"* && "$DESCRIPTION" != *"UPDATE_DOCS=not_needed"* ]]; then - echo "::error file=truss-chains/examples/::Chains examples were modified and ack not found in PR description. Verify whether docs need to be update (https://github.com/basetenlabs/docs.baseten.co/tree/main/chains) and add an ack tag `UPDATE_DOCS={done|not_needed}` to the PR description." - exit 1 - else - echo "::notice file=truss-chains/examples/::Chains examples modified and ack found int PR description." - fi diff --git a/truss-chains/tests/chains_e2e_test.py b/truss-chains/tests/chains_e2e_test.py index 31ecc2461..29d7ca894 100644 --- a/truss-chains/tests/chains_e2e_test.py +++ b/truss-chains/tests/chains_e2e_test.py @@ -146,7 +146,7 @@ def test_streaming_chain(): assert result["chunks"][2]["words"] == ["G", "HH", "III"] assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"] assert result["footer"]["duration_sec"] > 0 - assert result["strings"] == ["First second last."] + assert result["strings"] == "First second last." @pytest.mark.asyncio @@ -163,4 +163,4 @@ async def test_streaming_chain_local(): assert result.chunks[2].words == ["G", "HH", "III"] assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"] assert result.footer.duration_sec > 0 - assert result.strings == ["First second last."] + assert result.strings == "First second last." diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 90ef39973..560d4a795 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -378,7 +378,8 @@ def _validate_endpoint_output_types( for i, arg in enumerate(get_args(annotation)): _validate_io_type(arg, f"return_type[{i}]", location) output_types.append(definitions.TypeDescriptor(raw=arg)) - if origin in (collections.abc.AsyncIterator, collections.abc.Iterator): + + elif origin in (collections.abc.AsyncIterator, collections.abc.Iterator): output_types = [_validate_streaming_output_type(annotation, location)] has_streaming_type = True if not is_streaming: @@ -394,8 +395,8 @@ def _validate_endpoint_output_types( if is_streaming and not has_streaming_type: _collect_error( - "If the endpoint is streaming (has `yield` statements), the return type must" - "be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n" + "If the endpoint is streaming (has `yield` statements), the return type " + "must be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n" f"\t{location.method_name}{signature} -> {annotation}", _ErrorKind.IO_TYPE_ERROR, location, diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 9b39ed4b9..6431d6d83 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -3,11 +3,10 @@ import struct import sys from collections.abc import AsyncIterator -from typing import NamedTuple, Optional, Protocol, overload +from typing import Generic, Optional, Protocol, Type, TypeVar, overload import pydantic from truss.templates.shared import serialization -from typing_extensions import Generic, Type, TypeVar TAG_SIZE = 5 # uint8 + uint32. _T = TypeVar("_T") @@ -41,10 +40,16 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: FooterTT = TypeVar("FooterTT") -class StreamTypes(NamedTuple, Generic[ItemT, HeaderTT, FooterTT]): - item_t: Type[ItemT] - header_t: HeaderTT # Is either `Type[HeaderT]` or `None`. - footer_t: FooterTT # Is either `Type[FooterT]` or `None`. +class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]): + def __init__( + self, + item_t: Type[ItemT], + header_t: HeaderTT, + footer_t: FooterTT, + ) -> None: + self.item_t = item_t + self.header_t = header_t # Is either `Type[HeaderT]` or `None`. + self.footer_t = footer_t # Is either `Type[FooterT]` or `None`. @overload From 36df2d4e6568ad6d3289252835ba6f22d2bf8d10 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 18:13:49 -0800 Subject: [PATCH 08/11] dataclass --- truss-chains/truss_chains/streaming.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 6431d6d83..1225aa06b 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import enum import struct import sys @@ -40,16 +41,11 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: FooterTT = TypeVar("FooterTT") +@dataclasses.dataclass class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]): - def __init__( - self, - item_t: Type[ItemT], - header_t: HeaderTT, - footer_t: FooterTT, - ) -> None: - self.item_t = item_t - self.header_t = header_t # Is either `Type[HeaderT]` or `None`. - self.footer_t = footer_t # Is either `Type[FooterT]` or `None`. + item_t: Type[ItemT] + header_t: HeaderTT # Is either `Type[HeaderT]` or `None`. + footer_t: FooterTT # Is either `Type[FooterT]` or `None`. @overload From 139942108da4c7770f94c4cebd2060118d47660d Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 25 Nov 2024 18:29:55 -0800 Subject: [PATCH 09/11] Add streaming unittest --- .../examples/streaming/streaming_chain.py | 4 +- truss-chains/tests/test_streaming.py | 181 ++++++++++++++++++ truss-chains/truss_chains/streaming.py | 2 +- 3 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 truss-chains/tests/test_streaming.py diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index e09e8b8ac..5ef979f50 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -56,11 +56,11 @@ async def run_remote(self) -> AsyncIterator[bytes]: class StringGenerator(chains.ChainletBase): - """Minimal streaming example with raw strings (e.g. for LLM).""" + """Minimal streaming example with strings (e.g. for raw LLM output).""" async def run_remote(self) -> AsyncIterator[str]: # Note: the "chunk" boundaries are lost, when streaming raw strings. You must - # add spaces and linebreaks to the items. + # add spaces and linebreaks to the items yourself.. yield "First " yield "second " yield "last." diff --git a/truss-chains/tests/test_streaming.py b/truss-chains/tests/test_streaming.py new file mode 100644 index 000000000..1eee6c031 --- /dev/null +++ b/truss-chains/tests/test_streaming.py @@ -0,0 +1,181 @@ +import asyncio +from typing import AsyncIterator + +import pydantic +import pytest + +from truss_chains import streaming + + +class Header(pydantic.BaseModel): + time: float + msg: str + + +class MyDataChunk(pydantic.BaseModel): + words: list[str] + + +class Footer(pydantic.BaseModel): + time: float + duration_sec: float + msg: str + + +async def to_bytes_iterator(data_stream) -> AsyncIterator[bytes]: + for data in data_stream: + yield data + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_streaming_with_header_and_footer(): + types = streaming.stream_types(item_t=MyDataChunk, header_t=Header, footer_t=Footer) + + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [ + MyDataChunk(words=["hello", "world"]), + MyDataChunk(words=["foo", "bar"]), + MyDataChunk(words=["baz"]), + ] + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + + data_stream = [] + data_stream.append(writer.yield_header(header)) + for item in items: + data_stream.append(writer.yield_item(item)) + data_stream.append(writer.yield_footer(footer)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + # Assert that serialization roundtrip works. + read_header = await reader.read_header() + assert read_header == header + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert read_items == items + read_footer = await reader.read_footer() + assert read_footer == footer + + +@pytest.mark.asyncio +async def test_streaming_with_items_only(): + types = streaming.stream_types(item_t=MyDataChunk) + writer = streaming.stream_writer(types) + + items = [ + MyDataChunk(words=["hello", "world"]), + MyDataChunk(words=["foo", "bar"]), + MyDataChunk(words=["baz"]), + ] + + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + + assert read_items == items + + +@pytest.mark.asyncio +async def test_reading_header_when_none_sent(): + types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) + writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + items = [MyDataChunk(words=["hello", "world"])] + + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + with pytest.raises(ValueError, match="Stream does not contain header."): + await reader.read_header() + + +@pytest.mark.asyncio +async def test_reading_items_with_wrong_model(): + types_writer = streaming.stream_types(item_t=MyDataChunk) + types_reader = streaming.stream_types(item_t=Header) # Wrong item type + writer = streaming.stream_writer(types_writer) + items = [MyDataChunk(words=["hello", "world"])] + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types_reader, to_bytes_iterator(data_stream)) + + with pytest.raises(pydantic.ValidationError): + async for item in reader.read_items(): + pass + + +@pytest.mark.asyncio +async def test_streaming_with_wrong_order(): + types = streaming.stream_types( + item_t=MyDataChunk, + header_t=Header, + footer_t=Footer, + ) + + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [MyDataChunk(words=["hello", "world"])] + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + data_stream.append(writer.yield_header(header)) + data_stream.append(writer.yield_footer(footer)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + + # Try to read header, should fail because the first data is an item + with pytest.raises(ValueError, match="Stream does not contain header."): + await reader.read_header() + + +@pytest.mark.asyncio +async def test_reading_items_without_consuming_header(): + types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [MyDataChunk(words=["hello", "world"])] + + data_stream = [] + data_stream.append(writer.yield_header(header)) + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + # Try to read items without consuming header + with pytest.raises( + ValueError, + match="Called `read_items`, but there the stream contains header data", + ): + async for item in reader.read_items(): + pass + + +@pytest.mark.asyncio +async def test_reading_footer_when_none_sent(): + types = streaming.stream_types(item_t=MyDataChunk, footer_t=Footer) + writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + items = [MyDataChunk(words=["hello", "world"])] + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert read_items == items + + # Try to read footer, expect an error + with pytest.raises(ValueError, match="Stream does not contain footer."): + await reader.read_footer() diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 1225aa06b..85abc7589 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -180,7 +180,7 @@ async def read_items(self) -> AsyncIterator[ItemT]: if delimiter == _Delimiter.HEADER: raise ValueError( "Called `read_items`, but there the stream contains header data, which " - "is not consumed. Call `read_header` or remove sending a header." + "is not consumed. Call `read_header` first or remove sending a header." ) if delimiter in (_Delimiter.FOOTER, _Delimiter.END): return From b14aa3b7d949f2bc80c67e6a2a64d8f8cf96c1cf Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Tue, 26 Nov 2024 10:03:19 -0800 Subject: [PATCH 10/11] Add tests for validation. Make streaming not depend on msg pack. --- truss-chains/tests/test_framework.py | 54 ++++++++++++++++++++++++- truss-chains/truss_chains/framework.py | 12 +++++- truss-chains/truss_chains/streaming.py | 41 +++++++++---------- truss/templates/server/model_wrapper.py | 23 ++++++++--- truss/templates/server/truss_server.py | 6 +-- truss/templates/shared/serialization.py | 23 +---------- 6 files changed, 105 insertions(+), 54 deletions(-) diff --git a/truss-chains/tests/test_framework.py b/truss-chains/tests/test_framework.py index 5f33a3c00..c29324606 100644 --- a/truss-chains/tests/test_framework.py +++ b/truss-chains/tests/test_framework.py @@ -2,7 +2,7 @@ import contextlib import logging import re -from typing import List +from typing import AsyncIterator, Iterator, List import pydantic import pytest @@ -505,3 +505,55 @@ def run_remote(argument: object): ... with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): with public_api.run_local(): MultiIssue() + + +def test_raises_iterator_no_yield(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorNoYield\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"If the endpoint returns an iterator \(streaming\), it must have `yield` statements" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorNoYield(chains.ChainletBase): + async def run_remote(self) -> AsyncIterator[str]: + return "123" # type: ignore[return-value] + + +def test_raises_yield_no_iterator(): + match = ( + rf"{TEST_FILE}:\d+ \(YieldNoIterator\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"If the endpoint is streaming \(has `yield` statements\), the return type must be an iterator" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class YieldNoIterator(chains.ChainletBase): + async def run_remote(self) -> str: # type: ignore[misc] + yield "123" + + +def test_raises_iterator_sync(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorSync\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Streaming endpoints \(containing `yield` statements\) are only supported for async endpoints" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorSync(chains.ChainletBase): + def run_remote(self) -> Iterator[str]: + yield "123" + + +def test_raises_iterator_no_arg(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorNoArg\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Iterators must be annotated with type \(one of \['str', 'bytes'\]\)" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorNoArg(chains.ChainletBase): + async def run_remote(self) -> AsyncIterator: + yield "123" diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 560d4a795..d86015974 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -305,6 +305,16 @@ def _validate_streaming_output_type( origin = get_origin(annotation) assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator) args = get_args(annotation) + if len(args) < 1: + _collect_error( + f"Iterators must be annotated with type (one of {list(x.__name__ for x in _STREAM_TYPES)}).", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + return definitions.StreamingTypeDescriptor( + raw=annotation, origin_type=origin, arg_type=bytes + ) + assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." arg = args[0] if arg not in _STREAM_TYPES: @@ -479,7 +489,7 @@ def _validate_and_describe_endpoint( _collect_error( "`Streaming endpoints (containing `yield` statements) are only " "supported for async endpoints.", - _ErrorKind.TYPE_ERROR, + _ErrorKind.IO_TYPE_ERROR, location, ) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 85abc7589..6dde0d473 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -4,12 +4,12 @@ import struct import sys from collections.abc import AsyncIterator -from typing import Generic, Optional, Protocol, Type, TypeVar, overload +from typing import Generic, Optional, Protocol, Type, TypeVar, Union, overload import pydantic -from truss.templates.shared import serialization TAG_SIZE = 5 # uint8 + uint32. +JSONType = Union[str, int, float, bool, None, list["JSONType"], dict[str, "JSONType"]] _T = TypeVar("_T") if sys.version_info < (3, 10): @@ -136,14 +136,14 @@ async def readexactly(self, num_bytes: int) -> bytes: class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]): _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] - _footer_data: Optional[serialization.MsgPackType] + _footer_data: Optional[bytes] - async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]: ... + async def _read(self) -> tuple[_Delimiter, bytes]: ... class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]): _stream: _ByteReader - _footer_data: Optional[serialization.MsgPackType] + _footer_data: Optional[bytes] def __init__( self, @@ -159,24 +159,24 @@ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]: enum_value, length = struct.unpack(">BI", tag) return _Delimiter(enum_value), length - async def _read(self) -> tuple[_Delimiter, serialization.MsgPackType]: + async def _read(self) -> tuple[_Delimiter, bytes]: try: tag = await self._stream.readexactly(TAG_SIZE) # It's ok to read nothing (end of stream), but unexpected to read partial. except asyncio.IncompleteReadError: raise except EOFError: - return _Delimiter.END, None + return _Delimiter.END, b"" delimiter, length = self._unpack_tag(tag) if not length: - return delimiter, None + return delimiter, b"" data_bytes = await self._stream.readexactly(length) print(f"Read Delimiter: {delimiter}") - return delimiter, serialization.truss_msgpack_deserialize(data_bytes) + return delimiter, data_bytes async def read_items(self) -> AsyncIterator[ItemT]: - delimiter, data_dict = await self._read() + delimiter, data_bytes = await self._read() if delimiter == _Delimiter.HEADER: raise ValueError( "Called `read_items`, but there the stream contains header data, which " @@ -187,13 +187,13 @@ async def read_items(self) -> AsyncIterator[ItemT]: assert delimiter == _Delimiter.ITEM while True: - yield self._stream_types.item_t.model_validate(data_dict) + yield self._stream_types.item_t.model_validate_json(data_bytes) # We don't know if the next data is another item, footer or the end. - delimiter, data_dict = await self._read() + delimiter, data_bytes = await self._read() if delimiter == _Delimiter.END: return if delimiter == _Delimiter.FOOTER: - self._footer_data = data_dict + self._footer_data = data_bytes return @@ -201,25 +201,25 @@ class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]): async def read_header( self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT], ) -> HeaderT: - delimiter, data_dict = await self._read() + delimiter, data_bytes = await self._read() if delimiter != _Delimiter.HEADER: raise ValueError("Stream does not contain header.") - return self._stream_types.header_t.model_validate(data_dict) + return self._stream_types.header_t.model_validate_json(data_bytes) class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]): - _footer_data: Optional[serialization.MsgPackType] + _footer_data: Optional[bytes] async def read_footer( self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT], ) -> FooterT: if self._footer_data is None: - delimiter, data_dict = await self._read() + delimiter, data_bytes = await self._read() if delimiter != _Delimiter.FOOTER: raise ValueError("Stream does not contain footer.") - self._footer_data = data_dict + self._footer_data = data_bytes - footer = self._stream_types.footer_t.model_validate(self._footer_data) + footer = self._stream_types.footer_t.model_validate_json(self._footer_data) self._footer_data = None return footer @@ -298,8 +298,7 @@ def _pack_tag(delimiter: _Delimiter, length: int) -> bytes: return struct.pack(">BI", delimiter.value, length) def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: - data_dict = obj.model_dump() - data_bytes = serialization.truss_msgpack_serialize(data_dict) + data_bytes = obj.model_dump_json().encode() data = bytearray(self._pack_tag(delimiter, len(data_bytes))) data.extend(data_bytes) # Starlette cannot handle byte array, but view works.. diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 82bab57d4..ab28713d2 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -27,6 +27,7 @@ ) import opentelemetry.sdk.trace as sdk_trace +import pydantic import starlette.requests import starlette.responses from anyio import Semaphore, to_thread @@ -56,6 +57,15 @@ TRT_LLM_EXTENSION_NAME = "trt_llm" POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30 +InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel] +OutputType = Union[ + serialization.JSONType, + serialization.MsgPackType, + Generator[bytes, None, None], + AsyncGenerator[bytes, None], + "starlette.responses.Response", +] + @asynccontextmanager async def deferred_semaphore_and_span( @@ -520,7 +530,7 @@ async def poll_for_environment_updates(self) -> None: async def preprocess( self, - inputs: serialization.InputType, + inputs: InputType, request: starlette.requests.Request, ) -> Any: descriptor = self.model_descriptor.preprocess @@ -538,7 +548,7 @@ async def predict( self, inputs: Any, request: starlette.requests.Request, - ) -> Union[serialization.OutputType, Any]: + ) -> Union[OutputType, Any]: # The result can be a serializable data structure, byte-generator, a request, # or, if `postprocessing` is used, anything. In the last case postprocessing # must convert the result to something serializable. @@ -555,9 +565,9 @@ async def predict( async def postprocess( self, - result: Union[serialization.InputType, Any], + result: Union[InputType, Any], request: starlette.requests.Request, - ) -> serialization.OutputType: + ) -> OutputType: # The postprocess function can handle outputs of `predict`, but not # generators and responses - in that case predict must return directly # and postprocess is skipped. @@ -642,9 +652,9 @@ async def _buffered_response_generator() -> AsyncGenerator[bytes, None]: async def __call__( self, - inputs: Optional[serialization.InputType], + inputs: Optional[InputType], request: starlette.requests.Request, - ) -> serialization.OutputType: + ) -> OutputType: """ Returns result from: preprocess -> predictor -> postprocess. """ @@ -726,6 +736,7 @@ async def __call__( ), tracing.detach_context(): postprocess_result = await self.postprocess(predict_result, request) + final_result: OutputType if isinstance(postprocess_result, BaseModel): # If we return a pydantic object, convert it back to a dict with tracing.section_as_event(span_post, "dump-pydantic"): diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 42b2293ae..37ab4c223 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -16,7 +16,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.routing import APIRoute as FastAPIRoute -from model_wrapper import ModelWrapper +from model_wrapper import InputType, ModelWrapper from opentelemetry import propagate as otel_propagate from opentelemetry import trace from opentelemetry.sdk import trace as sdk_trace @@ -104,7 +104,7 @@ async def _parse_body( body_raw: bytes, truss_schema: Optional[TrussSchema], span: trace.Span, - ) -> serialization.InputType: + ) -> InputType: if self.is_binary(request): with tracing.section_as_event(span, "binary-deserialize"): inputs = serialization.truss_msgpack_deserialize(body_raw) @@ -157,7 +157,7 @@ async def predict( with self._tracer.start_as_current_span( "predict-endpoint", context=trace_ctx ) as span: - inputs: Optional[serialization.InputType] + inputs: Optional[InputType] if model.model_descriptor.skip_input_parsing: inputs = None else: diff --git a/truss/templates/shared/serialization.py b/truss/templates/shared/serialization.py index a1281d4d4..21b099892 100644 --- a/truss/templates/shared/serialization.py +++ b/truss/templates/shared/serialization.py @@ -2,22 +2,9 @@ import uuid from datetime import date, datetime, time, timedelta from decimal import Decimal -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Callable, - Dict, - Generator, - List, - Optional, - Union, -) - -import pydantic +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union if TYPE_CHECKING: - import starlette.responses from numpy.typing import NDArray @@ -38,14 +25,6 @@ List["MsgPackType"], Dict[str, "MsgPackType"], ] -InputType = Union[JSONType, MsgPackType, pydantic.BaseModel] -OutputType = Union[ - JSONType, - MsgPackType, - Generator[bytes, None, None], - AsyncGenerator[bytes, None], - "starlette.responses.Response", -] # mostly cribbed from django.core.serializer.DjangoJSONEncoder From 680ba4e1c50d9baf5881d66132dd173a9038e495 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Wed, 27 Nov 2024 10:21:47 -0800 Subject: [PATCH 11/11] Review comments --- .../examples/streaming/streaming_chain.py | 4 +- truss-chains/tests/test_streaming.py | 50 ++++++++---- truss-chains/truss_chains/framework.py | 2 +- truss-chains/truss_chains/streaming.py | 79 +++++++++++-------- 4 files changed, 87 insertions(+), 48 deletions(-) diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index 5ef979f50..4b1b2488d 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -30,7 +30,9 @@ class ConsumerOutput(pydantic.BaseModel): strings: str -STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header, footer_t=Footer) +STREAM_TYPES = streaming.stream_types( + MyDataChunk, header_type=Header, footer_type=Footer +) class Generator(chains.ChainletBase): diff --git a/truss-chains/tests/test_streaming.py b/truss-chains/tests/test_streaming.py index 1eee6c031..88dd5421a 100644 --- a/truss-chains/tests/test_streaming.py +++ b/truss-chains/tests/test_streaming.py @@ -30,7 +30,9 @@ async def to_bytes_iterator(data_stream) -> AsyncIterator[bytes]: @pytest.mark.asyncio async def test_streaming_with_header_and_footer(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header, footer_t=Footer) + types = streaming.stream_types( + item_type=MyDataChunk, header_type=Header, footer_type=Footer + ) writer = streaming.stream_writer(types) header = Header(time=123.456, msg="Start of stream") @@ -61,7 +63,7 @@ async def test_streaming_with_header_and_footer(): @pytest.mark.asyncio async def test_streaming_with_items_only(): - types = streaming.stream_types(item_t=MyDataChunk) + types = streaming.stream_types(item_type=MyDataChunk) writer = streaming.stream_writer(types) items = [ @@ -84,8 +86,8 @@ async def test_streaming_with_items_only(): @pytest.mark.asyncio async def test_reading_header_when_none_sent(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) - writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) + writer = streaming.stream_writer(types) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] @@ -99,8 +101,8 @@ async def test_reading_header_when_none_sent(): @pytest.mark.asyncio async def test_reading_items_with_wrong_model(): - types_writer = streaming.stream_types(item_t=MyDataChunk) - types_reader = streaming.stream_types(item_t=Header) # Wrong item type + types_writer = streaming.stream_types(item_type=MyDataChunk) + types_reader = streaming.stream_types(item_type=Header) # Wrong item type writer = streaming.stream_writer(types_writer) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] @@ -117,9 +119,9 @@ async def test_reading_items_with_wrong_model(): @pytest.mark.asyncio async def test_streaming_with_wrong_order(): types = streaming.stream_types( - item_t=MyDataChunk, - header_t=Header, - footer_t=Footer, + item_type=MyDataChunk, + header_type=Header, + footer_type=Footer, ) writer = streaming.stream_writer(types) @@ -129,11 +131,14 @@ async def test_streaming_with_wrong_order(): data_stream = [] for item in items: data_stream.append(writer.yield_item(item)) - data_stream.append(writer.yield_header(header)) + + with pytest.raises( + ValueError, match="Cannot yield header after other data has been sent." + ): + data_stream.append(writer.yield_header(header)) data_stream.append(writer.yield_footer(footer)) reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) - # Try to read header, should fail because the first data is an item with pytest.raises(ValueError, match="Stream does not contain header."): await reader.read_header() @@ -141,7 +146,7 @@ async def test_streaming_with_wrong_order(): @pytest.mark.asyncio async def test_reading_items_without_consuming_header(): - types = streaming.stream_types(item_t=MyDataChunk, header_t=Header) + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) writer = streaming.stream_writer(types) header = Header(time=123.456, msg="Start of stream") items = [MyDataChunk(words=["hello", "world"])] @@ -163,8 +168,8 @@ async def test_reading_items_without_consuming_header(): @pytest.mark.asyncio async def test_reading_footer_when_none_sent(): - types = streaming.stream_types(item_t=MyDataChunk, footer_t=Footer) - writer = streaming.stream_writer(streaming.stream_types(item_t=MyDataChunk)) + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) items = [MyDataChunk(words=["hello", "world"])] data_stream = [] for item in items: @@ -179,3 +184,20 @@ async def test_reading_footer_when_none_sent(): # Try to read footer, expect an error with pytest.raises(ValueError, match="Stream does not contain footer."): await reader.read_footer() + + +@pytest.mark.asyncio +async def test_reading_footer_with_no_items(): + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + data_stream = [writer.yield_footer(footer)] + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert len(read_items) == 0 + + read_footer = await reader.read_footer() + assert read_footer == footer diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index d86015974..db2f822fa 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -315,7 +315,7 @@ def _validate_streaming_output_type( raw=annotation, origin_type=origin, arg_type=bytes ) - assert len(args) == 1, "AsyncIterator cannot have more than 1 arg." + assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg." arg = args[0] if arg not in _STREAM_TYPES: msg = ( diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 6dde0d473..9d9a1cae8 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -8,8 +8,10 @@ import pydantic -TAG_SIZE = 5 # uint8 + uint32. -JSONType = Union[str, int, float, bool, None, list["JSONType"], dict[str, "JSONType"]] +_TAG_SIZE = 5 # uint8 + uint32. +_JSONType = Union[ + str, int, float, bool, None, list["_JSONType"], dict[str, "_JSONType"] +] _T = TypeVar("_T") if sys.version_info < (3, 10): @@ -43,56 +45,57 @@ async def anext(iterable: AsyncIterator[_T]) -> _T: @dataclasses.dataclass class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]): - item_t: Type[ItemT] - header_t: HeaderTT # Is either `Type[HeaderT]` or `None`. - footer_t: FooterTT # Is either `Type[FooterT]` or `None`. + item_type: Type[ItemT] + header_type: HeaderTT # Is either `Type[HeaderT]` or `None`. + footer_type: FooterTT # Is either `Type[FooterT]` or `None`. @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Type[HeaderT], - footer_t: Type[FooterT], + header_type: Type[HeaderT], + footer_type: Type[FooterT], ) -> StreamTypes[ItemT, HeaderT, FooterT]: ... @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Type[HeaderT], + header_type: Type[HeaderT], ) -> StreamTypes[ItemT, HeaderT, None]: ... @overload def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - footer_t: Type[FooterT], + footer_type: Type[FooterT], ) -> StreamTypes[ItemT, None, FooterT]: ... @overload -def stream_types(item_t: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... +def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... def stream_types( - item_t: Type[ItemT], + item_type: Type[ItemT], *, - header_t: Optional[Type[HeaderT]] = None, - footer_t: Optional[Type[FooterT]] = None, + header_type: Optional[Type[HeaderT]] = None, + footer_type: Optional[Type[FooterT]] = None, ) -> StreamTypes: """Creates a bundle of item type and potentially header/footer types, each as pydantic model.""" # This indirection for creating `StreamTypes` is needed to get generic typing. - return StreamTypes(item_t, header_t, footer_t) + return StreamTypes(item_type, header_type, footer_type) # Reading ############################################################################## class _Delimiter(enum.IntEnum): + NOT_SET = enum.auto() HEADER = enum.auto() ITEM = enum.auto() FOOTER = enum.auto() @@ -161,7 +164,7 @@ def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]: async def _read(self) -> tuple[_Delimiter, bytes]: try: - tag = await self._stream.readexactly(TAG_SIZE) + tag = await self._stream.readexactly(_TAG_SIZE) # It's ok to read nothing (end of stream), but unexpected to read partial. except asyncio.IncompleteReadError: raise @@ -182,12 +185,13 @@ async def read_items(self) -> AsyncIterator[ItemT]: "Called `read_items`, but there the stream contains header data, which " "is not consumed. Call `read_header` first or remove sending a header." ) - if delimiter in (_Delimiter.FOOTER, _Delimiter.END): + if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items. + self._footer_data = data_bytes return assert delimiter == _Delimiter.ITEM while True: - yield self._stream_types.item_t.model_validate_json(data_bytes) + yield self._stream_types.item_type.model_validate_json(data_bytes) # We don't know if the next data is another item, footer or the end. delimiter, data_bytes = await self._read() if delimiter == _Delimiter.END: @@ -204,7 +208,7 @@ async def read_header( delimiter, data_bytes = await self._read() if delimiter != _Delimiter.HEADER: raise ValueError("Stream does not contain header.") - return self._stream_types.header_t.model_validate_json(data_bytes) + return self._stream_types.header_type.model_validate_json(data_bytes) class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]): @@ -219,7 +223,7 @@ async def read_footer( raise ValueError("Stream does not contain footer.") self._footer_data = data_bytes - footer = self._stream_types.footer_t.model_validate_json(self._footer_data) + footer = self._stream_types.footer_type.model_validate_json(self._footer_data) self._footer_data = None return footer @@ -273,11 +277,11 @@ def stream_reader( types: StreamTypes[ItemT, HeaderTT, FooterTT], stream: AsyncIterator[bytes], ) -> _StreamReader: - if types.header_t is None and types.footer_t is None: + if types.header_type is None and types.footer_type is None: return _StreamReader(types, stream) - if types.header_t is None: + if types.header_type is None: return StreamReaderWithFooter(types, stream) - if types.footer_t is None: + if types.footer_type is None: return StreamReaderWithHeader(types, stream) return StreamReaderFull(types, stream) @@ -288,11 +292,17 @@ def stream_reader( class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]): _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + _last_sent: _Delimiter def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ... class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]): + def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: + super().__init__(types) + self._last_sent = _Delimiter.NOT_SET + self._stream_types = types + @staticmethod def _pack_tag(delimiter: _Delimiter, length: int) -> bytes: return struct.pack(">BI", delimiter.value, length) @@ -305,6 +315,9 @@ def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: return memoryview(data) def yield_item(self, item: ItemT) -> bytes: + if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END): + raise ValueError("Cannot yield item after sending footer / closing stream.") + self._last_sent = _Delimiter.ITEM return self._serialize(item, _Delimiter.ITEM) @@ -312,8 +325,9 @@ class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]): def yield_header( self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT ) -> bytes: - if self._stream_types.header_t is None or header is None: - raise ValueError() + if self._last_sent != _Delimiter.NOT_SET: + raise ValueError("Cannot yield header after other data has been sent.") + self._last_sent = _Delimiter.HEADER return self._serialize(header, _Delimiter.HEADER) @@ -321,8 +335,9 @@ class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]): def yield_footer( self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT ) -> bytes: - if self._stream_types.header_t is None or footer is None: - raise ValueError() + if self._last_sent == _Delimiter.END: + raise ValueError("Cannot yield footer after closing stream.") + self._last_sent = _Delimiter.FOOTER return self._serialize(footer, _Delimiter.FOOTER) @@ -370,11 +385,11 @@ def stream_writer( def stream_writer( types: StreamTypes[ItemT, HeaderTT, FooterTT], ) -> _StreamWriter: - if types.header_t is None and types.footer_t is None: + if types.header_type is None and types.footer_type is None: return _StreamWriter(types) - if types.header_t is None: + if types.header_type is None: return StreamWriterWithFooter(types) - if types.footer_t is None: + if types.footer_type is None: return StreamWriterWithHeader(types) return StreamWriterFull(types)