Skip to content

Commit

Permalink
Introduce new ModelWatcher, slight refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Jan 27, 2025
1 parent 10d8ecb commit 98c86e3
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 61 deletions.
160 changes: 133 additions & 27 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import remote as b10_remote
from truss.remote.baseten import service as b10_service
from truss.remote.baseten.error import RemoteError
from truss.truss_handle import truss_handle
from truss.util import log_utils
from truss.util import path as truss_path
Expand Down Expand Up @@ -495,6 +496,114 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote)
# Watch / Live Patching ################################################################


def _create_watch_filter(root_dir: pathlib.Path):
ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir)

def watch_filter(_: watchfiles.Change, path: str) -> bool:
return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns)

logging.getLogger("watchfiles.main").disabled = True
return ignore_patterns, watch_filter


def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"):
if logs:
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
console.print(f"Intercepted logs from importing source code:\n{formatted_logs}")


def _handle_import_error(
exception: Exception,
console: "rich_console.Console",
error_console: "rich_console.Console",
stack_trace: Optional[str] = None,
):
error_console.print(
"Source files were changed, but pre-conditions for "
"live patching are not given. Most likely there is a "
"syntax error in the source files or names changed. "
"Try to fix the issue and save the file. Error:\n"
f"{textwrap.indent(str(exception), ' ' * 4)}"
)
if stack_trace:
error_console.print(stack_trace)

console.print(
"The watcher will continue and if you can resolve the "
"issue, subsequent patches might succeed.",
style="blue",
)


class _ModelWatcher:
_source: pathlib.Path
_model_name: str
_remote_provider: b10_remote.BasetenRemote
_ignore_patterns: list[str]
_watch_filter: Callable[[watchfiles.Change, str], bool]
_console: "rich_console.Console"
_error_console: "rich_console.Console"

def __init__(
self,
source: pathlib.Path,
model_name: str,
remote_provider: b10_remote.BasetenRemote,
console: "rich_console.Console",
error_console: "rich_console.Console",
) -> None:
self._source = source
self._model_name = model_name
self._remote_provider = remote_provider
self._console = console
self._error_console = error_console
self._ignore_patterns, self._watch_filter = _create_watch_filter(
source.absolute().parent
)

dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name)
if not dev_version:
raise RemoteError(
"No development model found. Run `truss push` then try again."
)

def _patch(self) -> None:
exception_raised = None
with log_utils.LogInterceptor() as log_interceptor, self._console.status(
" Live Patching Model.\n", spinner="arrow3"
):
try:
gen_truss_path = code_gen.gen_truss_model_from_source(self._source)
return self._remote_provider.patch_model(
gen_truss_path,
self._ignore_patterns,
self._console,
self._error_console,
)
except Exception as e:
exception_raised = e
finally:
logs = log_interceptor.get_logs()

_handle_intercepted_logs(logs, self._console)
if exception_raised:
_handle_import_error(exception_raised, self._console, self._error_console)

def watch(self) -> None:
# Perform one initial patch at startup.
self._patch()
self._console.print("👀 Watching for new changes.", style="blue")

# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = self._source.absolute().parent
for _ in watchfiles.watch(
root_dir, watch_filter=self._watch_filter, raise_interrupt=False
):
self._patch()
self._console.print("👀 Watching for new changes.", style="blue")


