From a02e8e924de307b9dfdbf07d1c58c1f6fa5cac40 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Tue, 14 Jan 2025 14:54:01 +0000 Subject: [PATCH 1/5] POC for new model DX --- truss-chains/truss_chains/__init__.py | 2 + truss-chains/truss_chains/definitions.py | 17 ++ .../truss_chains/deployment/code_gen.py | 5 +- truss-chains/truss_chains/framework.py | 151 ++++++++++++++++-- truss-chains/truss_chains/public_api.py | 12 ++ truss/cli/cli.py | 22 +++ 6 files changed, 194 insertions(+), 15 deletions(-) diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index 252218e41..0f2c37013 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -35,6 +35,7 @@ ) from truss_chains.public_api import ( ChainletBase, + ModelBase, depends, depends_context, mark_entrypoint, @@ -50,6 +51,7 @@ "Assets", "BasetenImage", "ChainletBase", + "ModelBase", "ChainletOptions", "Compute", "CustomImage", diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index fd7908c99..52b683956 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -529,6 +529,23 @@ def display_name(cls) -> str: # ... +class ABCModel(abc.ABC): + remote_config: ClassVar[RemoteConfig] = RemoteConfig() + + @classproperty + @classmethod + def name(cls) -> str: + return cls.__name__ + + @classproperty + @classmethod + def display_name(cls) -> str: + return cls.remote_config.name or cls.name + + @abc.abstractmethod + def predict(self, request: Any) -> Any: ... + + class TypeDescriptor(SafeModelNonSerializable): """For describing I/O types of Chainlets.""" diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index e569c0318..552e36a33 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -648,7 +648,7 @@ def _inplace_fill_base_image( ) -def _make_truss_config( +def write_truss_config_yaml( chainlet_dir: pathlib.Path, chains_config: definitions.RemoteConfig, chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], @@ -707,7 +707,6 @@ def _make_truss_config( config.write_to_yaml_file( chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True ) - return config def gen_truss_chainlet( @@ -731,7 +730,7 @@ def gen_truss_chainlet( f"Code generation for Chainlet `{chainlet_descriptor.name}` " f"in `{chainlet_dir}`." ) - _make_truss_config( + write_truss_config_yaml( chainlet_dir, chainlet_descriptor.chainlet_cls.remote_config, dep_services, diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 34e314e0d..e293f6492 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -776,6 +776,38 @@ def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: _global_chainlet_registry.register_chainlet(chainlet_descriptor) +def validate_base_model(cls: Type[definitions.ABCModel]) -> None: + # NB(nikhil): Following lines cause ERROR TypeError: is a built-in class + # src_path = os.path.abspath(inspect.getfile(cls)) + # line = inspect.getsourcelines(cls)[1] + location = _ErrorLocation(src_path="model.py", line=10, chainlet_name=cls.__name__) + + # NB(nikhil): This seems to pass even when my definition doesn't have remote_config, likely + # pulling from default + if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): + _collect_error( + f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " + f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}" + f"(...)`. Missing for `{cls}`.", + _ErrorKind.MISSING_API_ERROR, + location, + ) + return + + if not isinstance( + remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME), + definitions.RemoteConfig, + ): + _collect_error( + f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " + f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " + f"for `{cls}`.", + _ErrorKind.TYPE_ERROR, + location, + ) + return + + # Dependency-Injection / Registry ###################################################### @@ -1176,23 +1208,23 @@ def import_target( I.e. aiming at making the import idempotent for common usages, although there could be additional side effects not accounted for by this implementation.""" - module_path = pathlib.Path(module_path).resolve() - module_name = module_path.stem # Use the file's name as the module name - if not os.path.isfile(module_path): + resolved_module_path = pathlib.Path(module_path).resolve() + module_name = resolved_module_path.stem # Use the file's name as the module name + if not os.path.isfile(resolved_module_path): raise ImportError( - f"`{module_path}` is not a file. You must point to a python file where " + f"`{resolved_module_path}` is not a file. You must point to a python file where " "the entrypoint Chainlet is defined." ) - import_error_msg = f"Could not import `{module_path}`. Check path." - spec = importlib.util.spec_from_file_location(module_name, module_path) + import_error_msg = f"Could not import `{resolved_module_path}`. Check path." + spec = importlib.util.spec_from_file_location(module_name, resolved_module_path) if not spec: raise ImportError(import_error_msg) if not spec.loader: raise ImportError(import_error_msg) module = importlib.util.module_from_spec(spec) - module.__file__ = str(module_path) + module.__file__ = str(resolved_module_path) # Since the framework depends on tracking the source files via `inspect` and this # depends on the modules bein properly registered in `sys.modules`, we have to # manually do this here (because importlib does not do it automatically). This @@ -1206,7 +1238,7 @@ def import_target( modules_before = set(sys.modules.keys()) sys.modules[module_name] = module # Add path for making absolute imports relative to the source_module's dir. - sys.path.insert(0, str(module_path.parent)) + sys.path.insert(0, str(resolved_module_path.parent)) chainlets_before = _global_chainlet_registry.get_chainlet_names() chainlets_after = set() modules_after = set() @@ -1223,7 +1255,7 @@ def import_target( if not target_cls: raise AttributeError( f"Target Chainlet class `{target_name}` not found " - f"in `{module_path}`." + f"in `{resolved_module_path}`." ) if not utils.issubclass_safe(target_cls, definitions.ABCChainlet): raise TypeError( @@ -1235,13 +1267,13 @@ def import_target( if len(entrypoints) == 0: raise ValueError( "No `target_name` was specified and no Chainlet in " - "`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag " + f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag " "one Chainlet or provide the Chainlet class name." ) elif len(entrypoints) > 1: raise ValueError( "`target_name` was not specified and multiple Chainlets in " - f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag " + f"`{resolved_module_path}` were tagged with `@chains.mark_entrypoint`. Tag " "one Chainlet or provide the Chainlet class name. Found Chainlets: " f"\n{list(cls.name for cls in entrypoints)}" ) @@ -1277,6 +1309,101 @@ def import_target( for mod in modules_to_delete: del sys.modules[mod] try: - sys.path.remove(str(module_path.parent)) + sys.path.remove(str(resolved_module_path.parent)) except ValueError: # In case the value was already removed for whatever reason. pass + + +# NB(nikhil): mainly taken from above, but with some dependency logic removed +@contextlib.contextmanager +def import_model_target( + module_path: pathlib.Path, +) -> Iterator[Type[definitions.ABCModel]]: + resolved_module_path = pathlib.Path(module_path).resolve() + module_name = resolved_module_path.stem # Use the file's name as the module name + if not os.path.isfile(resolved_module_path): + raise ImportError( + f"`{resolved_module_path}` is not a file. You must point to a python file where " + "the Model is defined." + ) + + import_error_msg = f"Could not import `{resolved_module_path}`. Check path." + spec = importlib.util.spec_from_file_location(module_name, resolved_module_path) + if not spec: + raise ImportError(import_error_msg) + if not spec.loader: + raise ImportError(import_error_msg) + + module = importlib.util.module_from_spec(spec) + module.__file__ = str(resolved_module_path) + # Since the framework depends on tracking the source files via `inspect` and this + # depends on the modules bein properly registered in `sys.modules`, we have to + # manually do this here (because importlib does not do it automatically). This + # registration has to stay at least until the push command has finished. + if module_name in sys.modules: + raise ImportError( + f"{import_error_msg} There is already a module in `sys.modules` " + f"with name `{module_name}`. Overwriting that value is unsafe. " + "Try renaming your source file." + ) + + modules_before = set(sys.modules.keys()) + try: + try: + spec.loader.exec_module(module) + raise_validation_errors() + finally: + modules_after = set(sys.modules.keys()) + + module_vars = (getattr(module, name) for name in dir(module)) + models: set[Type[definitions.ABCModel]] = [ + sym + for sym in module_vars + if utils.issubclass_safe(sym, definitions.ABCModel) + ] + if len(models) == 0: + raise ValueError( + "No `target_name` was specified and no class in " + f"`{module_path}` extends `ModelBase`." + ) + elif len(models) > 1: + raise ValueError( + "`target_name` was not specified and multiple classes in " + f"`{resolved_module_path}` extend `ModelBase`. Ensure " + "exactly one class serves as entrypoint; found classes: " + f"\n{list(cls.name for cls in models)}" + ) + target_cls = utils.expect_one(models) + if not utils.issubclass_safe(target_cls, definitions.ABCModel): + raise TypeError( + f"Target `{target_cls}` is not a {definitions.ABCModel}." + ) + + yield target_cls + finally: + _cleanup_module_imports(modules_before, modules_after, resolved_module_path) + + +def _cleanup_module_imports( + modules_before: set[str], modules_after: set[str], module_path: pathlib.Path +): + modules_diff = modules_after - modules_before + # Apparently torch import leaves some side effects that cannot be reverted + # by deleting the modules and would lead to a crash when another import + # is attempted. Since torch is a common lib, we make this explicit special + # case and just leave those modules. + # TODO: this seems still brittle and other modules might cause similar problems. + # it would be good to find a more principled solution. + modules_to_delete = { + s for s in modules_diff if not (s.startswith("torch.") or s == "torch") + } + if torch_modules := modules_diff - modules_to_delete: + logging.debug(f"Keeping torch modules after import context: {torch_modules}") + + logging.debug(f"Deleting modules when exiting import context: {modules_to_delete}") + for mod in modules_to_delete: + del sys.modules[mod] + try: + sys.path.remove(str(module_path.parent)) + except ValueError: # In case the value was already removed for whatever reason. + pass diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index ecb1be35a..0b9d37ca7 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -122,6 +122,18 @@ def __init_with_arg_check__(self, *args, **kwargs): cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign] +class ModelBase(definitions.ABCModel): + """Base class for all singular truss models. + + Inheriting from this class adds validations to make sure subclasses adhere to the + truss model pattern. + """ + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + framework.validate_base_model(cls) + + @overload def mark_entrypoint( cls_or_chain_name: Type[framework.ChainletT], diff --git a/truss/cli/cli.py b/truss/cli/cli.py index a2ff6fec3..6f6b66e3e 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -4,6 +4,7 @@ import os import sys import time +import traceback import warnings from functools import wraps from pathlib import Path @@ -1132,11 +1133,32 @@ def push( TARGET_DIRECTORY: A Truss directory. If none, use current directory. """ + from truss_chains import framework + from truss_chains.deployment.code_gen import write_truss_config_yaml + if not remote: remote = inquire_remote_name(RemoteFactory.get_available_config_names()) remote_provider = RemoteFactory.create(remote=remote) + try: + # Check whether the model file extends our new base class type, if so write + # the config file so _get_truss_from_directory will pick it up + target_path = Path(target_directory) + with framework.import_model_target( + target_path / "model/model.py" + ) as entrypoint_cls: + write_truss_config_yaml( + target_path, + entrypoint_cls.remote_config, + {}, + entrypoint_cls.remote_config.name, + False, # use_local_chains_src + ) + except Exception as e: + print(traceback.print_exc()) + raise (e) + tr = _get_truss_from_directory(target_directory=target_directory) model_name = model_name or tr.spec.config.model_name From 6d9c698ad779ccbc3a1d38094771fe318a4c5caf Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Wed, 15 Jan 2025 19:34:12 +0000 Subject: [PATCH 2/5] Production refactors --- truss-chains/truss_chains/definitions.py | 7 +- .../truss_chains/deployment/code_gen.py | 16 +- truss-chains/truss_chains/framework.py | 231 ++++++------------ truss-chains/truss_chains/public_api.py | 2 +- truss/cli/cli.py | 33 +-- truss/remote/baseten/core.py | 2 +- truss/remote/baseten/remote.py | 9 +- truss/truss_handle/build.py | 34 +++ truss/truss_handle/truss_handle.py | 2 +- 9 files changed, 137 insertions(+), 199 deletions(-) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 52b683956..93ff87a88 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -542,8 +542,11 @@ def name(cls) -> str: def display_name(cls) -> str: return cls.remote_config.name or cls.name - @abc.abstractmethod - def predict(self, request: Any) -> Any: ... + # Cannot add this abstract method to API, because we want to allow arbitrary + # arg/kwarg names and specifying any function signature here would give type errors + # @abc.abstractmethod + # def predict(self, *args, **kwargs) -> Any: + # ... class TypeDescriptor(SafeModelNonSerializable): diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index 552e36a33..53db5c998 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -651,10 +651,10 @@ def _inplace_fill_base_image( def write_truss_config_yaml( chainlet_dir: pathlib.Path, chains_config: definitions.RemoteConfig, - chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], model_name: str, - use_local_chains_src: bool, -) -> truss_config.TrussConfig: + chainlet_to_service: Mapping[str, definitions.ServiceDescriptor] = {}, + use_local_chains_src: bool = False, +): """Generate a truss config for a Chainlet.""" config = truss_config.TrussConfig() config.model_name = model_name @@ -731,11 +731,11 @@ def gen_truss_chainlet( f"in `{chainlet_dir}`." ) write_truss_config_yaml( - chainlet_dir, - chainlet_descriptor.chainlet_cls.remote_config, - dep_services, - model_name, - use_local_chains_src, + chainlet_dir=chainlet_dir, + chains_config=chainlet_descriptor.chainlet_cls.remote_config, + model_name=model_name, + chainlet_to_service=dep_services, + use_local_chains_src=use_local_chains_src, ) # This assumes all imports are absolute w.r.t chain root (or site-packages). truss_path.copy_tree_path( diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index e293f6492..04093478a 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -14,11 +14,13 @@ import sys import types import warnings +from importlib.abc import Loader from typing import ( Any, Callable, Iterable, Iterator, + Literal, Mapping, MutableMapping, Optional, @@ -132,7 +134,9 @@ def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation): ) -def raise_validation_errors() -> None: +def raise_validation_errors( + truss_type: Literal["Chainlet", "Model"] = "Chainlet", +) -> None: """Raises validation errors as combined ``ChainsUsageError``""" if _global_error_collector.has_errors: error_msg = _global_error_collector.format_errors() @@ -143,7 +147,7 @@ def raise_validation_errors() -> None: ) _global_error_collector.clear() # Clear errors so `atexit` won't display them raise definitions.ChainsUsageError( - f"The Chainlet definitions contain {errors_count}:\n{error_msg}" + f"The {truss_type} definitions contain {errors_count}:\n{error_msg}" ) @@ -725,34 +729,25 @@ def _validate_dependencies( return dependencies -def _validate_chainlet_cls( - cls: Type[definitions.ABCChainlet], location: _ErrorLocation -) -> None: - if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): - _collect_error( - f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " - f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}" - f"(...)`. Missing for `{cls}`.", - _ErrorKind.MISSING_API_ERROR, - location, - ) - return - +def _validate_remote_config( + cls: Union[Type[definitions.ABCChainlet], Type[definitions.ABCModel]], + truss_type: Literal["Chainlet", "Model"], + location: _ErrorLocation, +): if not isinstance( remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME), definitions.RemoteConfig, ): _collect_error( - f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " + f"{truss_type} must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " f"for `{cls}`.", _ErrorKind.TYPE_ERROR, location, ) - return -def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: +def validate_and_register_chain(cls: Type[definitions.ABCChainlet]) -> None: """Note that validation errors will only be collected, not raised, and Chainlets. with issues, are still added to the registry. Use `raise_validation_errors` to assert all Chainlets are valid and before performing operations that depend on @@ -761,7 +756,7 @@ def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: line = inspect.getsourcelines(cls)[1] location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__) - _validate_chainlet_cls(cls, location) + _validate_remote_config(cls, "Chainlet", location) init_validator = _ChainletInitValidator(cls, location) chainlet_descriptor = definitions.ChainletAPIDescriptor( chainlet_cls=cls, @@ -777,35 +772,10 @@ def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: def validate_base_model(cls: Type[definitions.ABCModel]) -> None: - # NB(nikhil): Following lines cause ERROR TypeError: is a built-in class - # src_path = os.path.abspath(inspect.getfile(cls)) - # line = inspect.getsourcelines(cls)[1] - location = _ErrorLocation(src_path="model.py", line=10, chainlet_name=cls.__name__) - - # NB(nikhil): This seems to pass even when my definition doesn't have remote_config, likely - # pulling from default - if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): - _collect_error( - f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " - f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}" - f"(...)`. Missing for `{cls}`.", - _ErrorKind.MISSING_API_ERROR, - location, - ) - return - - if not isinstance( - remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME), - definitions.RemoteConfig, - ): - _collect_error( - f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " - f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " - f"for `{cls}`.", - _ErrorKind.TYPE_ERROR, - location, - ) - return + src_path = os.path.abspath(inspect.getfile(cls)) + line = inspect.getsourcelines(cls)[1] + location = _ErrorLocation(src_path=src_path, line=line) + _validate_remote_config(cls, "Model", location) # Dependency-Injection / Registry ###################################################### @@ -1203,48 +1173,16 @@ def _get_entrypoint_chainlets(symbols) -> set[Type[definitions.ABCChainlet]]: def import_target( module_path: pathlib.Path, target_name: Optional[str] ) -> Iterator[Type[definitions.ABCChainlet]]: - """The context manager ensures that modules imported by the chain and - Chainlets registered in ``_global_chainlet_registry`` are removed upon exit. - - I.e. aiming at making the import idempotent for common usages, although there could - be additional side effects not accounted for by this implementation.""" resolved_module_path = pathlib.Path(module_path).resolve() - module_name = resolved_module_path.stem # Use the file's name as the module name - if not os.path.isfile(resolved_module_path): - raise ImportError( - f"`{resolved_module_path}` is not a file. You must point to a python file where " - "the entrypoint Chainlet is defined." - ) - - import_error_msg = f"Could not import `{resolved_module_path}`. Check path." - spec = importlib.util.spec_from_file_location(module_name, resolved_module_path) - if not spec: - raise ImportError(import_error_msg) - if not spec.loader: - raise ImportError(import_error_msg) - - module = importlib.util.module_from_spec(spec) - module.__file__ = str(resolved_module_path) - # Since the framework depends on tracking the source files via `inspect` and this - # depends on the modules bein properly registered in `sys.modules`, we have to - # manually do this here (because importlib does not do it automatically). This - # registration has to stay at least until the push command has finished. - if module_name in sys.modules: - raise ImportError( - f"{import_error_msg} There is already a module in `sys.modules` " - f"with name `{module_name}`. Overwriting that value is unsafe. " - "Try renaming your source file." - ) + module, loader = _load_module(module_path, "Chainlet") modules_before = set(sys.modules.keys()) - sys.modules[module_name] = module - # Add path for making absolute imports relative to the source_module's dir. - sys.path.insert(0, str(resolved_module_path.parent)) + modules_after = set() + chainlets_before = _global_chainlet_registry.get_chainlet_names() chainlets_after = set() - modules_after = set() try: try: - spec.loader.exec_module(module) + loader.exec_module(module) raise_validation_errors() finally: modules_after = set(sys.modules.keys()) @@ -1285,57 +1223,67 @@ def import_target( yield target_cls finally: + _cleanup_module_imports(modules_before, modules_after, resolved_module_path) for chainlet_name in chainlets_after - chainlets_before: _global_chainlet_registry.unregister_chainlet(chainlet_name) - modules_diff = modules_after - modules_before - # Apparently torch import leaves some side effects that cannot be reverted - # by deleting the modules and would lead to a crash when another import - # is attempted. Since torch is a common lib, we make this explicit special - # case and just leave those modules. - # TODO: this seems still brittle and other modules might cause similar problems. - # it would be good to find a more principled solution. - modules_to_delete = { - s for s in modules_diff if not (s.startswith("torch.") or s == "torch") - } - if torch_modules := modules_diff - modules_to_delete: - logging.debug( - f"Keeping torch modules after import context: {torch_modules}" - ) - - logging.debug( - f"Deleting modules when exiting import context: {modules_to_delete}" - ) - for mod in modules_to_delete: - del sys.modules[mod] - try: - sys.path.remove(str(resolved_module_path.parent)) - except ValueError: # In case the value was already removed for whatever reason. - pass - -# NB(nikhil): mainly taken from above, but with some dependency logic removed @contextlib.contextmanager def import_model_target( module_path: pathlib.Path, ) -> Iterator[Type[definitions.ABCModel]]: resolved_module_path = pathlib.Path(module_path).resolve() - module_name = resolved_module_path.stem # Use the file's name as the module name - if not os.path.isfile(resolved_module_path): + module, loader = _load_module(resolved_module_path, "Model") + modules_before = set(sys.modules.keys()) + modules_after = set() + try: + try: + loader.exec_module(module) + raise_validation_errors("Model") + finally: + modules_after = set(sys.modules.keys()) + + module_vars = (getattr(module, name) for name in dir(module)) + models: set[Type[definitions.ABCModel]] = { + sym + for sym in module_vars + if utils.issubclass_safe(sym, definitions.ABCModel) + } + if len(models) == 0: + raise ValueError(f"No class in `{module_path}` extends `ModelBase`.") + + target_cls = utils.expect_one(models) + if not utils.issubclass_safe(target_cls, definitions.ABCModel): + raise TypeError(f"Target `{target_cls}` is not a {definitions.ABCModel}.") + + yield target_cls + finally: + _cleanup_module_imports(modules_before, modules_after, resolved_module_path) + + +def _load_module( + module_path: pathlib.Path, + truss_type: Literal["Chainlet", "Model"], +) -> tuple[types.ModuleType, Loader]: + """The context manager ensures that modules imported by the Model/Chain + are removed upon exit. + + I.e. aiming at making the import idempotent for common usages, although there could + be additional side effects not accounted for by this implementation.""" + module_name = module_path.stem # Use the file's name as the module name + if not os.path.isfile(module_path): raise ImportError( - f"`{resolved_module_path}` is not a file. You must point to a python file where " - "the Model is defined." + f"`{module_path}` is not a file. You must point to a python file where " + f"the entrypoint {truss_type} is defined." ) - import_error_msg = f"Could not import `{resolved_module_path}`. Check path." - spec = importlib.util.spec_from_file_location(module_name, resolved_module_path) - if not spec: - raise ImportError(import_error_msg) - if not spec.loader: + import_error_msg = f"Could not import `{module_path}`. Check path." + spec = importlib.util.spec_from_file_location(module_name, module_path) + if not spec or not spec.loader: raise ImportError(import_error_msg) module = importlib.util.module_from_spec(spec) - module.__file__ = str(resolved_module_path) + module.__file__ = str(module_path) # Since the framework depends on tracking the source files via `inspect` and this # depends on the modules bein properly registered in `sys.modules`, we have to # manually do this here (because importlib does not do it automatically). This @@ -1347,43 +1295,14 @@ def import_model_target( "Try renaming your source file." ) - modules_before = set(sys.modules.keys()) - try: - try: - spec.loader.exec_module(module) - raise_validation_errors() - finally: - modules_after = set(sys.modules.keys()) - - module_vars = (getattr(module, name) for name in dir(module)) - models: set[Type[definitions.ABCModel]] = [ - sym - for sym in module_vars - if utils.issubclass_safe(sym, definitions.ABCModel) - ] - if len(models) == 0: - raise ValueError( - "No `target_name` was specified and no class in " - f"`{module_path}` extends `ModelBase`." - ) - elif len(models) > 1: - raise ValueError( - "`target_name` was not specified and multiple classes in " - f"`{resolved_module_path}` extend `ModelBase`. Ensure " - "exactly one class serves as entrypoint; found classes: " - f"\n{list(cls.name for cls in models)}" - ) - target_cls = utils.expect_one(models) - if not utils.issubclass_safe(target_cls, definitions.ABCModel): - raise TypeError( - f"Target `{target_cls}` is not a {definitions.ABCModel}." - ) + sys.modules[module_name] = module + # Add path for making absolute imports relative to the source_module's dir. + sys.path.insert(0, str(module_path.parent)) - yield target_cls - finally: - _cleanup_module_imports(modules_before, modules_after, resolved_module_path) + return module, spec.loader +# Ensures the loaded system modules are restored to before we executed user defined code. def _cleanup_module_imports( modules_before: set[str], modules_after: set[str], module_path: pathlib.Path ): @@ -1403,7 +1322,3 @@ def _cleanup_module_imports( logging.debug(f"Deleting modules when exiting import context: {modules_to_delete}") for mod in modules_to_delete: del sys.modules[mod] - try: - sys.path.remove(str(module_path.parent)) - except ValueError: # In case the value was already removed for whatever reason. - pass diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 0b9d37ca7..2b476e562 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -107,7 +107,7 @@ def __init_subclass__(cls, **kwargs) -> None: # Each sub-class has own, isolated metadata, e.g. we don't want # `mark_entrypoint` to propagate to subclasses. cls.meta_data = definitions.ChainletMetadata() - framework.validate_and_register_class(cls) # Errors are collected, not raised! + framework.validate_and_register_chain(cls) # Errors are collected, not raised! # For default init (from `object`) we don't need to check anything. if cls.has_custom_init(): original_init = cls.__init__ diff --git a/truss/cli/cli.py b/truss/cli/cli.py index 6f6b66e3e..5494fc0e7 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -4,7 +4,6 @@ import os import sys import time -import traceback import warnings from functools import wraps from pathlib import Path @@ -51,7 +50,10 @@ ) from truss.truss_handle.build import cleanup as _cleanup from truss.truss_handle.build import init as _init -from truss.truss_handle.build import load +from truss.truss_handle.build import ( + load, + load_from_code_config, +) from truss.util import docker from truss.util.log_utils import LogInterceptor @@ -1133,32 +1135,11 @@ def push( TARGET_DIRECTORY: A Truss directory. If none, use current directory. """ - from truss_chains import framework - from truss_chains.deployment.code_gen import write_truss_config_yaml if not remote: remote = inquire_remote_name(RemoteFactory.get_available_config_names()) remote_provider = RemoteFactory.create(remote=remote) - - try: - # Check whether the model file extends our new base class type, if so write - # the config file so _get_truss_from_directory will pick it up - target_path = Path(target_directory) - with framework.import_model_target( - target_path / "model/model.py" - ) as entrypoint_cls: - write_truss_config_yaml( - target_path, - entrypoint_cls.remote_config, - {}, - entrypoint_cls.remote_config.name, - False, # use_local_chains_src - ) - except Exception as e: - print(traceback.print_exc()) - raise (e) - tr = _get_truss_from_directory(target_directory=target_directory) model_name = model_name or tr.spec.config.model_name @@ -1359,7 +1340,11 @@ def _get_truss_from_directory(target_directory: Optional[str] = None): """Gets Truss from directory. If none, use the current directory""" if target_directory is None: target_directory = os.getcwd() - return load(target_directory) + if not os.path.isfile(target_directory): + return load(target_directory) + # NB(nikhil): if target_directory points to a specific file, assume they are using + # the Python driven DX for configuring a truss + return load_from_code_config(Path(target_directory)) truss_cli.add_command(container) diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 20a90420a..b1c2f549b 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -280,7 +280,7 @@ def archive_truss( Returns: A file-like object containing the tar file """ - truss_dir = truss_handle._spec.truss_dir + truss_dir = truss_handle._truss_dir # check for a truss_ignore file and read the ignore patterns if it exists ignore_patterns = load_trussignore_patterns_from_truss_dir(truss_dir) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 660b671e2..ef16cea6c 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -139,9 +139,10 @@ def _prepare_push( if model_name.isspace(): raise ValueError("Model name cannot be empty") - gathered_truss = TrussHandle(truss_handle.gather()) + if truss_handle.is_scattered(): + truss_handle = TrussHandle(truss_handle.gather()) - if gathered_truss.spec.model_server != ModelServer.TrussServer: + if truss_handle.spec.model_server != ModelServer.TrussServer: publish = True if promote: @@ -176,10 +177,10 @@ def _prepare_push( if model_id is not None and disable_truss_download: raise ValueError("disable-truss-download can only be used for new models") - temp_file = archive_truss(gathered_truss, progress_bar) + temp_file = archive_truss(truss_handle, progress_bar) s3_key = upload_truss(self._api, temp_file, progress_bar) encoded_config_str = base64_encoded_json_str( - gathered_truss._spec._config.to_dict() + truss_handle._spec._config.to_dict() ) validate_truss_config(self._api, encoded_config_str) diff --git a/truss/truss_handle/build.py b/truss/truss_handle/build.py index 5eae8821d..aa5cdd9cd 100644 --- a/truss/truss_handle/build.py +++ b/truss/truss_handle/build.py @@ -1,7 +1,9 @@ import logging import os +import shutil import sys from pathlib import Path +from tempfile import gettempdir from typing import List, Optional import yaml @@ -105,6 +107,38 @@ def load(truss_directory: str) -> TrussHandle: return TrussHandle(Path(truss_directory)) +# NB(nikhil): Generates a TrussHandle whose spec points to a generated +# directory that contains data dumped from the configuration in code. +def load_from_code_config(model_file: Path) -> TrussHandle: + # These imports are delayed, to handle pydantic v1 envs gracefully. + from truss_chains import framework + from truss_chains.deployment.code_gen import write_truss_config_yaml + + # TODO(nikhil): Improve detection of directory structure, since right now + # we assume the traditional model/model.py format. + root_dir = model_file.absolute().parent.parent + with framework.import_model_target(model_file) as entrypoint_cls: + tmp_dir = _copy_to_generated_dir(root_dir) + write_truss_config_yaml( + chainlet_dir=tmp_dir, + chains_config=entrypoint_cls.remote_config, + model_name=entrypoint_cls.display_name, + ) + + return TrussHandle(truss_dir=tmp_dir) + + +def _copy_to_generated_dir(root_dir: Path) -> Path: + # These imports are delayed, to handle pydantic v1 envs gracefully. + from truss_chains.definitions import GENERATED_CODE_DIR + + tmp_dir = Path(gettempdir()) / GENERATED_CODE_DIR + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + shutil.copytree(root_dir, tmp_dir) + return tmp_dir + + def cleanup() -> None: """ Cleans up .truss directory. diff --git a/truss/truss_handle/truss_handle.py b/truss/truss_handle/truss_handle.py index f80f55ab9..c624d6b1b 100644 --- a/truss/truss_handle/truss_handle.py +++ b/truss/truss_handle/truss_handle.py @@ -106,7 +106,7 @@ def wait(self): class TrussHandle: def __init__(self, truss_dir: Path, validate: bool = True) -> None: self._truss_dir = truss_dir - self._spec = TrussSpec(truss_dir) + self._spec = TrussSpec(self._truss_dir) self._hash_for_mod_time: Optional[Tuple[float, str]] = None if validate: self.validate() From bd313b57dd5d23d75eaef4150f994ca1969e19b0 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Fri, 17 Jan 2025 12:53:30 -0500 Subject: [PATCH 3/5] Move code generation dependency to chains --- truss-chains/truss_chains/definitions.py | 24 ++++-------- .../truss_chains/deployment/code_gen.py | 9 +++-- .../deployment/deployment_client.py | 8 +--- truss-chains/truss_chains/framework.py | 37 ++++++++++++++++--- truss-chains/truss_chains/public_api.py | 1 + .../truss_chains/remote_chainlet/utils.py | 3 ++ truss/templates/server/model_wrapper.py | 3 ++ truss/truss_handle/build.py | 29 +-------------- 8 files changed, 53 insertions(+), 61 deletions(-) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 93ff87a88..810df19f3 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -512,6 +512,10 @@ class ABCChainlet(abc.ABC): def has_custom_init(cls) -> bool: return cls.__init__ is not object.__init__ + @classmethod + def truss_type(cls) -> str: + return "Chainlet" + @classproperty @classmethod def name(cls) -> str: @@ -529,24 +533,10 @@ def display_name(cls) -> str: # ... -class ABCModel(abc.ABC): - remote_config: ClassVar[RemoteConfig] = RemoteConfig() - - @classproperty +class ABCModel(ABCChainlet): @classmethod - def name(cls) -> str: - return cls.__name__ - - @classproperty - @classmethod - def display_name(cls) -> str: - return cls.remote_config.name or cls.name - - # Cannot add this abstract method to API, because we want to allow arbitrary - # arg/kwarg names and specifying any function signature here would give type errors - # @abc.abstractmethod - # def predict(self, *args, **kwargs) -> Any: - # ... + def truss_type(cls) -> str: + return "Chainlet" class TypeDescriptor(SafeModelNonSerializable): diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index 53db5c998..080d003f4 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -31,6 +31,7 @@ import shutil import subprocess import sys +import tempfile import textwrap from typing import Any, Iterable, Mapping, Optional, get_args, get_origin @@ -711,11 +712,10 @@ def write_truss_config_yaml( def gen_truss_chainlet( chain_root: pathlib.Path, - gen_root: pathlib.Path, chain_name: str, chainlet_descriptor: definitions.ChainletAPIDescriptor, - model_name: str, - use_local_chains_src: bool, + model_name: Optional[str] = None, + use_local_chains_src: bool = False, ) -> pathlib.Path: # Filter needed services and customize options. dep_services = {} @@ -725,6 +725,7 @@ def gen_truss_chainlet( display_name=dep.display_name, options=dep.options, ) + gen_root = pathlib.Path(tempfile.gettempdir()) chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root) logging.info( f"Code generation for Chainlet `{chainlet_descriptor.name}` " @@ -733,7 +734,7 @@ def gen_truss_chainlet( write_truss_config_yaml( chainlet_dir=chainlet_dir, chains_config=chainlet_descriptor.chainlet_cls.remote_config, - model_name=model_name, + model_name=model_name or chain_name, chainlet_to_service=dep_services, use_local_chains_src=use_local_chains_src, ) diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index 9f4290250..ba954d661 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -4,7 +4,6 @@ import json import logging import pathlib -import tempfile import textwrap import traceback import uuid @@ -138,10 +137,8 @@ class _ChainSourceGenerator: def __init__( self, options: definitions.PushOptions, - gen_root: pathlib.Path, ) -> None: self._options = options - self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir()) @property def _use_local_chains_src(self) -> bool: @@ -175,7 +172,6 @@ def generate_chainlet_artifacts( chainlet_dir = code_gen.gen_truss_chainlet( chain_root, - self._gen_root, self._options.chain_name, chainlet_descriptor, model_name, @@ -205,11 +201,10 @@ def generate_chainlet_artifacts( def push( entrypoint: Type[definitions.ABCChainlet], options: definitions.PushOptions, - gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()), progress_bar: Optional[Type["progress.Progress"]] = None, ) -> Optional[ChainService]: entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator( - options, gen_root + options ).generate_chainlet_artifacts( entrypoint, ) @@ -632,7 +627,6 @@ def _code_gen_and_patch_thread( # TODO: Maybe try-except code_gen errors explicitly. chainlet_dir = code_gen.gen_truss_chainlet( self._chain_root, - pathlib.Path(tempfile.gettempdir()), self._deployed_chain_name, descr, self._chainlet_data[descr.display_name].oracle_name, diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 04093478a..5135f2339 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -34,9 +34,11 @@ ) import pydantic +from truss.truss_handle.truss_handle import TrussHandle from typing_extensions import ParamSpec from truss_chains import definitions, utils +from truss_chains.deployment import code_gen _SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None, pydantic.BaseModel} _SIMPLE_CONTAINERS = {list, dict} @@ -730,8 +732,7 @@ def _validate_dependencies( def _validate_remote_config( - cls: Union[Type[definitions.ABCChainlet], Type[definitions.ABCModel]], - truss_type: Literal["Chainlet", "Model"], + cls: Type[definitions.ABCChainlet], location: _ErrorLocation, ): if not isinstance( @@ -739,7 +740,7 @@ def _validate_remote_config( definitions.RemoteConfig, ): _collect_error( - f"{truss_type} must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " + f"{cls.truss_type}s must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " f"for `{cls}`.", _ErrorKind.TYPE_ERROR, @@ -756,7 +757,7 @@ def validate_and_register_chain(cls: Type[definitions.ABCChainlet]) -> None: line = inspect.getsourcelines(cls)[1] location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__) - _validate_remote_config(cls, "Chainlet", location) + _validate_remote_config(cls, location) init_validator = _ChainletInitValidator(cls, location) chainlet_descriptor = definitions.ChainletAPIDescriptor( chainlet_cls=cls, @@ -775,7 +776,16 @@ def validate_base_model(cls: Type[definitions.ABCModel]) -> None: src_path = os.path.abspath(inspect.getfile(cls)) line = inspect.getsourcelines(cls)[1] location = _ErrorLocation(src_path=src_path, line=line) - _validate_remote_config(cls, "Model", location) + _validate_remote_config(cls, location) + + base_model_descriptor = definitions.ChainletAPIDescriptor( + chainlet_cls=cls, + dependencies={}, + has_context=False, + endpoint=_validate_and_describe_endpoint(cls, location), + src_path=src_path, + ) + _global_chainlet_registry.register_chainlet(base_model_descriptor) # Dependency-Injection / Registry ###################################################### @@ -1322,3 +1332,20 @@ def _cleanup_module_imports( logging.debug(f"Deleting modules when exiting import context: {modules_to_delete}") for mod in modules_to_delete: del sys.modules[mod] + + +# NB(nikhil): Generates a TrussHandle whose spec points to a generated +# directory that contains data dumped from the configuration in code. +def truss_handle_from_code_config(model_file: pathlib.Path) -> TrussHandle: + # TODO(nikhil): Improve detection of directory structure, since right now + # we assume a flat structure + root_dir = model_file.absolute().parent + with import_model_target(model_file) as entrypoint_cls: + descriptor = _global_chainlet_registry.get_descriptor(entrypoint_cls) + generated_dir = code_gen.gen_truss_chainlet( + chain_root=root_dir, + chain_name=entrypoint_cls.display_name, + chainlet_descriptor=descriptor, + ) + + return TrussHandle(truss_dir=generated_dir) diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 2b476e562..979fe8782 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -131,6 +131,7 @@ class ModelBase(definitions.ABCModel): def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) + cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True) framework.validate_base_model(cls) diff --git a/truss-chains/truss_chains/remote_chainlet/utils.py b/truss-chains/truss_chains/remote_chainlet/utils.py index ebb6f6244..1ce1c5e08 100644 --- a/truss-chains/truss_chains/remote_chainlet/utils.py +++ b/truss-chains/truss_chains/remote_chainlet/utils.py @@ -31,6 +31,9 @@ def populate_chainlet_service_predict_urls( chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], ) -> Mapping[str, definitions.DeployedServiceDescriptor]: chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {} + # If there are no dependencies of this chainlet, no need to derive dynamic URLs + if len(chainlet_to_service) == 0: + return chainlet_to_deployed_service dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync( definitions.DYNAMIC_CHAINLET_CONFIG_KEY diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index cff58f2ca..2e459c878 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -406,6 +406,9 @@ def _load_impl(self): if not spec.loader: raise ImportError(import_error_msg) module = importlib.util.module_from_spec(spec) + module.__file__ = str(module_path) + sys.modules[module_name] = module + sys.path.insert(0, str(module_path.parent)) try: spec.loader.exec_module(module) except ImportError as e: diff --git a/truss/truss_handle/build.py b/truss/truss_handle/build.py index aa5cdd9cd..ccb3978e7 100644 --- a/truss/truss_handle/build.py +++ b/truss/truss_handle/build.py @@ -1,9 +1,7 @@ import logging import os -import shutil import sys from pathlib import Path -from tempfile import gettempdir from typing import List, Optional import yaml @@ -107,36 +105,11 @@ def load(truss_directory: str) -> TrussHandle: return TrussHandle(Path(truss_directory)) -# NB(nikhil): Generates a TrussHandle whose spec points to a generated -# directory that contains data dumped from the configuration in code. def load_from_code_config(model_file: Path) -> TrussHandle: # These imports are delayed, to handle pydantic v1 envs gracefully. from truss_chains import framework - from truss_chains.deployment.code_gen import write_truss_config_yaml - # TODO(nikhil): Improve detection of directory structure, since right now - # we assume the traditional model/model.py format. - root_dir = model_file.absolute().parent.parent - with framework.import_model_target(model_file) as entrypoint_cls: - tmp_dir = _copy_to_generated_dir(root_dir) - write_truss_config_yaml( - chainlet_dir=tmp_dir, - chains_config=entrypoint_cls.remote_config, - model_name=entrypoint_cls.display_name, - ) - - return TrussHandle(truss_dir=tmp_dir) - - -def _copy_to_generated_dir(root_dir: Path) -> Path: - # These imports are delayed, to handle pydantic v1 envs gracefully. - from truss_chains.definitions import GENERATED_CODE_DIR - - tmp_dir = Path(gettempdir()) / GENERATED_CODE_DIR - if tmp_dir.exists(): - shutil.rmtree(tmp_dir) - shutil.copytree(root_dir, tmp_dir) - return tmp_dir + return framework.truss_handle_from_code_config(model_file) def cleanup() -> None: From f4b327708143a284aad00af348b593eb1baabbd2 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Fri, 17 Jan 2025 19:11:21 -0500 Subject: [PATCH 4/5] Add unit test --- truss-chains/tests/test_e2e.py | 25 +++++++++++++++++++ .../tests/traditional_truss/truss_model.py | 20 +++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 truss-chains/tests/traditional_truss/truss_model.py diff --git a/truss-chains/tests/test_e2e.py b/truss-chains/tests/test_e2e.py index 86da35eff..c572f182e 100644 --- a/truss-chains/tests/test_e2e.py +++ b/truss-chains/tests/test_e2e.py @@ -6,6 +6,7 @@ import pytest import requests from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all +from truss.truss_handle.build import load_from_code_config from truss_chains import definitions, framework, public_api, utils from truss_chains.deployment import deployment_client @@ -270,3 +271,27 @@ async def test_timeout(): assert re.match( sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE ), sync_error_str + + +@pytest.mark.integration +def test_traditional_truss(): + with ensure_kill_all(): + chain_root = TEST_ROOT / "traditional_truss" / "truss_model.py" + truss_handle = load_from_code_config(chain_root) + + assert truss_handle.spec.config.resources.cpu == "4" + assert truss_handle.spec.config.model_name == "OverridePassthroughModelName" + + port = utils.get_free_port() + truss_handle.docker_run( + local_port=port, + detach=True, + network="host", + ) + + response = requests.post( + f"http://localhost:{port}/v1/models/model:predict", + json={"call_count_increment": 5}, + ) + assert response.status_code == 200 + assert response.json() == 5 diff --git a/truss-chains/tests/traditional_truss/truss_model.py b/truss-chains/tests/traditional_truss/truss_model.py new file mode 100644 index 000000000..3d5d03108 --- /dev/null +++ b/truss-chains/tests/traditional_truss/truss_model.py @@ -0,0 +1,20 @@ +import truss_chains as chains + + +class PassthroughModel(chains.ModelBase): + remote_config: chains.RemoteConfig = chains.RemoteConfig( # type: ignore + compute=chains.Compute(4, "1Gi"), + name="OverridePassthroughModelName", + docker_image=chains.DockerImage( + pip_requirements=[ + "truss==0.9.59rc2", + ] + ), + ) + + def __init__(self, **kwargs): + self._call_count = 0 + + async def run_remote(self, call_count_increment: int) -> int: + self._call_count += call_count_increment + return self._call_count From c5cad3d06774475ec0c9f3152c03c044f2797d56 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Fri, 17 Jan 2025 19:52:22 -0500 Subject: [PATCH 5/5] Consolidate more code --- .../tests/traditional_truss/truss_model.py | 2 +- truss-chains/truss_chains/definitions.py | 10 --- .../truss_chains/deployment/code_gen.py | 4 +- truss-chains/truss_chains/framework.py | 67 ++----------------- truss-chains/truss_chains/public_api.py | 4 +- truss/templates/server/model_wrapper.py | 3 - 6 files changed, 12 insertions(+), 78 deletions(-) diff --git a/truss-chains/tests/traditional_truss/truss_model.py b/truss-chains/tests/traditional_truss/truss_model.py index 3d5d03108..12cf5bb85 100644 --- a/truss-chains/tests/traditional_truss/truss_model.py +++ b/truss-chains/tests/traditional_truss/truss_model.py @@ -12,7 +12,7 @@ class PassthroughModel(chains.ModelBase): ), ) - def __init__(self, **kwargs): + def __init__(self): self._call_count = 0 async def run_remote(self, call_count_increment: int) -> int: diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 810df19f3..fd7908c99 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -512,10 +512,6 @@ class ABCChainlet(abc.ABC): def has_custom_init(cls) -> bool: return cls.__init__ is not object.__init__ - @classmethod - def truss_type(cls) -> str: - return "Chainlet" - @classproperty @classmethod def name(cls) -> str: @@ -533,12 +529,6 @@ def display_name(cls) -> str: # ... -class ABCModel(ABCChainlet): - @classmethod - def truss_type(cls) -> str: - return "Chainlet" - - class TypeDescriptor(SafeModelNonSerializable): """For describing I/O types of Chainlets.""" diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index 080d003f4..cd80f9d66 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -652,9 +652,9 @@ def _inplace_fill_base_image( def write_truss_config_yaml( chainlet_dir: pathlib.Path, chains_config: definitions.RemoteConfig, + chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], model_name: str, - chainlet_to_service: Mapping[str, definitions.ServiceDescriptor] = {}, - use_local_chains_src: bool = False, + use_local_chains_src: bool, ): """Generate a truss config for a Chainlet.""" config = truss_config.TrussConfig() diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 5135f2339..87386b477 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -20,7 +20,6 @@ Callable, Iterable, Iterator, - Literal, Mapping, MutableMapping, Optional, @@ -136,9 +135,7 @@ def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation): ) -def raise_validation_errors( - truss_type: Literal["Chainlet", "Model"] = "Chainlet", -) -> None: +def raise_validation_errors() -> None: """Raises validation errors as combined ``ChainsUsageError``""" if _global_error_collector.has_errors: error_msg = _global_error_collector.format_errors() @@ -149,7 +146,7 @@ def raise_validation_errors( ) _global_error_collector.clear() # Clear errors so `atexit` won't display them raise definitions.ChainsUsageError( - f"The {truss_type} definitions contain {errors_count}:\n{error_msg}" + f"The Chainlet definitions contain {errors_count}:\n{error_msg}" ) @@ -740,7 +737,7 @@ def _validate_remote_config( definitions.RemoteConfig, ): _collect_error( - f"{cls.truss_type}s must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " + f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " f"for `{cls}`.", _ErrorKind.TYPE_ERROR, @@ -772,22 +769,6 @@ def validate_and_register_chain(cls: Type[definitions.ABCChainlet]) -> None: _global_chainlet_registry.register_chainlet(chainlet_descriptor) -def validate_base_model(cls: Type[definitions.ABCModel]) -> None: - src_path = os.path.abspath(inspect.getfile(cls)) - line = inspect.getsourcelines(cls)[1] - location = _ErrorLocation(src_path=src_path, line=line) - _validate_remote_config(cls, location) - - base_model_descriptor = definitions.ChainletAPIDescriptor( - chainlet_cls=cls, - dependencies={}, - has_context=False, - endpoint=_validate_and_describe_endpoint(cls, location), - src_path=src_path, - ) - _global_chainlet_registry.register_chainlet(base_model_descriptor) - - # Dependency-Injection / Registry ###################################################### @@ -1181,10 +1162,10 @@ def _get_entrypoint_chainlets(symbols) -> set[Type[definitions.ABCChainlet]]: @contextlib.contextmanager def import_target( - module_path: pathlib.Path, target_name: Optional[str] + module_path: pathlib.Path, target_name: Optional[str] = None ) -> Iterator[Type[definitions.ABCChainlet]]: resolved_module_path = pathlib.Path(module_path).resolve() - module, loader = _load_module(module_path, "Chainlet") + module, loader = _load_module(module_path) modules_before = set(sys.modules.keys()) modules_after = set() @@ -1238,42 +1219,8 @@ def import_target( _global_chainlet_registry.unregister_chainlet(chainlet_name) -@contextlib.contextmanager -def import_model_target( - module_path: pathlib.Path, -) -> Iterator[Type[definitions.ABCModel]]: - resolved_module_path = pathlib.Path(module_path).resolve() - module, loader = _load_module(resolved_module_path, "Model") - modules_before = set(sys.modules.keys()) - modules_after = set() - try: - try: - loader.exec_module(module) - raise_validation_errors("Model") - finally: - modules_after = set(sys.modules.keys()) - - module_vars = (getattr(module, name) for name in dir(module)) - models: set[Type[definitions.ABCModel]] = { - sym - for sym in module_vars - if utils.issubclass_safe(sym, definitions.ABCModel) - } - if len(models) == 0: - raise ValueError(f"No class in `{module_path}` extends `ModelBase`.") - - target_cls = utils.expect_one(models) - if not utils.issubclass_safe(target_cls, definitions.ABCModel): - raise TypeError(f"Target `{target_cls}` is not a {definitions.ABCModel}.") - - yield target_cls - finally: - _cleanup_module_imports(modules_before, modules_after, resolved_module_path) - - def _load_module( module_path: pathlib.Path, - truss_type: Literal["Chainlet", "Model"], ) -> tuple[types.ModuleType, Loader]: """The context manager ensures that modules imported by the Model/Chain are removed upon exit. @@ -1284,7 +1231,7 @@ def _load_module( if not os.path.isfile(module_path): raise ImportError( f"`{module_path}` is not a file. You must point to a python file where " - f"the entrypoint {truss_type} is defined." + f"the entrypoint Chainlet is defined." ) import_error_msg = f"Could not import `{module_path}`. Check path." @@ -1340,7 +1287,7 @@ def truss_handle_from_code_config(model_file: pathlib.Path) -> TrussHandle: # TODO(nikhil): Improve detection of directory structure, since right now # we assume a flat structure root_dir = model_file.absolute().parent - with import_model_target(model_file) as entrypoint_cls: + with import_target(model_file) as entrypoint_cls: descriptor = _global_chainlet_registry.get_descriptor(entrypoint_cls) generated_dir = code_gen.gen_truss_chainlet( chain_root=root_dir, diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 979fe8782..9545c6945 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -122,7 +122,7 @@ def __init_with_arg_check__(self, *args, **kwargs): cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign] -class ModelBase(definitions.ABCModel): +class ModelBase(definitions.ABCChainlet): """Base class for all singular truss models. Inheriting from this class adds validations to make sure subclasses adhere to the @@ -132,7 +132,7 @@ class ModelBase(definitions.ABCModel): def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True) - framework.validate_base_model(cls) + framework.validate_and_register_chain(cls) @overload diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 2e459c878..cff58f2ca 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -406,9 +406,6 @@ def _load_impl(self): if not spec.loader: raise ImportError(import_error_msg) module = importlib.util.module_from_spec(spec) - module.__file__ = str(module_path) - sys.modules[module_name] = module - sys.path.insert(0, str(module_path.parent)) try: spec.loader.exec_module(module) except ImportError as e: