Skip to content

Commit

Permalink
Merge branch 'main' into mf/bump-truss-transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil authored Jan 28, 2025
2 parents 766fee8 + 934bcbc commit 42a8775
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.59"
version = "0.9.60rc001"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_numpy_chain(mode):
print(response.json())


@pytest.mark.integration
@pytest.mark.asyncio
async def test_timeout():
with ensure_kill_all():
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def migrate_fields(cls, values):
class ComputeSpec(pydantic.BaseModel):
"""Parsed and validated compute. See ``Compute`` for more information."""

# TODO[rcano] add node count
cpu_count: int = 1
predict_concurrency: int = 1
memory: str = "2Gi"
Expand Down
160 changes: 133 additions & 27 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from truss.remote import remote_factory
from truss.remote.baseten import core as b10_core
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import error as b10_errors
from truss.remote.baseten import remote as b10_remote
from truss.remote.baseten import service as b10_service
from truss.truss_handle import truss_handle
Expand Down Expand Up @@ -495,6 +496,114 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote)
# Watch / Live Patching ################################################################


def _create_watch_filter(root_dir: pathlib.Path):
ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir)

def watch_filter(_: watchfiles.Change, path: str) -> bool:
return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns)

logging.getLogger("watchfiles.main").disabled = True
return ignore_patterns, watch_filter


def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"):
if logs:
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
console.print(f"Intercepted logs from importing source code:\n{formatted_logs}")


def _handle_import_error(
exception: Exception,
console: "rich_console.Console",
error_console: "rich_console.Console",
stack_trace: Optional[str] = None,
):
error_console.print(
"Source files were changed, but pre-conditions for "
"live patching are not given. Most likely there is a "
"syntax error in the source files or names changed. "
"Try to fix the issue and save the file. Error:\n"
f"{textwrap.indent(str(exception), ' ' * 4)}"
)
if stack_trace:
error_console.print(stack_trace)

console.print(
"The watcher will continue and if you can resolve the "
"issue, subsequent patches might succeed.",
style="blue",
)


class _ModelWatcher:
_source: pathlib.Path
_model_name: str
_remote_provider: b10_remote.BasetenRemote
_ignore_patterns: list[str]
_watch_filter: Callable[[watchfiles.Change, str], bool]
_console: "rich_console.Console"
_error_console: "rich_console.Console"

def __init__(
self,
source: pathlib.Path,
model_name: str,
remote_provider: b10_remote.BasetenRemote,
console: "rich_console.Console",
error_console: "rich_console.Console",
) -> None:
self._source = source
self._model_name = model_name
self._remote_provider = remote_provider
self._console = console
self._error_console = error_console
self._ignore_patterns, self._watch_filter = _create_watch_filter(
source.absolute().parent
)

dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name)
if not dev_version:
raise b10_errors.RemoteError(
"No development model found. Run `truss push` then try again."
)

def _patch(self) -> None:
exception_raised = None
with log_utils.LogInterceptor() as log_interceptor, self._console.status(
" Live Patching Model.\n", spinner="arrow3"
):
try:
gen_truss_path = code_gen.gen_truss_model_from_source(self._source)
return self._remote_provider.patch(
gen_truss_path,
self._ignore_patterns,
self._console,
self._error_console,
)
except Exception as e:
exception_raised = e
finally:
logs = log_interceptor.get_logs()

_handle_intercepted_logs(logs, self._console)
if exception_raised:
_handle_import_error(exception_raised, self._console, self._error_console)

def watch(self) -> None:
# Perform one initial patch at startup.
self._patch()
self._console.print("👀 Watching for new changes.", style="blue")

# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = self._source.absolute().parent
for _ in watchfiles.watch(
root_dir, watch_filter=self._watch_filter, raise_interrupt=False
):
self._patch()
self._console.print("👀 Watching for new changes.", style="blue")