class _Watcher:
_source: pathlib.Path
_entrypoint: Optional[str]
Expand Down Expand Up @@ -573,16 +682,10 @@ def __init__(

self._chainlet_data = {c.name: c for c in deployed_chainlets}
self._assert_chainlet_names_same(chainlet_names)
self._ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(
self._ignore_patterns, self._watch_filter = _create_watch_filter(
self._chain_root
)

def watch_filter(_: watchfiles.Change, path: str) -> bool:
return not truss_path.is_ignored(pathlib.Path(path), self._ignore_patterns)

logging.getLogger("watchfiles.main").disabled = True
self._watch_filter = watch_filter

@property
def _original_chainlet_names(self) -> set[str]:
return set(self._chainlet_data.keys())
Expand Down Expand Up @@ -665,27 +768,13 @@ def _patch(self, executor: concurrent.futures.Executor) -> None:
finally:
logs = log_interceptor.get_logs()

if logs:
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
self._console.print(
f"Intercepted logs from importing chain source code:\n{formatted_logs}"
)

_handle_intercepted_logs(logs, self._console)
if exception_raised:
self._error_console.print(
"Source files were changed, but pre-conditions for "
"live patching are not given. Most likely there is a "
"syntax in the source files or chainlet names changed. "
"Try to fix the issue and save the file. Error:\n"
f"{textwrap.indent(str(exception_raised), ' ' * 4)}"
)
if self._show_stack_trace:
self._error_console.print(stack_trace)

self._console.print(
"The watcher will continue and if you can resolve the "
"issue, subsequent patches might succeed.",
style="blue",
_handle_import_error(
exception_raised,
self._console,
self._error_console,
stack_trace=stack_trace if self._show_stack_trace else None,
)
return

Expand Down Expand Up @@ -775,3 +864,20 @@ def watch(
included_chainlets,
)
patcher.watch()


def watch_model(
source: pathlib.Path,
model_name: str,
remote_provider: b10_remote.TrussRemote,
console: "rich_console.Console",
error_console: "rich_console.Console",
):
patcher = _ModelWatcher(
source=source,
model_name=model_name,
remote_provider=cast(b10_remote.BasetenRemote, remote_provider),
console=console,
error_console=error_console,
)
patcher.watch()
19 changes: 16 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,22 @@ def watch(target_directory: str, remote: str) -> None:
console.print(
f"🪵 View logs for your deployment at {_format_link(service.logs_url)}"
)
remote_provider.sync_truss_to_dev_version_by_name(
model_name, target_directory, console, error_console
)

if not os.path.isfile(target_directory):
remote_provider.sync_truss_to_dev_version_by_name(
model_name, target_directory, console, error_console
)
else:
# These imports are delayed, to handle pydantic v1 envs gracefully.
from truss_chains.deployment import deployment_client

deployment_client.watch_model(
source=Path(target_directory),
model_name=model_name,
remote_provider=remote_provider,
console=console,
error_console=error_console,
)


# Chains Stuff #########################################################################
Expand Down
38 changes: 7 additions & 31 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Type
Expand Down Expand Up @@ -406,16 +405,8 @@ def sync_truss_to_dev_version_by_name(
"No development model found. Run `truss push` then try again."
)

patch_path = Path(target_directory)
truss_ignore_patterns = load_trussignore_patterns_from_truss_dir(patch_path)

# For trusses that use the new chains DX, we need to watch the original source code
# but patch generated code over to the control server.
watch_path = patch_path
patch_fn = self._patch_model
if os.path.isfile(watch_path):
patch_fn = self._patch_code_gen_model
watch_path = patch_path.absolute().parent
watch_path = Path(target_directory)
truss_ignore_patterns = load_trussignore_patterns_from_truss_dir(watch_path)

def watch_filter(_, path):
return not is_ignored(Path(path), truss_ignore_patterns)
Expand All @@ -424,11 +415,11 @@ def watch_filter(_, path):
logging.getLogger("watchfiles.main").disabled = True

console.print(f"🚰 Attempting to sync truss at '{watch_path}' with remote")
patch_fn(patch_path, truss_ignore_patterns, console, error_console)
self.patch_model(watch_path, truss_ignore_patterns, console, error_console)

console.print(f"👀 Watching for changes to truss at '{watch_path}' ...")
for _ in watch(watch_path, watch_filter=watch_filter, raise_interrupt=False):
patch_fn(patch_path, truss_ignore_patterns, console, error_console)
self.patch_model(watch_path, truss_ignore_patterns, console, error_console)

def _patch(
self,
Expand Down Expand Up @@ -551,29 +542,14 @@ def do_patch():
),
)

def _patch_code_gen_model(
self,
patch_path: Path,
truss_ignore_patterns: List[str],
console: "rich_console.Console",
error_console: "rich_console.Console",
):
# These imports are delayed, to handle pydantic v1 envs gracefully.
from truss_chains.deployment import code_gen

gen_truss_path = code_gen.gen_truss_model_from_source(patch_path)
return self._patch_model(
gen_truss_path, truss_ignore_patterns, console, error_console
)

def _patch_model(
def patch_model(
self,
patch_path: Path,
watch_path: Path,
truss_ignore_patterns: List[str],
console: "rich_console.Console",
error_console: "rich_console.Console",
):
result = self._patch(patch_path, truss_ignore_patterns)
result = self._patch(watch_path, truss_ignore_patterns)
if result.status in (PatchStatus.SUCCESS, PatchStatus.SKIPPED):
console.print(result.message, style="green")
else:
Expand Down

0 comments on commit 98c86e3

Please sign in to comment.