Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chains Streaming, fixes BT-10339 #1261

Merged
merged 11 commits into from
Dec 2, 2024
35 changes: 0 additions & 35 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,3 @@ jobs:
with:
use-verbose-mode: "yes"
folder-path: "docs"

enforce-chains-example-docs-sync:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v4
with:
lfs: true
fetch-depth: 2

- name: Fetch main branch
run: git fetch origin main

- name: Check if chains examples were modified
id: check_files
run: |
if git diff --name-only origin/main | grep -q '^truss-chains/examples/.*'; then
echo "chains_docs_update_needed=true" >> $GITHUB_ENV
echo "Chains examples were modified."
else
echo "chains_docs_update_needed=false" >> $GITHUB_ENV
echo "Chains examples were not modified."
echo "::notice file=truss-chains/examples/::Chains examples not modified."
fi

- name: Enforce acknowledgment in PR description
if: env.chains_docs_update_needed == 'true'
env:
DESCRIPTION: ${{ github.event.pull_request.body }}
run: |
if [[ "$DESCRIPTION" != *"UPDATE_DOCS=done"* && "$DESCRIPTION" != *"UPDATE_DOCS=not_needed"* ]]; then
echo "::error file=truss-chains/examples/::Chains examples were modified and ack not found in PR description. Verify whether docs need to be update (https://github.com/basetenlabs/docs.baseten.co/tree/main/chains) and add an ack tag `UPDATE_DOCS={done|not_needed}` to the PR description."
exit 1
else
echo "::notice file=truss-chains/examples/::Chains examples modified and ack found int PR description."
fi
105 changes: 105 additions & 0 deletions truss-chains/examples/streaming/streaming_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import asyncio
import time
from typing import AsyncIterator

import pydantic

import truss_chains as chains
from truss_chains import streaming


class Header(pydantic.BaseModel):
time: float
msg: str


class MyDataChunk(pydantic.BaseModel):
words: list[str]


class Footer(pydantic.BaseModel):
time: float
duration_sec: float
msg: str


class ConsumerOutput(pydantic.BaseModel):
header: Header
chunks: list[MyDataChunk]
footer: Footer
strings: str


STREAM_TYPES = streaming.stream_types(MyDataChunk, header_t=Header, footer_t=Footer)


class Generator(chains.ChainletBase):
"""Example that streams fully structured pydantic items with header and footer."""

async def run_remote(self) -> AsyncIterator[bytes]:
print("Entering Generator")
streamer = streaming.stream_writer(STREAM_TYPES)
header = Header(time=time.time(), msg="Start.")
yield streamer.yield_header(header)
for i in range(1, 5):
data = MyDataChunk(
words=[chr(x + 70) * x for x in range(1, i + 1)],
)
print("Yield")
yield streamer.yield_item(data)
await asyncio.sleep(0.05)

end_time = time.time()
footer = Footer(time=end_time, duration_sec=end_time - header.time, msg="Done.")
yield streamer.yield_footer(footer)
print("Exiting Generator")


class StringGenerator(chains.ChainletBase):
"""Minimal streaming example with strings (e.g. for raw LLM output)."""
marius-baseten marked this conversation as resolved.
Show resolved Hide resolved

async def run_remote(self) -> AsyncIterator[str]:
# Note: the "chunk" boundaries are lost, when streaming raw strings. You must
# add spaces and linebreaks to the items yourself..
yield "First "
yield "second "
yield "last."


class Consumer(chains.ChainletBase):
"""Consume that reads the raw streams and parses them."""

def __init__(
self,
generator=chains.depends(Generator),
string_generator=chains.depends(StringGenerator),
):
self._generator = generator
self._string_generator = string_generator

async def run_remote(self) -> ConsumerOutput:
print("Entering Consumer")
reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote())
print("Consuming...")
header = await reader.read_header()
chunks = []
async for data in reader.read_items():
print(f"Read: {data}")
chunks.append(data)

footer = await reader.read_footer()
strings = []
async for part in self._string_generator.run_remote():
strings.append(part)

print("Exiting Consumer")
return ConsumerOutput(
header=header, chunks=chunks, footer=footer, strings="".join(strings)
)


if __name__ == "__main__":
with chains.run_local():
chain = Consumer()
result = asyncio.run(chain.run_remote())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the final chain streams output, what would be an easy way of consuming that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have 3 choices:

  • raw strings: works directly.
  • raw bytes: works directly.
  • structured/typed pydantic models: you create stream_reader with the same model definitions client side.

For the last one, we can discuss ways to distribute that implementation. It depends only on pydantic and builtins, so you wouldn't need to install the whole truss package, you just need that source file. Or we could even generate a "client".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the last one, we can discuss ways to distribute that implementation. It depends only on pydantic and builtins, so you wouldn't need to install the whole truss package, you just need that source file. Or we could even generate a "client".

Sounds good

print(result)
53 changes: 49 additions & 4 deletions truss-chains/tests/chains_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
@pytest.mark.integration
def test_chain():
with ensure_kill_all():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
tests_root = Path(__file__).parent.resolve()
chain_root = tests_root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
Expand Down Expand Up @@ -81,8 +81,8 @@ def test_chain():

@pytest.mark.asyncio
async def test_chain_local():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
tests_root = Path(__file__).parent.resolve()
chain_root = tests_root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with public_api.run_local():
with pytest.raises(ValueError):
Expand Down Expand Up @@ -119,3 +119,48 @@ async def test_chain_local():
match="Chainlets cannot be naively instantiated",
):
await entrypoint().run_remote(length=20, num_partitions=5)


@pytest.mark.integration
def test_streaming_chain():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="stream",
only_generate_trusses=False,
use_local_chains_src=True,
),
)
assert service is not None
response = service.run_remote({})
assert response.status_code == 200
print(response.json())
result = response.json()
print(result)
assert result["header"]["msg"] == "Start."
assert result["chunks"][0]["words"] == ["G"]
assert result["chunks"][1]["words"] == ["G", "HH"]
assert result["chunks"][2]["words"] == ["G", "HH", "III"]
assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
assert result["footer"]["duration_sec"] > 0
assert result["strings"] == "First second last."


@pytest.mark.asyncio
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote()
print(result)
assert result.header.msg == "Start."
assert result.chunks[0].words == ["G"]
assert result.chunks[1].words == ["G", "HH"]
assert result.chunks[2].words == ["G", "HH", "III"]
assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"]
assert result.footer.duration_sec > 0
assert result.strings == "First second last."
54 changes: 53 additions & 1 deletion truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib
import logging
import re
from typing import List
from typing import AsyncIterator, Iterator, List

import pydantic
import pytest
Expand Down Expand Up @@ -505,3 +505,55 @@ def run_remote(argument: object): ...
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
with public_api.run_local():
MultiIssue()


def test_raises_iterator_no_yield():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoYield\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint returns an iterator \(streaming\), it must have `yield` statements"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoYield(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator[str]:
return "123" # type: ignore[return-value]


def test_raises_yield_no_iterator():
match = (
rf"{TEST_FILE}:\d+ \(YieldNoIterator\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint is streaming \(has `yield` statements\), the return type must be an iterator"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class YieldNoIterator(chains.ChainletBase):
async def run_remote(self) -> str: # type: ignore[misc]
yield "123"


def test_raises_iterator_sync():
match = (
rf"{TEST_FILE}:\d+ \(IteratorSync\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Streaming endpoints \(containing `yield` statements\) are only supported for async endpoints"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorSync(chains.ChainletBase):
def run_remote(self) -> Iterator[str]:
yield "123"


def test_raises_iterator_no_arg():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoArg\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Iterators must be annotated with type \(one of \['str', 'bytes'\]\)"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoArg(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator:
yield "123"
Loading