class _Watcher:
_source: pathlib.Path
_entrypoint: Optional[str]
Expand Down Expand Up @@ -573,16 +682,10 @@ def __init__(

self._chainlet_data = {c.name: c for c in deployed_chainlets}
self._assert_chainlet_names_same(chainlet_names)
self._ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(
self._ignore_patterns, self._watch_filter = _create_watch_filter(
self._chain_root
)

def watch_filter(_: watchfiles.Change, path: str) -> bool:
return not truss_path.is_ignored(pathlib.Path(path), self._ignore_patterns)

logging.getLogger("watchfiles.main").disabled = True
self._watch_filter = watch_filter

@property
def _original_chainlet_names(self) -> set[str]:
return set(self._chainlet_data.keys())
Expand Down Expand Up @@ -665,27 +768,13 @@ def _patch(self, executor: concurrent.futures.Executor) -> None:
finally:
logs = log_interceptor.get_logs()

if logs:
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
self._console.print(
f"Intercepted logs from importing chain source code:\n{formatted_logs}"
)

_handle_intercepted_logs(logs, self._console)
if exception_raised:
self._error_console.print(
"Source files were changed, but pre-conditions for "
"live patching are not given. Most likely there is a "
"syntax in the source files or chainlet names changed. "
"Try to fix the issue and save the file. Error:\n"
f"{textwrap.indent(str(exception_raised), ' ' * 4)}"
)
if self._show_stack_trace:
self._error_console.print(stack_trace)

self._console.print(
"The watcher will continue and if you can resolve the "
"issue, subsequent patches might succeed.",
style="blue",
_handle_import_error(
exception_raised,
self._console,
self._error_console,
stack_trace=stack_trace if self._show_stack_trace else None,
)
return

Expand Down Expand Up @@ -775,3 +864,20 @@ def watch(
included_chainlets,
)
patcher.watch()


def watch_model(
source: pathlib.Path,
model_name: str,
remote_provider: b10_remote.TrussRemote,
console: "rich_console.Console",
error_console: "rich_console.Console",
):
patcher = _ModelWatcher(
source=source,
model_name=model_name,
remote_provider=cast(b10_remote.BasetenRemote, remote_provider),
console=console,
error_console=error_console,
)
patcher.watch()
13 changes: 11 additions & 2 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
DEFAULT_USE_GPU = False
DEFAULT_NODE_COUNT = 1

DEFAULT_BLOB_BACKEND = HTTP_PUBLIC_BLOB_BACKEND

Expand Down Expand Up @@ -259,6 +260,7 @@ class Resources:
memory: str = DEFAULT_MEMORY
use_gpu: bool = DEFAULT_USE_GPU
accelerator: AcceleratorSpec = field(default_factory=AcceleratorSpec)
node_count: int = DEFAULT_NODE_COUNT

@staticmethod
def from_dict(d):
Expand All @@ -270,9 +272,15 @@ def from_dict(d):
use_gpu = d.get("use_gpu", DEFAULT_USE_GPU)
if accelerator.accelerator is not None:
use_gpu = True
# TODO[rcano]: add validation for node count
node_count = d.get("node_count", DEFAULT_NODE_COUNT)

return Resources(
cpu=cpu, memory=memory, use_gpu=use_gpu, accelerator=accelerator
cpu=cpu,
memory=memory,
use_gpu=use_gpu,
accelerator=accelerator,
node_count=node_count,
)

def to_dict(self):
Expand Down Expand Up @@ -521,6 +529,7 @@ class TrussConfig:
memory: 14Gi
use_gpu: true
accelerator: A10G
node_count: 2
```
secrets (Dict[str, str]):
<Warning>
Expand Down Expand Up @@ -765,7 +774,7 @@ def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]:


DATACLASS_TO_REQ_KEYS_MAP = {
Resources: {"accelerator", "cpu", "memory", "use_gpu"},
Resources: {"accelerator", "cpu", "memory", "use_gpu", "node_count"},
Runtime: {"predict_concurrency"},
Build: {"model_server"},
TrussConfig: {
Expand Down
19 changes: 16 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,22 @@ def watch(target_directory: str, remote: str) -> None:
console.print(
f"🪵 View logs for your deployment at {_format_link(service.logs_url)}"
)
remote_provider.sync_truss_to_dev_version_by_name(
model_name, target_directory, console, error_console
)

if not os.path.isfile(target_directory):
remote_provider.sync_truss_to_dev_version_by_name(
model_name, target_directory, console, error_console
)
else:
# These imports are delayed, to handle pydantic v1 envs gracefully.
from truss_chains.deployment import deployment_client

deployment_client.watch_model(
source=Path(target_directory),
model_name=model_name,
remote_provider=remote_provider,
console=console,
error_console=error_console,
)


# Chains Stuff #########################################################################
Expand Down
2 changes: 2 additions & 0 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def default_config() -> Dict[str, Any]:
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
"node_count": 1,
},
"secrets": {},
"system_packages": [],
Expand All @@ -745,6 +746,7 @@ def trtllm_config(default_config) -> Dict[str, Any]:
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
"node_count": 1,
}
trtllm_config["trt_llm"] = {
"build": {
Expand Down
1 change: 1 addition & 0 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_default_config_not_crowded_end_to_end():
accelerator: null
cpu: '1'
memory: 2Gi
node_count: 1
use_gpu: false
secrets: {}
system_packages: []
Expand Down
1 change: 1 addition & 0 deletions truss/tests/test_truss_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def generate_default_config():
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
"node_count": 1,
},
"secrets": {},
"system_packages": [],
Expand Down

0 comments on commit 42a8775

Please sign in to comment.