Skip to content

Commit

Permalink
Organize chains code. Separate 'remote' code. (#1265)
Browse files Browse the repository at this point in the history
* Organize chains code. Separate 'remote' code.

* Fix exceptions

* bump rc

* Fix streaming trace context
  • Loading branch information
marius-baseten authored Dec 5, 2024
1 parent 5da6b48 commit 77343fb
Show file tree
Hide file tree
Showing 17 changed files with 784 additions and 690 deletions.
699 changes: 367 additions & 332 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.54rc9"
version = "0.9.55rc2"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down Expand Up @@ -84,6 +84,7 @@ aiohttp = { version = "^3.10.10", optional = false }
blake3 = { version = "^0.3.3", optional = false }
boto3 = { version = "^1.34.85", optional = false }
click = { version = "^8.0.3", optional = false }
fastapi = { version =">=0.109.1", optional = false }
google-cloud-storage = { version = "2.10.0", optional = false }
httpx = { version = ">=0.24.1", optional = false }
inquirerpy = { version = "^0.3.4", optional = false }
Expand Down
9 changes: 5 additions & 4 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import requests
from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all

from truss_chains import definitions, framework, public_api, remote, utils
from truss_chains import definitions, framework, public_api, utils
from truss_chains.deployment import deployment_client

utils.setup_dev_logging(logging.DEBUG)

Expand All @@ -19,7 +20,7 @@ def test_chain():
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = remote.push(entrypoint, options)
service = deployment_client.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")

Expand Down Expand Up @@ -127,7 +128,7 @@ def test_streaming_chain():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
service = remote.push(
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="integration-test-stream",
Expand Down Expand Up @@ -178,7 +179,7 @@ def test_numpy_chain(mode):
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, target) as entrypoint:
service = remote.push(
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name=f"integration-test-numpy-{mode}",
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from truss_chains import definitions
from truss_chains.utils import populate_chainlet_service_predict_urls
from truss_chains.remote_chainlet.utils import populate_chainlet_service_predict_urls

DYNAMIC_CHAINLET_CONFIG_VALUE = {
"Hello World!": {
Expand Down
4 changes: 3 additions & 1 deletion truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
push,
run_local,
)
from truss_chains.stub import StubBase

# TODO: make this optional (remove aiohttp, httpx and starlette deps).
from truss_chains.remote_chainlet.stub import StubBase
from truss_chains.utils import make_abs_path_here

__all__ = [
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
from truss.contexts.image_builder import serving_image_builder
from truss.util import path as truss_path

from truss_chains import definitions, framework, model_skeleton, utils
from truss_chains import definitions, framework, utils

INDENT = " " * 4
_INDENT = " " * 4
_REQUIREMENTS_FILENAME = "pip_requirements.txt"
_MODEL_FILENAME = "model.py"
_MODEL_CLS_NAME = model_skeleton.TrussChainletModel.__name__
_MODEL_CLS_NAME = "TrussChainletModel"
_TRUSS_GIT = "git+https://github.com/basetenlabs/truss.git"
_TRUSS_PIP_PATTERN = re.compile(
r"""
Expand All @@ -63,9 +63,15 @@
re.VERBOSE,
)

_MODEL_SKELETON_FILE = (
pathlib.Path(__file__).parent.parent.resolve()
/ "remote_chainlet"
/ "model_skeleton.py"
)


def _indent(text: str, num: int = 1) -> str:
return textwrap.indent(text, INDENT * num)
return textwrap.indent(text, _INDENT * num)


def _run_simple_subprocess(cmd: str) -> None:
Expand Down Expand Up @@ -312,7 +318,7 @@ async def run_remote(
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
```
"""
imports = {"from truss_chains import stub"}
imports = {"from truss_chains.remote_chainlet import stub"}
src_parts: list[str] = []
input_src = _gen_truss_input_pydantic(chainlet)
_update_src(input_src, src_parts, imports)
Expand Down Expand Up @@ -395,7 +401,7 @@ def leave_SimpleStatementLine(

def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
"""Generates AST for the `load` method of the truss model."""
imports = {"from truss_chains import stub", "import logging"}
imports = {"from truss_chains.remote_chainlet import stub", "import logging"}
stub_args = []
for name, dep in chainlet_descriptor.dependencies.items():
# `dep.name` is the class name, while `name` is the argument name.
Expand Down Expand Up @@ -423,7 +429,10 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So

def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
"""Generates AST for the `predict` method of the truss model."""
imports: set[str] = {"from truss_chains import stub"}
imports: set[str] = {
"from truss_chains.remote_chainlet import stub",
"from truss_chains.remote_chainlet import utils",
}
parts: list[str] = []
def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def"
input_model_name = _get_input_model_name(chainlet_descriptor.name)
Expand All @@ -444,7 +453,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
# Add error handling context manager:
parts.append(
_indent(
f"with stub.trace_parent(request), stub.exception_to_http_error("
f"with stub.trace_parent(request), utils.exception_to_http_error("
f'chainlet_name="{chainlet_descriptor.name}"):'
)
)
Expand All @@ -458,13 +467,15 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
maybe_await = ""
run_remote = chainlet_descriptor.endpoint.name
# See docs of `pydantic_set_field_dict` for why this is needed.
args = "**stub.pydantic_set_field_dict(inputs)"
args = "**utils.pydantic_set_field_dict(inputs)"
parts.append(
_indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2)
)
if chainlet_descriptor.endpoint.is_streaming:
# Streaming returns raw iterator, no pydantic model.
parts.append(_indent("return result"))
# This needs to be nested inside the `trace_parent` context!
parts.append(_indent("async for chunk in result:", 2))
parts.append(_indent("yield chunk", 3))
else:
result_pydantic = f"{output_type_name}(result)"
parts.append(_indent(f"return {result_pydantic}"))
Expand All @@ -474,9 +485,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
def _gen_truss_chainlet_model(
chainlet_descriptor: definitions.ChainletAPIDescriptor,
) -> _Source:
skeleton_tree = libcst.parse_module(
pathlib.Path(model_skeleton.__file__).read_text()
)
skeleton_tree = libcst.parse_module(_MODEL_SKELETON_FILE.read_text())
imports: set[str] = set(
libcst.Module(body=[node]).code
for node in skeleton_tree.body
Expand All @@ -489,8 +498,7 @@ def _gen_truss_chainlet_model(
class_definition: libcst.ClassDef = utils.expect_one(
node
for node in skeleton_tree.body
if isinstance(node, libcst.ClassDef)
and node.name.value == model_skeleton.TrussChainletModel.__name__
if isinstance(node, libcst.ClassDef) and node.name.value == _MODEL_CLS_NAME
)

load_src = _gen_load_src(chainlet_descriptor)
Expand Down Expand Up @@ -561,14 +569,32 @@ def _make_requirements(image: definitions.DockerImage) -> list[str]:
)
pip_requirements.update(image.pip_requirements)

has_truss_pypy = any(
bool(_TRUSS_PIP_PATTERN.match(req)) for req in pip_requirements
truss_pypy = next(
(req for req in pip_requirements if _TRUSS_PIP_PATTERN.match(req)), None
)
has_truss_git = any(_TRUSS_GIT in req for req in pip_requirements)

if not (has_truss_git or has_truss_pypy):
truss_git = next((req for req in pip_requirements if _TRUSS_GIT in req), None)

if truss_git:
logging.warning(
"The chainlet contains a truss version from github as a pip_requirement:\n"
f"\t{truss_git}\n"
"This could result in inconsistencies between the deploying client and the "
"deployed chainlet. This is not recommended for production chains."
)
if truss_pypy:
logging.warning(
"The chainlet contains a pinned truss version as a pip_requirement:\n"
f"\t{truss_pypy}\n"
"This could result in inconsistencies between the deploying client and the "
"deployed chainlet. This is not recommended for production chains. If "
"`truss` is not manually added as a requirement, the same version as "
"locally installed will be automatically added and ensure compatibility."
)

if not (truss_git or truss_pypy):
truss_pip = f"truss=={truss.version()}"
logging.info(
logging.debug(
f"Truss not found in pip requirements, auto-adding: `{truss_pip}`."
)
pip_requirements.add(truss_pip)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@

import tenacity
import watchfiles

if TYPE_CHECKING:
from rich import console as rich_console
from rich import progress
from truss.local import local_config_handler
from truss.remote import remote_factory
from truss.remote.baseten import core as b10_core
Expand All @@ -37,7 +33,12 @@
from truss.util import log_utils
from truss.util import path as truss_path

from truss_chains import code_gen, definitions, framework, utils
from truss_chains import definitions, framework, utils
from truss_chains.deployment import code_gen

if TYPE_CHECKING:
from rich import console as rich_console
from rich import progress


class DockerTrussService(b10_service.TrussService):
Expand Down
14 changes: 8 additions & 6 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import pathlib
from typing import TYPE_CHECKING, ContextManager, Mapping, Optional, Type, Union

from truss_chains import definitions, framework
from truss_chains.deployment import deployment_client

if TYPE_CHECKING:
from rich import progress

from truss_chains import definitions, framework
from truss_chains import remote as chains_remote


def depends_context() -> definitions.DeploymentContext:
"""Sets a "symbolic marker" for injecting a context object at runtime.
Expand Down Expand Up @@ -137,7 +137,7 @@ def push(
remote: str = "baseten",
environment: Optional[str] = None,
progress_bar: Optional[Type["progress.Progress"]] = None,
) -> chains_remote.BasetenChainService:
) -> deployment_client.BasetenChainService:
"""
Deploys a chain remotely (with all dependent chainlets).
Expand Down Expand Up @@ -168,8 +168,10 @@ def push(
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint, options, progress_bar=progress_bar)
assert isinstance(service, chains_remote.BasetenChainService) # Per options above.
service = deployment_client.push(entrypoint, options, progress_bar=progress_bar)
assert isinstance(
service, deployment_client.BasetenChainService
) # Per options above.
return service


Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from truss.templates.shared import secrets_resolver

from truss_chains import definitions
from truss_chains.utils import populate_chainlet_service_predict_urls
from truss_chains.remote_chainlet import utils


class TrussChainletModel:
Expand All @@ -27,7 +27,7 @@ def __init__(
deployment_environment: Optional[definitions.Environment] = (
definitions.Environment.model_validate(environment) if environment else None
)
chainlet_to_deployed_service = populate_chainlet_service_predict_urls(
chainlet_to_deployed_service = utils.populate_chainlet_service_predict_urls(
truss_metadata.chainlet_to_service
)

Expand All @@ -42,12 +42,16 @@ def __init__(

# def load(self) -> None:
# logging.info(f"Loading Chainlet `TextToNum`.")
# self._chainlet = main.TextToNum(
# mistral=stub.factory(MistralLLM, self._context))
# self._chainlet = itest_chain.TextToNum(
# replicator=stub.factory(TextReplicator, self._context),
# side_effect=stub.factory(SideEffectOnlySubclass, self._context),
# )
#
# def predict(self, inputs: TextToNumInput) -> TextToNumOutput:
# with utils.exception_to_http_error(
# def predict(
# self, inputs: TextToNumInput, request: starlette.requests.Request
# ) -> TextToNumOutput:
# with stub.trace_parent(request), utils.exception_to_http_error(
# include_stack=True, chainlet_name="TextToNum"
# ):
# result = self._chainlet.run_remote(data=inputs.data)
# return TextToNumOutput((result,))
# result = self._chainlet.run_remote(**utils.pydantic_set_field_dict(inputs))
# return TextToNumOutput(result)
Loading

0 comments on commit 77343fb

Please sign in to comment.