-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
POC for new model DX #1311
base: main
Are you sure you want to change the base?
POC for new model DX #1311
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import truss_chains as chains | ||
|
||
|
||
class PassthroughModel(chains.ModelBase): | ||
remote_config: chains.RemoteConfig = chains.RemoteConfig( # type: ignore | ||
compute=chains.Compute(4, "1Gi"), | ||
name="OverridePassthroughModelName", | ||
docker_image=chains.DockerImage( | ||
pip_requirements=[ | ||
"truss==0.9.59rc2", | ||
] | ||
), | ||
) | ||
|
||
def __init__(self): | ||
self._call_count = 0 | ||
|
||
async def run_remote(self, call_count_increment: int) -> int: | ||
self._call_count += call_count_increment | ||
return self._call_count |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
import shutil | ||
import subprocess | ||
import sys | ||
import tempfile | ||
import textwrap | ||
from typing import Any, Iterable, Mapping, Optional, get_args, get_origin | ||
|
||
|
@@ -648,13 +649,13 @@ def _inplace_fill_base_image( | |
) | ||
|
||
|
||
def _make_truss_config( | ||
def write_truss_config_yaml( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed this because I thought it was confusing that it returned a |
||
chainlet_dir: pathlib.Path, | ||
chains_config: definitions.RemoteConfig, | ||
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], | ||
model_name: str, | ||
use_local_chains_src: bool, | ||
) -> truss_config.TrussConfig: | ||
): | ||
"""Generate a truss config for a Chainlet.""" | ||
config = truss_config.TrussConfig() | ||
config.model_name = model_name | ||
|
@@ -707,16 +708,14 @@ def _make_truss_config( | |
config.write_to_yaml_file( | ||
chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True | ||
) | ||
return config | ||
|
||
|
||
def gen_truss_chainlet( | ||
chain_root: pathlib.Path, | ||
gen_root: pathlib.Path, | ||
chain_name: str, | ||
chainlet_descriptor: definitions.ChainletAPIDescriptor, | ||
model_name: str, | ||
use_local_chains_src: bool, | ||
model_name: Optional[str] = None, | ||
use_local_chains_src: bool = False, | ||
) -> pathlib.Path: | ||
# Filter needed services and customize options. | ||
dep_services = {} | ||
|
@@ -726,17 +725,18 @@ def gen_truss_chainlet( | |
display_name=dep.display_name, | ||
options=dep.options, | ||
) | ||
gen_root = pathlib.Path(tempfile.gettempdir()) | ||
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root) | ||
logging.info( | ||
f"Code generation for Chainlet `{chainlet_descriptor.name}` " | ||
f"in `{chainlet_dir}`." | ||
) | ||
_make_truss_config( | ||
chainlet_dir, | ||
chainlet_descriptor.chainlet_cls.remote_config, | ||
dep_services, | ||
model_name, | ||
use_local_chains_src, | ||
write_truss_config_yaml( | ||
chainlet_dir=chainlet_dir, | ||
chains_config=chainlet_descriptor.chainlet_cls.remote_config, | ||
model_name=model_name or chain_name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chain models have their name suffixed, but for traditional trusses we'd want to preserve the name specified in We can push this coalesce down into |
||
chainlet_to_service=dep_services, | ||
use_local_chains_src=use_local_chains_src, | ||
) | ||
# This assumes all imports are absolute w.r.t chain root (or site-packages). | ||
truss_path.copy_tree_path( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
import json | ||
import logging | ||
import pathlib | ||
import tempfile | ||
import textwrap | ||
import traceback | ||
import uuid | ||
|
@@ -138,10 +137,8 @@ class _ChainSourceGenerator: | |
def __init__( | ||
self, | ||
options: definitions.PushOptions, | ||
gen_root: pathlib.Path, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small cleanup, this |
||
) -> None: | ||
self._options = options | ||
self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir()) | ||
|
||
@property | ||
def _use_local_chains_src(self) -> bool: | ||
|
@@ -175,7 +172,6 @@ def generate_chainlet_artifacts( | |
|
||
chainlet_dir = code_gen.gen_truss_chainlet( | ||
chain_root, | ||
self._gen_root, | ||
self._options.chain_name, | ||
chainlet_descriptor, | ||
model_name, | ||
|
@@ -205,11 +201,10 @@ def generate_chainlet_artifacts( | |
def push( | ||
entrypoint: Type[definitions.ABCChainlet], | ||
options: definitions.PushOptions, | ||
gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()), | ||
progress_bar: Optional[Type["progress.Progress"]] = None, | ||
) -> Optional[ChainService]: | ||
entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator( | ||
options, gen_root | ||
options | ||
).generate_chainlet_artifacts( | ||
entrypoint, | ||
) | ||
|
@@ -632,7 +627,6 @@ def _code_gen_and_patch_thread( | |
# TODO: Maybe try-except code_gen errors explicitly. | ||
chainlet_dir = code_gen.gen_truss_chainlet( | ||
self._chain_root, | ||
pathlib.Path(tempfile.gettempdir()), | ||
self._deployed_chain_name, | ||
descr, | ||
self._chainlet_data[descr.display_name].oracle_name, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traditional docker entrypoint doesn't have an equivalent to
use_local_chains_src
, this was a hack to get the e2e test to pass for now