diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index 977a960c0..dc594ce56 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -29,6 +29,7 @@ from truss.remote.baseten import custom_types as b10_types from truss.remote.baseten import remote as b10_remote from truss.remote.baseten import service as b10_service +from truss.remote.baseten.error import RemoteError from truss.truss_handle import truss_handle from truss.util import log_utils from truss.util import path as truss_path @@ -495,6 +496,114 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote) # Watch / Live Patching ################################################################ +def _create_watch_filter(root_dir: pathlib.Path): + ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir) + + def watch_filter(_: watchfiles.Change, path: str) -> bool: + return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns) + + logging.getLogger("watchfiles.main").disabled = True + return ignore_patterns, watch_filter + + +def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"): + if logs: + formatted_logs = textwrap.indent("\n".join(logs), " " * 4) + console.print(f"Intercepted logs from importing source code:\n{formatted_logs}") + + +def _handle_import_error( + exception: Exception, + console: "rich_console.Console", + error_console: "rich_console.Console", + stack_trace: Optional[str] = None, +): + error_console.print( + "Source files were changed, but pre-conditions for " + "live patching are not given. Most likely there is a " + "syntax error in the source files or names changed. " + "Try to fix the issue and save the file. Error:\n" + f"{textwrap.indent(str(exception), ' ' * 4)}" + ) + if stack_trace: + error_console.print(stack_trace) + + console.print( + "The watcher will continue and if you can resolve the " + "issue, subsequent patches might succeed.", + style="blue", + ) + + +class _ModelWatcher: + _source: pathlib.Path + _model_name: str + _remote_provider: b10_remote.BasetenRemote + _ignore_patterns: list[str] + _watch_filter: Callable[[watchfiles.Change, str], bool] + _console: "rich_console.Console" + _error_console: "rich_console.Console" + + def __init__( + self, + source: pathlib.Path, + model_name: str, + remote_provider: b10_remote.BasetenRemote, + console: "rich_console.Console", + error_console: "rich_console.Console", + ) -> None: + self._source = source + self._model_name = model_name + self._remote_provider = remote_provider + self._console = console + self._error_console = error_console + self._ignore_patterns, self._watch_filter = _create_watch_filter( + source.absolute().parent + ) + + dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name) + if not dev_version: + raise RemoteError( + "No development model found. Run `truss push` then try again." + ) + + def _patch(self) -> None: + exception_raised = None + with log_utils.LogInterceptor() as log_interceptor, self._console.status( + " Live Patching Model.\n", spinner="arrow3" + ): + try: + gen_truss_path = code_gen.gen_truss_model_from_source(self._source) + return self._remote_provider.patch_model( + gen_truss_path, + self._ignore_patterns, + self._console, + self._error_console, + ) + except Exception as e: + exception_raised = e + finally: + logs = log_interceptor.get_logs() + + _handle_intercepted_logs(logs, self._console) + if exception_raised: + _handle_import_error(exception_raised, self._console, self._error_console) + + def watch(self) -> None: + # Perform one initial patch at startup. + self._patch() + self._console.print("👀 Watching for new changes.", style="blue") + + # TODO(nikhil): Improve detection of directory structure, since right now + # we assume a flat structure + root_dir = self._source.absolute().parent + for _ in watchfiles.watch( + root_dir, watch_filter=self._watch_filter, raise_interrupt=False + ): + self._patch() + self._console.print("👀 Watching for new changes.", style="blue") + + class _Watcher: _source: pathlib.Path _entrypoint: Optional[str] @@ -573,16 +682,10 @@ def __init__( self._chainlet_data = {c.name: c for c in deployed_chainlets} self._assert_chainlet_names_same(chainlet_names) - self._ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir( + self._ignore_patterns, self._watch_filter = _create_watch_filter( self._chain_root ) - def watch_filter(_: watchfiles.Change, path: str) -> bool: - return not truss_path.is_ignored(pathlib.Path(path), self._ignore_patterns) - - logging.getLogger("watchfiles.main").disabled = True - self._watch_filter = watch_filter - @property def _original_chainlet_names(self) -> set[str]: return set(self._chainlet_data.keys()) @@ -665,27 +768,13 @@ def _patch(self, executor: concurrent.futures.Executor) -> None: finally: logs = log_interceptor.get_logs() - if logs: - formatted_logs = textwrap.indent("\n".join(logs), " " * 4) - self._console.print( - f"Intercepted logs from importing chain source code:\n{formatted_logs}" - ) - + _handle_intercepted_logs(logs, self._console) if exception_raised: - self._error_console.print( - "Source files were changed, but pre-conditions for " - "live patching are not given. Most likely there is a " - "syntax in the source files or chainlet names changed. " - "Try to fix the issue and save the file. Error:\n" - f"{textwrap.indent(str(exception_raised), ' ' * 4)}" - ) - if self._show_stack_trace: - self._error_console.print(stack_trace) - - self._console.print( - "The watcher will continue and if you can resolve the " - "issue, subsequent patches might succeed.", - style="blue", + _handle_import_error( + exception_raised, + self._console, + self._error_console, + stack_trace=stack_trace if self._show_stack_trace else None, ) return @@ -775,3 +864,20 @@ def watch( included_chainlets, ) patcher.watch() + + +def watch_model( + source: pathlib.Path, + model_name: str, + remote_provider: b10_remote.TrussRemote, + console: "rich_console.Console", + error_console: "rich_console.Console", +): + patcher = _ModelWatcher( + source=source, + model_name=model_name, + remote_provider=cast(b10_remote.BasetenRemote, remote_provider), + console=console, + error_console=error_console, + ) + patcher.watch() diff --git a/truss/cli/cli.py b/truss/cli/cli.py index 59cc91705..39fa8fee0 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -397,9 +397,22 @@ def watch(target_directory: str, remote: str) -> None: console.print( f"🪵 View logs for your deployment at {_format_link(service.logs_url)}" ) - remote_provider.sync_truss_to_dev_version_by_name( - model_name, target_directory, console, error_console - ) + + if not os.path.isfile(target_directory): + remote_provider.sync_truss_to_dev_version_by_name( + model_name, target_directory, console, error_console + ) + else: + # These imports are delayed, to handle pydantic v1 envs gracefully. + from truss_chains.deployment import deployment_client + + deployment_client.watch_model( + source=Path(target_directory), + model_name=model_name, + remote_provider=remote_provider, + console=console, + error_console=error_console, + ) # Chains Stuff ######################################################################### diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 7d5bbe9cb..1f4d0fe13 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -1,6 +1,5 @@ import enum import logging -import os import re from pathlib import Path from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Type @@ -406,16 +405,8 @@ def sync_truss_to_dev_version_by_name( "No development model found. Run `truss push` then try again." ) - patch_path = Path(target_directory) - truss_ignore_patterns = load_trussignore_patterns_from_truss_dir(patch_path) - - # For trusses that use the new chains DX, we need to watch the original source code - # but patch generated code over to the control server. - watch_path = patch_path - patch_fn = self._patch_model - if os.path.isfile(watch_path): - patch_fn = self._patch_code_gen_model - watch_path = patch_path.absolute().parent + watch_path = Path(target_directory) + truss_ignore_patterns = load_trussignore_patterns_from_truss_dir(watch_path) def watch_filter(_, path): return not is_ignored(Path(path), truss_ignore_patterns) @@ -424,11 +415,11 @@ def watch_filter(_, path): logging.getLogger("watchfiles.main").disabled = True console.print(f"🚰 Attempting to sync truss at '{watch_path}' with remote") - patch_fn(patch_path, truss_ignore_patterns, console, error_console) + self.patch_model(watch_path, truss_ignore_patterns, console, error_console) console.print(f"👀 Watching for changes to truss at '{watch_path}' ...") for _ in watch(watch_path, watch_filter=watch_filter, raise_interrupt=False): - patch_fn(patch_path, truss_ignore_patterns, console, error_console) + self.patch_model(watch_path, truss_ignore_patterns, console, error_console) def _patch( self, @@ -551,29 +542,14 @@ def do_patch(): ), ) - def _patch_code_gen_model( - self, - patch_path: Path, - truss_ignore_patterns: List[str], - console: "rich_console.Console", - error_console: "rich_console.Console", - ): - # These imports are delayed, to handle pydantic v1 envs gracefully. - from truss_chains.deployment import code_gen - - gen_truss_path = code_gen.gen_truss_model_from_source(patch_path) - return self._patch_model( - gen_truss_path, truss_ignore_patterns, console, error_console - ) - - def _patch_model( + def patch_model( self, - patch_path: Path, + watch_path: Path, truss_ignore_patterns: List[str], console: "rich_console.Console", error_console: "rich_console.Console", ): - result = self._patch(patch_path, truss_ignore_patterns) + result = self._patch(watch_path, truss_ignore_patterns) if result.status in (PatchStatus.SUCCESS, PatchStatus.SKIPPED): console.print(result.message, style="green") else: