Skip to content

Commit

Permalink
Release 0.9.56 (#1276)
Browse files Browse the repository at this point in the history
* update (#1267)

Co-authored-by: Tianshu Cheng <[email protected]>

* send truss version on patch (#1268)

* Speculative Decoding Interface refactor (#1270)

* spec dec config

* add optional dict of trt llm configs

* fix bad merge

* add extensions support

* fix fixture

* cli push fixes

* constants

* fix ordering

* fix merge

* refactor interface

* add tp validation error

* self review

* use constant

* fix tests

* fix tests

* add request_default_max_tokens

* fix default on trtllm runtime

* update copy

* bump to 54rc0

* add total token limit to toplevel config

* bump briton to 0.3.10

* fix import

* 54rc2

* fix rc3

* rc4

* bump briton server image

* bump rc6 for briton 0.3.12.dev3

* bump rc7

* revert trtllm serialization changes

* bump briton

* interface refactor

* add validation + tests

* 56rc0

* reduce property

* Update trt_llm_config.py (#1274)

* Update trt_llm_config.py -> revision (#1269)

* Better chains error propagation (+various fixes). (#1271)

* Bump briton in truss library (#1273)

* bump briton to briton==0.3.12.dev8

* bump truss to 0.9.56rc1

* Bump version to 0.9.56

---------

Co-authored-by: Tianshu <[email protected]>
Co-authored-by: Tianshu Cheng <[email protected]>
Co-authored-by: rcano-baseten <[email protected]>
Co-authored-by: joostinyi <[email protected]>
Co-authored-by: Michael Feil <[email protected]>
Co-authored-by: Marius Killinger <[email protected]>
  • Loading branch information
7 people authored Dec 10, 2024
1 parent e8e1664 commit 901fce9
Show file tree
Hide file tree
Showing 18 changed files with 325 additions and 247 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.55"
version = "0.9.56"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
12 changes: 8 additions & 4 deletions truss-chains/examples/streaming/streaming_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConsumerOutput(pydantic.BaseModel):
class Generator(chains.ChainletBase):
"""Example that streams fully structured pydantic items with header and footer."""

async def run_remote(self) -> AsyncIterator[bytes]:
async def run_remote(self, cause_error: bool) -> AsyncIterator[bytes]:
print("Entering Generator")
streamer = streaming.stream_writer(STREAM_TYPES)
header = Header(time=time.time(), msg="Start.")
Expand All @@ -49,6 +49,8 @@ async def run_remote(self) -> AsyncIterator[bytes]:
)
print("Yield")
yield streamer.yield_item(data)
if cause_error and i > 2:
raise RuntimeError("Test Error")
await asyncio.sleep(0.05)

end_time = time.time()
Expand Down Expand Up @@ -79,9 +81,11 @@ def __init__(
self._generator = generator
self._string_generator = string_generator

async def run_remote(self) -> ConsumerOutput:
async def run_remote(self, cause_error: bool) -> ConsumerOutput:
print("Entering Consumer")
reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote())
reader = streaming.stream_reader(
STREAM_TYPES, self._generator.run_remote(cause_error)
)
print("Consuming...")
header = await reader.read_header()
chunks = []
Expand All @@ -103,5 +107,5 @@ async def run_remote(self) -> ConsumerOutput:
if __name__ == "__main__":
with chains.run_local():
chain = Consumer()
result = asyncio.run(chain.run_remote())
result = asyncio.run(chain.run_remote(False))
print(result)
18 changes: 13 additions & 5 deletions truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def run_remote(self, length: int) -> str:
return (template * repetitions)[:length]


def validate_data(data):
if len(data) > 30:
raise ValueError(f"This input is too long: {len(data)}.")


class TextReplicator(chains.ChainletBase):
remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM)

Expand All @@ -44,8 +49,7 @@ def __init__(self):
self.multiplier = 2

def run_remote(self, data: str) -> str:
if len(data) > 30:
raise ValueError(f"This input is too long: {len(data)}.")
validate_data(data)
return data * self.multiplier


Expand Down Expand Up @@ -123,13 +127,17 @@ async def run_remote(
extra_arg=123,
)
print(pydantic_default_arg, simple_default_arg)
value = 0
for part in text_parts.parts:
value += self._text_to_num.run_remote(part)
value = self._accumulate_parts(text_parts.parts)
return (
value,
data,
number,
pydantic_default_arg,
simple_default_arg,
)

def _accumulate_parts(self, parts) -> int:
value = 0
for part in parts:
value += self._text_to_num.run_remote(part)
return value
42 changes: 38 additions & 4 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import re
import time
from pathlib import Path

import pytest
Expand All @@ -23,6 +25,7 @@ def test_chain():
service = deployment_client.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
time.sleep(1.0) # Wait for models to be ready.

# Call without providing values for default arguments.
response = requests.post(
Expand Down Expand Up @@ -73,11 +76,34 @@ def test_chain():
url, json={"length": 300, "num_partitions": 3}, stream=True
)
print(response)
assert response.status_code == 500

error = definitions.RemoteErrorDetail.model_validate(response.json()["error"])
error_str = error.format()
print(error_str)
assert "ValueError: This input is too long: 100." in error_str
assert response.status_code == 500

error_regex = r"""
Chainlet-Traceback \(most recent call last\):
File \".*?/itest_chain\.py\", line \d+, in run_remote
value = self\._accumulate_parts\(text_parts\.parts\)
File \".*?/itest_chain\.py\", line \d+, in _accumulate_parts
value \+= self\._text_to_num\.run_remote\(part\)
ValueError: \(showing chained remote errors, root error at the bottom\)
├─ Error in dependency Chainlet `TextToNum`:
│ Chainlet-Traceback \(most recent call last\):
│ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ generated_text = self\._replicator\.run_remote\(data\)
│ ValueError: \(showing chained remote errors, root error at the bottom\)
│ ├─ Error in dependency Chainlet `TextReplicator`:
│ │ Chainlet-Traceback \(most recent call last\):
│ │ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ │ validate_data\(data\)
│ │ File \".*?/itest_chain\.py\", line \d+, in validate_data
│ │ raise ValueError\(f\"This input is too long: \{len\(data\)\}\.\"\)
╰ ╰ ValueError: This input is too long: \d+\.
"""

assert re.match(error_regex.strip(), error_str.strip(), re.MULTILINE), error_str


@pytest.mark.asyncio
Expand Down Expand Up @@ -137,7 +163,8 @@ def test_streaming_chain():
),
)
assert service is not None
response = service.run_remote({})

response = service.run_remote({"cause_error": False})
assert response.status_code == 200
print(response.json())
result = response.json()
Expand All @@ -150,14 +177,21 @@ def test_streaming_chain():
assert result["footer"]["duration_sec"] > 0
assert result["strings"] == "First second last."

# TODO: build error handling for stream reader.
# response = service.run_remote({"cause_error": True})
# assert response.status_code == 200
# print(response.json())
# result = response.json()
# print(result)


@pytest.mark.asyncio
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote()
result = await entrypoint().run_remote(cause_error=False)
print(result)
assert result.header.msg == "Start."
assert result.chunks[0].words == ["G"]
Expand Down
4 changes: 1 addition & 3 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ class RemoteErrorDetail(SafeModel):
error response.
"""

remote_name: str
exception_cls_name: str
exception_module_name: Optional[str]
exception_message: str
Expand All @@ -654,8 +653,7 @@ def format(self) -> str:
else ""
)
error = (
f"{RemoteErrorDetail.__name__} in `{self.remote_name}`\n"
f"Traceback (most recent call last):\n"
f"Chainlet-Traceback (most recent call last):\n"
f"{stack}{self.exception_cls_name}: {self.exception_message}{exc_info}"
)
return error
Expand Down
5 changes: 1 addition & 4 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
)
# Add error handling context manager:
parts.append(
_indent(
f"with stub.trace_parent(request), utils.exception_to_http_error("
f'chainlet_name="{chainlet_descriptor.name}"):'
)
_indent("with stub.trace_parent(request), utils.exception_to_http_error():")
)
# Invoke Chainlet.
if (
Expand Down
55 changes: 32 additions & 23 deletions truss-chains/truss_chains/remote_chainlet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,43 +140,46 @@ def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseM
# Error Propagation Utils. #############################################################


def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn:
"""Raises `starlette.exceptions.HTTPExceptionn` with `RemoteErrorDetail`."""
def _handle_exception(exception: Exception) -> NoReturn:
"""Raises `HTTPException` with `RemoteErrorDetail`."""
if hasattr(exception, "__module__"):
exception_module_name = exception.__module__
else:
exception_module_name = None

error_stack = traceback.extract_tb(exception.__traceback__)
# Exclude the error handling functions from the stack trace.
exclude_frames = {
exception_to_http_error.__name__,
response_raise_errors.__name__,
async_response_raise_errors.__name__,
}
final_tb = [frame for frame in error_stack if frame.name not in exclude_frames]
stack = list(
[definitions.StackFrame.from_frame_summary(frame) for frame in final_tb]
)
# Filter everything before (model.py) and after (stubs, error handling) so that only
# user-defined code remains. See test_e2e.py::test_chain for expected results.
model_predict_index = 0
first_stub_index = len(error_stack)
for i, frame in enumerate(error_stack):
if frame.filename.endswith("model/model.py") and frame.name == "predict":
model_predict_index = i + 1
if frame.filename.endswith("remote_chainlet/stub.py") and frame.name.startswith(
"predict" # predict sycnc|async|stream.
):
first_stub_index = i - 1
break

final_tb = error_stack[model_predict_index:first_stub_index]
stack = [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb]
error = definitions.RemoteErrorDetail(
remote_name=chainlet_name,
exception_cls_name=exception.__class__.__name__,
exception_module_name=exception_module_name,
exception_message=str(exception),
user_stack_trace=stack,
user_stack_trace=list(stack),
)
raise fastapi.HTTPException(
status_code=500, detail=error.model_dump()
) from exception


@contextlib.contextmanager
def exception_to_http_error(chainlet_name: str) -> Iterator[None]:
# TODO: move chainlet name from here to caller side.
def exception_to_http_error() -> Iterator[None]:
try:
yield
except Exception as e:
_handle_exception(e, chainlet_name)
_handle_exception(e)


def _resolve_exception_class(
Expand Down Expand Up @@ -213,24 +216,30 @@ def _handle_response_error(response_json: dict, remote_name: str):
except KeyError as e:
logging.error(f"response_json: {response_json}")
raise ValueError(
"Could not get `error` field from JSON from error response"
"Could not get `error` field from JSON from chainlet error response"
) from e

try:
error = definitions.RemoteErrorDetail.model_validate(error_json)
except pydantic.ValidationError as e:
if isinstance(error_json, str):
msg = f"Remote error occurred in `{remote_name}`: '{error_json}'"
raise definitions.GenericRemoteException(msg) from None
raise ValueError(
"Could not parse error. Error details are expected to be either a "
"Could not parse chainlet error. Error details are expected to be either a "
"plain string (old truss models) or a serialized "
f"`definitions.RemoteErrorDetail.__name__`, got:\n{repr(error_json)}"
f"`{definitions.RemoteErrorDetail.__name__}`, got:\n{repr(error_json)}"
) from e

exception_cls = _resolve_exception_class(error)
error_format = textwrap.indent(error.format(), "│ ")
*lines, last_line = error_format.splitlines()
last_line = f"╰{last_line[1:]}" if last_line.startswith("│") else last_line
error_format = "\n".join(lines + [last_line])
msg = (
f"(showing remote errors, root message at the bottom)\n"
f"--> Preceding Remote Cause:\n"
f"{textwrap.indent(error.format(), ' ')}"
f"(showing chained remote errors, root error at the bottom)\n"
f"├─ Error in dependency Chainlet `{remote_name}`:\n"
f"{error_format}"
)
raise exception_cls(msg)

Expand Down
1 change: 0 additions & 1 deletion truss-chains/truss_chains/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ async def _read(self) -> tuple[_Delimiter, bytes]:
if not length:
return delimiter, b""
data_bytes = await self._stream.readexactly(length)
print(f"Read Delimiter: {delimiter}")
return delimiter, data_bytes

async def read_items(self) -> AsyncIterator[ItemT]:
Expand Down
2 changes: 1 addition & 1 deletion truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev4"]
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev8"]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
"tensorrt_cu12_bindings==10.2.0.post1",
Expand Down
Loading

0 comments on commit 901fce9

Please sign in to comment.