Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC for new model DX #1311

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import requests
from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
from truss.truss_handle.build import load_from_code_config

from truss_chains import definitions, framework, public_api, utils
from truss_chains.deployment import deployment_client
Expand Down Expand Up @@ -270,3 +271,27 @@ async def test_timeout():
assert re.match(
sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE
), sync_error_str


@pytest.mark.integration
def test_traditional_truss():
with ensure_kill_all():
chain_root = TEST_ROOT / "traditional_truss" / "truss_model.py"
truss_handle = load_from_code_config(chain_root)

assert truss_handle.spec.config.resources.cpu == "4"
assert truss_handle.spec.config.model_name == "OverridePassthroughModelName"

port = utils.get_free_port()
truss_handle.docker_run(
local_port=port,
detach=True,
network="host",
)

response = requests.post(
f"http://localhost:{port}/v1/models/model:predict",
json={"call_count_increment": 5},
)
assert response.status_code == 200
assert response.json() == 5
20 changes: 20 additions & 0 deletions truss-chains/tests/traditional_truss/truss_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import truss_chains as chains


class PassthroughModel(chains.ModelBase):
remote_config: chains.RemoteConfig = chains.RemoteConfig( # type: ignore
compute=chains.Compute(4, "1Gi"),
name="OverridePassthroughModelName",
docker_image=chains.DockerImage(
pip_requirements=[
"truss==0.9.59rc2",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traditional docker entrypoint doesn't have an equivalent to use_local_chains_src, this was a hack to get the e2e test to pass for now

]
),
)

def __init__(self):
self._call_count = 0

async def run_remote(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand All @@ -50,6 +51,7 @@
"Assets",
"BasetenImage",
"ChainletBase",
"ModelBase",
"ChainletOptions",
"Compute",
"CustomImage",
Expand Down
24 changes: 12 additions & 12 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import shutil
import subprocess
import sys
import tempfile
import textwrap
from typing import Any, Iterable, Mapping, Optional, get_args, get_origin

Expand Down Expand Up @@ -648,13 +649,13 @@ def _inplace_fill_base_image(
)


def _make_truss_config(
def write_truss_config_yaml(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this because I thought it was confusing that it returned a TrussConfig object that was never used, took me a while to figure out it was persisting to yaml which a later flow picked up

chainlet_dir: pathlib.Path,
chains_config: definitions.RemoteConfig,
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
model_name: str,
use_local_chains_src: bool,
) -> truss_config.TrussConfig:
):
"""Generate a truss config for a Chainlet."""
config = truss_config.TrussConfig()
config.model_name = model_name
Expand Down Expand Up @@ -707,16 +708,14 @@ def _make_truss_config(
config.write_to_yaml_file(
chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True
)
return config


def gen_truss_chainlet(
chain_root: pathlib.Path,
gen_root: pathlib.Path,
chain_name: str,
chainlet_descriptor: definitions.ChainletAPIDescriptor,
model_name: str,
use_local_chains_src: bool,
model_name: Optional[str] = None,
use_local_chains_src: bool = False,
) -> pathlib.Path:
# Filter needed services and customize options.
dep_services = {}
Expand All @@ -726,17 +725,18 @@ def gen_truss_chainlet(
display_name=dep.display_name,
options=dep.options,
)
gen_root = pathlib.Path(tempfile.gettempdir())
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
logging.info(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_make_truss_config(
chainlet_dir,
chainlet_descriptor.chainlet_cls.remote_config,
dep_services,
model_name,
use_local_chains_src,
write_truss_config_yaml(
chainlet_dir=chainlet_dir,
chains_config=chainlet_descriptor.chainlet_cls.remote_config,
model_name=model_name or chain_name,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chain models have their name suffixed, but for traditional trusses we'd want to preserve the name specified in RemoteConfig exactly (or the one autoderived from the class name)

We can push this coalesce down into write_truss_config_yaml if we think that's cleaner

chainlet_to_service=dep_services,
use_local_chains_src=use_local_chains_src,
)
# This assumes all imports are absolute w.r.t chain root (or site-packages).
truss_path.copy_tree_path(
Expand Down
8 changes: 1 addition & 7 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
import pathlib
import tempfile
import textwrap
import traceback
import uuid
Expand Down Expand Up @@ -138,10 +137,8 @@ class _ChainSourceGenerator:
def __init__(
self,
options: definitions.PushOptions,
gen_root: pathlib.Path,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small cleanup, this gen_root was passed around to allow callers to optionally pass in a given directory, but every callsite I tracked down defaulted to a generated temp dir

) -> None:
self._options = options
self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir())

@property
def _use_local_chains_src(self) -> bool:
Expand Down Expand Up @@ -175,7 +172,6 @@ def generate_chainlet_artifacts(

chainlet_dir = code_gen.gen_truss_chainlet(
chain_root,
self._gen_root,
self._options.chain_name,
chainlet_descriptor,
model_name,
Expand Down Expand Up @@ -205,11 +201,10 @@ def generate_chainlet_artifacts(
def push(
entrypoint: Type[definitions.ABCChainlet],
options: definitions.PushOptions,
gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()),
progress_bar: Optional[Type["progress.Progress"]] = None,
) -> Optional[ChainService]:
entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator(
options, gen_root
options
).generate_chainlet_artifacts(
entrypoint,
)
Expand Down Expand Up @@ -632,7 +627,6 @@ def _code_gen_and_patch_thread(
# TODO: Maybe try-except code_gen errors explicitly.
chainlet_dir = code_gen.gen_truss_chainlet(
self._chain_root,
pathlib.Path(tempfile.gettempdir()),
self._deployed_chain_name,
descr,
self._chainlet_data[descr.display_name].oracle_name,
Expand Down
Loading
Loading