Skip to content

Commit

Permalink
POC for new model DX
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Jan 14, 2025
1 parent 3f411cd commit 5ecb0dd
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 15 deletions.
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand All @@ -50,6 +51,7 @@
"Assets",
"BasetenImage",
"ChainletBase",
"ModelBase",
"ChainletOptions",
"Compute",
"CustomImage",
Expand Down
17 changes: 17 additions & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,23 @@ def display_name(cls) -> str:
# ...


class ABCModel(abc.ABC):
remote_config: ClassVar[RemoteConfig] = RemoteConfig()

@classproperty
@classmethod
def name(cls) -> str:
return cls.__name__

@classproperty
@classmethod
def display_name(cls) -> str:
return cls.remote_config.name or cls.name

@abc.abstractmethod
def predict(self, request: Any) -> Any: ...


class TypeDescriptor(SafeModelNonSerializable):
"""For describing I/O types of Chainlets."""

Expand Down
5 changes: 2 additions & 3 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def _inplace_fill_base_image(
)


def _make_truss_config(
def write_truss_config_yaml(
chainlet_dir: pathlib.Path,
chains_config: definitions.RemoteConfig,
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
Expand Down Expand Up @@ -706,7 +706,6 @@ def _make_truss_config(
config.write_to_yaml_file(
chainlet_dir / serving_image_builder.CONFIG_FILE, verbose=True
)
return config


def gen_truss_chainlet(
Expand All @@ -730,7 +729,7 @@ def gen_truss_chainlet(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_make_truss_config(
write_truss_config_yaml(
chainlet_dir,
chainlet_descriptor.chainlet_cls.remote_config,
dep_services,
Expand Down
151 changes: 139 additions & 12 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,38 @@ def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None:
_global_chainlet_registry.register_chainlet(chainlet_descriptor)


def validate_base_model(cls: Type[definitions.ABCModel]) -> None:
# NB(nikhil): Following lines cause ERROR TypeError: <class 'model.Model'> is a built-in class
# src_path = os.path.abspath(inspect.getfile(cls))
# line = inspect.getsourcelines(cls)[1]
location = _ErrorLocation(src_path="model.py", line=10, chainlet_name=cls.__name__)

# NB(nikhil): This seems to pass even when my definition doesn't have remote_config, likely
# pulling from default
if not hasattr(cls, definitions.REMOTE_CONFIG_NAME):
_collect_error(
f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}"
f"(...)`. Missing for `{cls}`.",
_ErrorKind.MISSING_API_ERROR,
location,
)
return

if not isinstance(
remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME),
definitions.RemoteConfig,
):
_collect_error(
f"Models must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` "
f"for `{cls}`.",
_ErrorKind.TYPE_ERROR,
location,
)
return


# Dependency-Injection / Registry ######################################################


Expand Down Expand Up @@ -1177,23 +1209,23 @@ def import_target(
I.e. aiming at making the import idempotent for common usages, although there could
be additional side effects not accounted for by this implementation."""
module_path = pathlib.Path(module_path).resolve()
module_name = module_path.stem # Use the file's name as the module name
if not os.path.isfile(module_path):
resolved_module_path = pathlib.Path(module_path).resolve()
module_name = resolved_module_path.stem # Use the file's name as the module name
if not os.path.isfile(resolved_module_path):
raise ImportError(
f"`{module_path}` is not a file. You must point to a python file where "
f"`{resolved_module_path}` is not a file. You must point to a python file where "
"the entrypoint Chainlet is defined."
)

import_error_msg = f"Could not import `{module_path}`. Check path."
spec = importlib.util.spec_from_file_location(module_name, module_path)
import_error_msg = f"Could not import `{resolved_module_path}`. Check path."
spec = importlib.util.spec_from_file_location(module_name, resolved_module_path)
if not spec:
raise ImportError(import_error_msg)
if not spec.loader:
raise ImportError(import_error_msg)

module = importlib.util.module_from_spec(spec)
module.__file__ = str(module_path)
module.__file__ = str(resolved_module_path)
# Since the framework depends on tracking the source files via `inspect` and this
# depends on the modules bein properly registered in `sys.modules`, we have to
# manually do this here (because importlib does not do it automatically). This
Expand All @@ -1207,7 +1239,7 @@ def import_target(
modules_before = set(sys.modules.keys())
sys.modules[module_name] = module
# Add path for making absolute imports relative to the source_module's dir.
sys.path.insert(0, str(module_path.parent))
sys.path.insert(0, str(resolved_module_path.parent))
chainlets_before = _global_chainlet_registry.get_chainlet_names()
chainlets_after = set()
modules_after = set()
Expand All @@ -1224,7 +1256,7 @@ def import_target(
if not target_cls:
raise AttributeError(
f"Target Chainlet class `{target_name}` not found "
f"in `{module_path}`."
f"in `{resolved_module_path}`."
)
if not utils.issubclass_safe(target_cls, definitions.ABCChainlet):
raise TypeError(
Expand All @@ -1236,13 +1268,13 @@ def import_target(
if len(entrypoints) == 0:
raise ValueError(
"No `target_name` was specified and no Chainlet in "
"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag "
f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag "
"one Chainlet or provide the Chainlet class name."
)
elif len(entrypoints) > 1:
raise ValueError(
"`target_name` was not specified and multiple Chainlets in "
f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag "
f"`{resolved_module_path}` were tagged with `@chains.mark_entrypoint`. Tag "
"one Chainlet or provide the Chainlet class name. Found Chainlets: "
f"\n{list(cls.name for cls in entrypoints)}"
)
Expand Down Expand Up @@ -1278,6 +1310,101 @@ def import_target(
for mod in modules_to_delete:
del sys.modules[mod]
try:
sys.path.remove(str(module_path.parent))
sys.path.remove(str(resolved_module_path.parent))
except ValueError: # In case the value was already removed for whatever reason.
pass


# NB(nikhil): mainly taken from above, but with some dependency logic removed
@contextlib.contextmanager
def import_model_target(
module_path: pathlib.Path,
) -> Iterator[Type[definitions.ABCModel]]:
resolved_module_path = pathlib.Path(module_path).resolve()
module_name = resolved_module_path.stem # Use the file's name as the module name
if not os.path.isfile(resolved_module_path):
raise ImportError(
f"`{resolved_module_path}` is not a file. You must point to a python file where "
"the Model is defined."
)

import_error_msg = f"Could not import `{resolved_module_path}`. Check path."
spec = importlib.util.spec_from_file_location(module_name, resolved_module_path)
if not spec:
raise ImportError(import_error_msg)
if not spec.loader:
raise ImportError(import_error_msg)

module = importlib.util.module_from_spec(spec)
module.__file__ = str(resolved_module_path)
# Since the framework depends on tracking the source files via `inspect` and this
# depends on the modules bein properly registered in `sys.modules`, we have to
# manually do this here (because importlib does not do it automatically). This
# registration has to stay at least until the push command has finished.
if module_name in sys.modules:
raise ImportError(
f"{import_error_msg} There is already a module in `sys.modules` "
f"with name `{module_name}`. Overwriting that value is unsafe. "
"Try renaming your source file."
)

modules_before = set(sys.modules.keys())
try:
try:
spec.loader.exec_module(module)
raise_validation_errors()
finally:
modules_after = set(sys.modules.keys())

module_vars = (getattr(module, name) for name in dir(module))
models: set[Type[definitions.ABCModel]] = [
sym
for sym in module_vars
if utils.issubclass_safe(sym, definitions.ABCModel)
]
if len(models) == 0:
raise ValueError(
"No `target_name` was specified and no class in "
f"`{module_path}` extends `ModelBase`."
)
elif len(models) > 1:
raise ValueError(
"`target_name` was not specified and multiple classes in "
f"`{resolved_module_path}` extend `ModelBase`. Ensure "
"exactly one class serves as entrypoint; found classes: "
f"\n{list(cls.name for cls in models)}"
)
target_cls = utils.expect_one(models)
if not utils.issubclass_safe(target_cls, definitions.ABCModel):
raise TypeError(
f"Target `{target_cls}` is not a {definitions.ABCModel}."
)

yield target_cls
finally:
_cleanup_module_imports(modules_before, modules_after, resolved_module_path)


def _cleanup_module_imports(
modules_before: set[str], modules_after: set[str], module_path: pathlib.Path
):
modules_diff = modules_after - modules_before
# Apparently torch import leaves some side effects that cannot be reverted
# by deleting the modules and would lead to a crash when another import
# is attempted. Since torch is a common lib, we make this explicit special
# case and just leave those modules.
# TODO: this seems still brittle and other modules might cause similar problems.
# it would be good to find a more principled solution.
modules_to_delete = {
s for s in modules_diff if not (s.startswith("torch.") or s == "torch")
}
if torch_modules := modules_diff - modules_to_delete:
logging.debug(f"Keeping torch modules after import context: {torch_modules}")

logging.debug(f"Deleting modules when exiting import context: {modules_to_delete}")
for mod in modules_to_delete:
del sys.modules[mod]
try:
sys.path.remove(str(module_path.parent))
except ValueError: # In case the value was already removed for whatever reason.
pass
12 changes: 12 additions & 0 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def __init_with_arg_check__(self, *args, **kwargs):
cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]


class ModelBase(definitions.ABCModel):
"""Base class for all singular truss models.
Inheriting from this class adds validations to make sure subclasses adhere to the
truss model pattern.
"""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
framework.validate_base_model(cls)


@overload
def mark_entrypoint(
cls_or_chain_name: Type[framework.ChainletT],
Expand Down
22 changes: 22 additions & 0 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import time
import traceback
import warnings
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -1132,11 +1133,32 @@ def push(
TARGET_DIRECTORY: A Truss directory. If none, use current directory.
"""
from truss_chains import framework
from truss_chains.deployment.code_gen import write_truss_config_yaml

if not remote:
remote = inquire_remote_name(RemoteFactory.get_available_config_names())

remote_provider = RemoteFactory.create(remote=remote)

try:
# Check whether the model file extends our new base class type, if so write
# the config file so _get_truss_from_directory will pick it up
target_path = Path(target_directory)
with framework.import_model_target(
target_path / "model/model.py"
) as entrypoint_cls:
write_truss_config_yaml(
target_path,
entrypoint_cls.remote_config,
{},
entrypoint_cls.remote_config.name,
False, # use_local_chains_src
)
except Exception as e:
print(traceback.print_exc())
raise (e)

tr = _get_truss_from_directory(target_directory=target_directory)

model_name = model_name or tr.spec.config.model_name
Expand Down

0 comments on commit 5ecb0dd

Please sign in to comment.