diff --git a/truss-chains/tests/test_e2e.py b/truss-chains/tests/test_e2e.py index 86da35eff..c572f182e 100644 --- a/truss-chains/tests/test_e2e.py +++ b/truss-chains/tests/test_e2e.py @@ -6,6 +6,7 @@ import pytest import requests from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all +from truss.truss_handle.build import load_from_code_config from truss_chains import definitions, framework, public_api, utils from truss_chains.deployment import deployment_client @@ -270,3 +271,27 @@ async def test_timeout(): assert re.match( sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE ), sync_error_str + + +@pytest.mark.integration +def test_traditional_truss(): + with ensure_kill_all(): + chain_root = TEST_ROOT / "traditional_truss" / "truss_model.py" + truss_handle = load_from_code_config(chain_root) + + assert truss_handle.spec.config.resources.cpu == "4" + assert truss_handle.spec.config.model_name == "OverridePassthroughModelName" + + port = utils.get_free_port() + truss_handle.docker_run( + local_port=port, + detach=True, + network="host", + ) + + response = requests.post( + f"http://localhost:{port}/v1/models/model:predict", + json={"call_count_increment": 5}, + ) + assert response.status_code == 200 + assert response.json() == 5 diff --git a/truss-chains/tests/traditional_truss/truss_model.py b/truss-chains/tests/traditional_truss/truss_model.py new file mode 100644 index 000000000..12cf5bb85 --- /dev/null +++ b/truss-chains/tests/traditional_truss/truss_model.py @@ -0,0 +1,20 @@ +import truss_chains as chains + + +class PassthroughModel(chains.ModelBase): + remote_config: chains.RemoteConfig = chains.RemoteConfig( # type: ignore + compute=chains.Compute(4, "1Gi"), + name="OverridePassthroughModelName", + docker_image=chains.DockerImage( + pip_requirements=[ + "truss==0.9.59rc2", + ] + ), + ) + + def __init__(self): + self._call_count = 0 + + async def run_remote(self, call_count_increment: int) -> int: + self._call_count += call_count_increment + return self._call_count diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index 252218e41..0f2c37013 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -35,6 +35,7 @@ ) from truss_chains.public_api import ( ChainletBase, + ModelBase, depends, depends_context, mark_entrypoint, @@ -50,6 +51,7 @@ "Assets", "BasetenImage", "ChainletBase", + "ModelBase", "ChainletOptions", "Compute", "CustomImage", diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index e569c0318..cd80f9d66 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -31,6 +31,7 @@ import shutil import subprocess import sys +import tempfile import textwrap from typing import Any, Iterable, Mapping, Optional, get_args, get_origin @@ -648,13 +649,13 @@ def _inplace_fill_base_image( ) -def _make_truss_config( +def write_truss_config_yaml( chainlet_dir: pathlib.Path, chains_config: definitions.RemoteConfig, chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], model_name: str, use_local_chains_src: bool, -) -> truss_config.TrussConfig: +): """Generate a truss config for a Chainlet.""" config = truss_config.TrussConfig() config.model_name = model_name @@ -707,16 +708,14 @@ def _make_truss_config( config.write_to_yaml_file( chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True ) - return config def gen_truss_chainlet( chain_root: pathlib.Path, - gen_root: pathlib.Path, chain_name: str, chainlet_descriptor: definitions.ChainletAPIDescriptor, - model_name: str, - use_local_chains_src: bool, + model_name: Optional[str] = None, + use_local_chains_src: bool = False, ) -> pathlib.Path: # Filter needed services and customize options. dep_services = {} @@ -726,17 +725,18 @@ def gen_truss_chainlet( display_name=dep.display_name, options=dep.options, ) + gen_root = pathlib.Path(tempfile.gettempdir()) chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root) logging.info( f"Code generation for Chainlet `{chainlet_descriptor.name}` " f"in `{chainlet_dir}`." ) - _make_truss_config( - chainlet_dir, - chainlet_descriptor.chainlet_cls.remote_config, - dep_services, - model_name, - use_local_chains_src, + write_truss_config_yaml( + chainlet_dir=chainlet_dir, + chains_config=chainlet_descriptor.chainlet_cls.remote_config, + model_name=model_name or chain_name, + chainlet_to_service=dep_services, + use_local_chains_src=use_local_chains_src, ) # This assumes all imports are absolute w.r.t chain root (or site-packages). truss_path.copy_tree_path( diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index 9f4290250..ba954d661 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -4,7 +4,6 @@ import json import logging import pathlib -import tempfile import textwrap import traceback import uuid @@ -138,10 +137,8 @@ class _ChainSourceGenerator: def __init__( self, options: definitions.PushOptions, - gen_root: pathlib.Path, ) -> None: self._options = options - self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir()) @property def _use_local_chains_src(self) -> bool: @@ -175,7 +172,6 @@ def generate_chainlet_artifacts( chainlet_dir = code_gen.gen_truss_chainlet( chain_root, - self._gen_root, self._options.chain_name, chainlet_descriptor, model_name, @@ -205,11 +201,10 @@ def generate_chainlet_artifacts( def push( entrypoint: Type[definitions.ABCChainlet], options: definitions.PushOptions, - gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()), progress_bar: Optional[Type["progress.Progress"]] = None, ) -> Optional[ChainService]: entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator( - options, gen_root + options ).generate_chainlet_artifacts( entrypoint, ) @@ -632,7 +627,6 @@ def _code_gen_and_patch_thread( # TODO: Maybe try-except code_gen errors explicitly. chainlet_dir = code_gen.gen_truss_chainlet( self._chain_root, - pathlib.Path(tempfile.gettempdir()), self._deployed_chain_name, descr, self._chainlet_data[descr.display_name].oracle_name, diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 34e314e0d..87386b477 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -14,6 +14,7 @@ import sys import types import warnings +from importlib.abc import Loader from typing import ( Any, Callable, @@ -32,9 +33,11 @@ ) import pydantic +from truss.truss_handle.truss_handle import TrussHandle from typing_extensions import ParamSpec from truss_chains import definitions, utils +from truss_chains.deployment import code_gen _SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None, pydantic.BaseModel} _SIMPLE_CONTAINERS = {list, dict} @@ -725,19 +728,10 @@ def _validate_dependencies( return dependencies -def _validate_chainlet_cls( - cls: Type[definitions.ABCChainlet], location: _ErrorLocation -) -> None: - if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): - _collect_error( - f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " - f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}" - f"(...)`. Missing for `{cls}`.", - _ErrorKind.MISSING_API_ERROR, - location, - ) - return - +def _validate_remote_config( + cls: Type[definitions.ABCChainlet], + location: _ErrorLocation, +): if not isinstance( remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME), definitions.RemoteConfig, @@ -749,10 +743,9 @@ def _validate_chainlet_cls( _ErrorKind.TYPE_ERROR, location, ) - return -def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: +def validate_and_register_chain(cls: Type[definitions.ABCChainlet]) -> None: """Note that validation errors will only be collected, not raised, and Chainlets. with issues, are still added to the registry. Use `raise_validation_errors` to assert all Chainlets are valid and before performing operations that depend on @@ -761,7 +754,7 @@ def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: line = inspect.getsourcelines(cls)[1] location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__) - _validate_chainlet_cls(cls, location) + _validate_remote_config(cls, location) init_validator = _ChainletInitValidator(cls, location) chainlet_descriptor = definitions.ChainletAPIDescriptor( chainlet_cls=cls, @@ -1169,50 +1162,18 @@ def _get_entrypoint_chainlets(symbols) -> set[Type[definitions.ABCChainlet]]: @contextlib.contextmanager def import_target( - module_path: pathlib.Path, target_name: Optional[str] + module_path: pathlib.Path, target_name: Optional[str] = None ) -> Iterator[Type[definitions.ABCChainlet]]: - """The context manager ensures that modules imported by the chain and - Chainlets registered in ``_global_chainlet_registry`` are removed upon exit. - - I.e. aiming at making the import idempotent for common usages, although there could - be additional side effects not accounted for by this implementation.""" - module_path = pathlib.Path(module_path).resolve() - module_name = module_path.stem # Use the file's name as the module name - if not os.path.isfile(module_path): - raise ImportError( - f"`{module_path}` is not a file. You must point to a python file where " - "the entrypoint Chainlet is defined." - ) - - import_error_msg = f"Could not import `{module_path}`. Check path." - spec = importlib.util.spec_from_file_location(module_name, module_path) - if not spec: - raise ImportError(import_error_msg) - if not spec.loader: - raise ImportError(import_error_msg) - - module = importlib.util.module_from_spec(spec) - module.__file__ = str(module_path) - # Since the framework depends on tracking the source files via `inspect` and this - # depends on the modules bein properly registered in `sys.modules`, we have to - # manually do this here (because importlib does not do it automatically). This - # registration has to stay at least until the push command has finished. - if module_name in sys.modules: - raise ImportError( - f"{import_error_msg} There is already a module in `sys.modules` " - f"with name `{module_name}`. Overwriting that value is unsafe. " - "Try renaming your source file." - ) + resolved_module_path = pathlib.Path(module_path).resolve() + module, loader = _load_module(module_path) modules_before = set(sys.modules.keys()) - sys.modules[module_name] = module - # Add path for making absolute imports relative to the source_module's dir. - sys.path.insert(0, str(module_path.parent)) + modules_after = set() + chainlets_before = _global_chainlet_registry.get_chainlet_names() chainlets_after = set() - modules_after = set() try: try: - spec.loader.exec_module(module) + loader.exec_module(module) raise_validation_errors() finally: modules_after = set(sys.modules.keys()) @@ -1223,7 +1184,7 @@ def import_target( if not target_cls: raise AttributeError( f"Target Chainlet class `{target_name}` not found " - f"in `{module_path}`." + f"in `{resolved_module_path}`." ) if not utils.issubclass_safe(target_cls, definitions.ABCChainlet): raise TypeError( @@ -1235,13 +1196,13 @@ def import_target( if len(entrypoints) == 0: raise ValueError( "No `target_name` was specified and no Chainlet in " - "`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag " + f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag " "one Chainlet or provide the Chainlet class name." ) elif len(entrypoints) > 1: raise ValueError( "`target_name` was not specified and multiple Chainlets in " - f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag " + f"`{resolved_module_path}` were tagged with `@chains.mark_entrypoint`. Tag " "one Chainlet or provide the Chainlet class name. Found Chainlets: " f"\n{list(cls.name for cls in entrypoints)}" ) @@ -1253,30 +1214,85 @@ def import_target( yield target_cls finally: + _cleanup_module_imports(modules_before, modules_after, resolved_module_path) for chainlet_name in chainlets_after - chainlets_before: _global_chainlet_registry.unregister_chainlet(chainlet_name) - modules_diff = modules_after - modules_before - # Apparently torch import leaves some side effects that cannot be reverted - # by deleting the modules and would lead to a crash when another import - # is attempted. Since torch is a common lib, we make this explicit special - # case and just leave those modules. - # TODO: this seems still brittle and other modules might cause similar problems. - # it would be good to find a more principled solution. - modules_to_delete = { - s for s in modules_diff if not (s.startswith("torch.") or s == "torch") - } - if torch_modules := modules_diff - modules_to_delete: - logging.debug( - f"Keeping torch modules after import context: {torch_modules}" - ) - logging.debug( - f"Deleting modules when exiting import context: {modules_to_delete}" +def _load_module( + module_path: pathlib.Path, +) -> tuple[types.ModuleType, Loader]: + """The context manager ensures that modules imported by the Model/Chain + are removed upon exit. + + I.e. aiming at making the import idempotent for common usages, although there could + be additional side effects not accounted for by this implementation.""" + module_name = module_path.stem # Use the file's name as the module name + if not os.path.isfile(module_path): + raise ImportError( + f"`{module_path}` is not a file. You must point to a python file where " + f"the entrypoint Chainlet is defined." ) - for mod in modules_to_delete: - del sys.modules[mod] - try: - sys.path.remove(str(module_path.parent)) - except ValueError: # In case the value was already removed for whatever reason. - pass + + import_error_msg = f"Could not import `{module_path}`. Check path." + spec = importlib.util.spec_from_file_location(module_name, module_path) + if not spec or not spec.loader: + raise ImportError(import_error_msg) + + module = importlib.util.module_from_spec(spec) + module.__file__ = str(module_path) + # Since the framework depends on tracking the source files via `inspect` and this + # depends on the modules bein properly registered in `sys.modules`, we have to + # manually do this here (because importlib does not do it automatically). This + # registration has to stay at least until the push command has finished. + if module_name in sys.modules: + raise ImportError( + f"{import_error_msg} There is already a module in `sys.modules` " + f"with name `{module_name}`. Overwriting that value is unsafe. " + "Try renaming your source file." + ) + + sys.modules[module_name] = module + # Add path for making absolute imports relative to the source_module's dir. + sys.path.insert(0, str(module_path.parent)) + + return module, spec.loader + + +# Ensures the loaded system modules are restored to before we executed user defined code. +def _cleanup_module_imports( + modules_before: set[str], modules_after: set[str], module_path: pathlib.Path +): + modules_diff = modules_after - modules_before + # Apparently torch import leaves some side effects that cannot be reverted + # by deleting the modules and would lead to a crash when another import + # is attempted. Since torch is a common lib, we make this explicit special + # case and just leave those modules. + # TODO: this seems still brittle and other modules might cause similar problems. + # it would be good to find a more principled solution. + modules_to_delete = { + s for s in modules_diff if not (s.startswith("torch.") or s == "torch") + } + if torch_modules := modules_diff - modules_to_delete: + logging.debug(f"Keeping torch modules after import context: {torch_modules}") + + logging.debug(f"Deleting modules when exiting import context: {modules_to_delete}") + for mod in modules_to_delete: + del sys.modules[mod] + + +# NB(nikhil): Generates a TrussHandle whose spec points to a generated +# directory that contains data dumped from the configuration in code. +def truss_handle_from_code_config(model_file: pathlib.Path) -> TrussHandle: + # TODO(nikhil): Improve detection of directory structure, since right now + # we assume a flat structure + root_dir = model_file.absolute().parent + with import_target(model_file) as entrypoint_cls: + descriptor = _global_chainlet_registry.get_descriptor(entrypoint_cls) + generated_dir = code_gen.gen_truss_chainlet( + chain_root=root_dir, + chain_name=entrypoint_cls.display_name, + chainlet_descriptor=descriptor, + ) + + return TrussHandle(truss_dir=generated_dir) diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index ecb1be35a..9545c6945 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -107,7 +107,7 @@ def __init_subclass__(cls, **kwargs) -> None: # Each sub-class has own, isolated metadata, e.g. we don't want # `mark_entrypoint` to propagate to subclasses. cls.meta_data = definitions.ChainletMetadata() - framework.validate_and_register_class(cls) # Errors are collected, not raised! + framework.validate_and_register_chain(cls) # Errors are collected, not raised! # For default init (from `object`) we don't need to check anything. if cls.has_custom_init(): original_init = cls.__init__ @@ -122,6 +122,19 @@ def __init_with_arg_check__(self, *args, **kwargs): cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign] +class ModelBase(definitions.ABCChainlet): + """Base class for all singular truss models. + + Inheriting from this class adds validations to make sure subclasses adhere to the + truss model pattern. + """ + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True) + framework.validate_and_register_chain(cls) + + @overload def mark_entrypoint( cls_or_chain_name: Type[framework.ChainletT], diff --git a/truss-chains/truss_chains/remote_chainlet/utils.py b/truss-chains/truss_chains/remote_chainlet/utils.py index ebb6f6244..1ce1c5e08 100644 --- a/truss-chains/truss_chains/remote_chainlet/utils.py +++ b/truss-chains/truss_chains/remote_chainlet/utils.py @@ -31,6 +31,9 @@ def populate_chainlet_service_predict_urls( chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], ) -> Mapping[str, definitions.DeployedServiceDescriptor]: chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {} + # If there are no dependencies of this chainlet, no need to derive dynamic URLs + if len(chainlet_to_service) == 0: + return chainlet_to_deployed_service dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync( definitions.DYNAMIC_CHAINLET_CONFIG_KEY diff --git a/truss/cli/cli.py b/truss/cli/cli.py index a2ff6fec3..5494fc0e7 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -50,7 +50,10 @@ ) from truss.truss_handle.build import cleanup as _cleanup from truss.truss_handle.build import init as _init -from truss.truss_handle.build import load +from truss.truss_handle.build import ( + load, + load_from_code_config, +) from truss.util import docker from truss.util.log_utils import LogInterceptor @@ -1132,11 +1135,11 @@ def push( TARGET_DIRECTORY: A Truss directory. If none, use current directory. """ + if not remote: remote = inquire_remote_name(RemoteFactory.get_available_config_names()) remote_provider = RemoteFactory.create(remote=remote) - tr = _get_truss_from_directory(target_directory=target_directory) model_name = model_name or tr.spec.config.model_name @@ -1337,7 +1340,11 @@ def _get_truss_from_directory(target_directory: Optional[str] = None): """Gets Truss from directory. If none, use the current directory""" if target_directory is None: target_directory = os.getcwd() - return load(target_directory) + if not os.path.isfile(target_directory): + return load(target_directory) + # NB(nikhil): if target_directory points to a specific file, assume they are using + # the Python driven DX for configuring a truss + return load_from_code_config(Path(target_directory)) truss_cli.add_command(container) diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 20a90420a..b1c2f549b 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -280,7 +280,7 @@ def archive_truss( Returns: A file-like object containing the tar file """ - truss_dir = truss_handle._spec.truss_dir + truss_dir = truss_handle._truss_dir # check for a truss_ignore file and read the ignore patterns if it exists ignore_patterns = load_trussignore_patterns_from_truss_dir(truss_dir) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 660b671e2..ef16cea6c 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -139,9 +139,10 @@ def _prepare_push( if model_name.isspace(): raise ValueError("Model name cannot be empty") - gathered_truss = TrussHandle(truss_handle.gather()) + if truss_handle.is_scattered(): + truss_handle = TrussHandle(truss_handle.gather()) - if gathered_truss.spec.model_server != ModelServer.TrussServer: + if truss_handle.spec.model_server != ModelServer.TrussServer: publish = True if promote: @@ -176,10 +177,10 @@ def _prepare_push( if model_id is not None and disable_truss_download: raise ValueError("disable-truss-download can only be used for new models") - temp_file = archive_truss(gathered_truss, progress_bar) + temp_file = archive_truss(truss_handle, progress_bar) s3_key = upload_truss(self._api, temp_file, progress_bar) encoded_config_str = base64_encoded_json_str( - gathered_truss._spec._config.to_dict() + truss_handle._spec._config.to_dict() ) validate_truss_config(self._api, encoded_config_str) diff --git a/truss/truss_handle/build.py b/truss/truss_handle/build.py index 5eae8821d..ccb3978e7 100644 --- a/truss/truss_handle/build.py +++ b/truss/truss_handle/build.py @@ -105,6 +105,13 @@ def load(truss_directory: str) -> TrussHandle: return TrussHandle(Path(truss_directory)) +def load_from_code_config(model_file: Path) -> TrussHandle: + # These imports are delayed, to handle pydantic v1 envs gracefully. + from truss_chains import framework + + return framework.truss_handle_from_code_config(model_file) + + def cleanup() -> None: """ Cleans up .truss directory. diff --git a/truss/truss_handle/truss_handle.py b/truss/truss_handle/truss_handle.py index f80f55ab9..c624d6b1b 100644 --- a/truss/truss_handle/truss_handle.py +++ b/truss/truss_handle/truss_handle.py @@ -106,7 +106,7 @@ def wait(self): class TrussHandle: def __init__(self, truss_dir: Path, validate: bool = True) -> None: self._truss_dir = truss_dir - self._spec = TrussSpec(truss_dir) + self._spec = TrussSpec(self._truss_dir) self._hash_for_mod_time: Optional[Tuple[float, str]] = None if validate: self.validate()