Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Jan 29, 2025
1 parent 7f7d449 commit 2568a20
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class PassthroughModel:
class ClassWithoutModelInheritance:
def __init__(self):
self._call_count = 0

Expand Down
6 changes: 3 additions & 3 deletions truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import contextlib
import logging
import pathlib
import re
from pathlib import Path
from typing import AsyncIterator, Iterator, List

import pydantic
Expand All @@ -13,7 +13,7 @@

utils.setup_dev_logging(logging.DEBUG)

TEST_ROOT = Path(__file__).parent.resolve()
TEST_ROOT = pathlib.Path(__file__).parent.resolve()

# Assert that naive chainlet initialization is detected and prevented. #################

Expand Down Expand Up @@ -673,7 +673,7 @@ async def run_remote(self) -> str:


def test_import_model_requires_entrypoint():
model_src = TEST_ROOT / "import" / "standalone_without_entrypoint.py"
model_src = TEST_ROOT / "import" / "model_without_inheritance.py"
match = r"No Model class in `.+` inherits from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
Expand Down
3 changes: 1 addition & 2 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
RemoteErrorDetail,
RPCOptions,
)
from truss_chains.framework import ChainletBase, ModelBase
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand Down
92 changes: 73 additions & 19 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import pydantic
from typing_extensions import ParamSpec

from truss_chains import definitions, public_api, utils
from truss_chains import definitions, utils

_SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None, pydantic.BaseModel}
_SIMPLE_CONTAINERS = {list, dict}
Expand Down Expand Up @@ -1215,27 +1215,27 @@ def decorator(cls: Type[ChainletT]) -> Type[ChainletT]:
class _ABCImporter(abc.ABC):
@classmethod
@abc.abstractmethod
def no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
pass

@classmethod
@abc.abstractmethod
def multiple_entrypoints_error(
def _multiple_entrypoints_error(
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
) -> ValueError:
pass

@classmethod
@abc.abstractmethod
def target_cls_type(cls) -> Type[definitions.ABCChainlet]:
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
pass

@classmethod
def _get_entrypoint_chainlets(cls, symbols) -> set[Type[definitions.ABCChainlet]]:
return {
sym
for sym in symbols
if utils.issubclass_safe(sym, cls.target_cls_type())
if utils.issubclass_safe(sym, cls._target_cls_type())
and cast(definitions.ABCChainlet, sym).meta_data.is_entrypoint
}

Expand Down Expand Up @@ -1336,17 +1336,17 @@ def import_target(
f"Target class `{target_name}` not found "
f"in `{resolved_module_path}`."
)
if not utils.issubclass_safe(target_cls, cls.target_cls_type()):
if not utils.issubclass_safe(target_cls, cls._target_cls_type()):
raise TypeError(
f"Target `{target_cls}` is not a {cls.target_cls_type()}."
f"Target `{target_cls}` is not a {cls._target_cls_type()}."
)
else:
module_vars = (getattr(module, name) for name in dir(module))
entrypoints = cls._get_entrypoint_chainlets(module_vars)
if len(entrypoints) == 0:
raise cls.no_entrypoint_error(module_path)
raise cls._no_entrypoint_error(module_path)
elif len(entrypoints) > 1:
raise cls.multiple_entrypoints_error(module_path, entrypoints)
raise cls._multiple_entrypoints_error(module_path, entrypoints)
target_cls = utils.expect_one(entrypoints)
yield target_cls
finally:
Expand All @@ -1359,15 +1359,15 @@ def import_target(

class ChainletImporter(_ABCImporter):
@classmethod
def no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
return ValueError(
"No `target_name` was specified and no Chainlet in "
f"`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag "
"one Chainlet or provide the Chainlet class name."
)

@classmethod
def multiple_entrypoints_error(
def _multiple_entrypoints_error(
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
) -> ValueError:
return ValueError(
Expand All @@ -1378,27 +1378,81 @@ def multiple_entrypoints_error(
)

@classmethod
def target_cls_type(cls) -> Type[definitions.ABCChainlet]:
return public_api.ChainletBase
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
return ChainletBase


class ModelImporter(_ABCImporter):
@classmethod
def no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
def _no_entrypoint_error(cls, module_path: pathlib.Path) -> ValueError:
return ValueError(
f"No Model class in `{module_path}` inherits from {cls.target_cls_type()}."
f"No Model class in `{module_path}` inherits from {cls._target_cls_type()}."
)

@classmethod
def multiple_entrypoints_error(
def _multiple_entrypoints_error(
cls, module_path: pathlib.Path, entrypoints: set[type[definitions.ABCChainlet]]
) -> ValueError:
return ValueError(
f"Multiple Model classes in `{module_path}` inherit from {cls.target_cls_type()}, "
f"Multiple Model classes in `{module_path}` inherit from {cls._target_cls_type()}, "
"but only one allowed. Found classes: "
f"\n{list(cls.name for cls in entrypoints)}"
)

@classmethod
def target_cls_type(cls) -> Type[definitions.ABCChainlet]:
return public_api.ModelBase
def _target_cls_type(cls) -> Type[definitions.ABCChainlet]:
return ModelBase


class ChainletBase(definitions.ABCChainlet):
"""Base class for all chainlets.
Inheriting from this class adds validations to make sure subclasses adhere to the
chainlet pattern and facilitates remote chainlet deployment.
Refer to `the docs <https://docs.baseten.co/chains/getting-started>`_ and this
`example chainlet <https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py>`_
for more guidance on how to create subclasses.
"""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls._framework_config = definitions.FrameworkConfig(
entity_type="Chainlet",
supports_dependencies=True,
endpoint_method_name=definitions.RUN_REMOTE_METHOD_NAME,
)
# Each sub-class has own, isolated metadata, e.g. we don't want
# `mark_entrypoint` to propagate to subclasses.
cls.meta_data = definitions.ChainletMetadata()
validate_and_register_cls(cls) # Errors are collected, not raised!
# For default init (from `object`) we don't need to check anything.
if cls.has_custom_init():
original_init = cls.__init__

@functools.wraps(original_init)
def __init_with_arg_check__(self, *args, **kwargs):
if args:
raise definitions.ChainsRuntimeError("Only kwargs are allowed.")
ensure_args_are_injected(cls, original_init, kwargs)
original_init(self, *args, **kwargs)

cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]


class ModelBase(definitions.ABCChainlet):
"""Base class for all standalone models.
Inheriting from this class adds validations to make sure subclasses adhere to the
truss model pattern.
"""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls._framework_config = definitions.FrameworkConfig(
entity_type="Model",
supports_dependencies=False,
endpoint_method_name=definitions.MODEL_ENDPOINT_METHOD_NAME,
)
cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True)
validate_and_register_cls(cls)
55 changes: 0 additions & 55 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import pathlib
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -91,60 +90,6 @@ def depends(
return framework.ChainletDependencyMarker(chainlet_cls, options) # type: ignore


class ChainletBase(definitions.ABCChainlet):
"""Base class for all chainlets.
Inheriting from this class adds validations to make sure subclasses adhere to the
chainlet pattern and facilitates remote chainlet deployment.
Refer to `the docs <https://docs.baseten.co/chains/getting-started>`_ and this
`example chainlet <https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py>`_
for more guidance on how to create subclasses.
"""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls._framework_config = definitions.FrameworkConfig(
entity_type="Chainlet",
supports_dependencies=True,
endpoint_method_name=definitions.RUN_REMOTE_METHOD_NAME,
)
# Each sub-class has own, isolated metadata, e.g. we don't want
# `mark_entrypoint` to propagate to subclasses.
cls.meta_data = definitions.ChainletMetadata()
framework.validate_and_register_cls(cls) # Errors are collected, not raised!
# For default init (from `object`) we don't need to check anything.
if cls.has_custom_init():
original_init = cls.__init__

@functools.wraps(original_init)
def __init_with_arg_check__(self, *args, **kwargs):
if args:
raise definitions.ChainsRuntimeError("Only kwargs are allowed.")
framework.ensure_args_are_injected(cls, original_init, kwargs)
original_init(self, *args, **kwargs)

cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]


class ModelBase(definitions.ABCChainlet):
"""Base class for all singular truss models.
Inheriting from this class adds validations to make sure subclasses adhere to the
truss model pattern.
"""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls._framework_config = definitions.FrameworkConfig(
entity_type="Model",
supports_dependencies=False,
endpoint_method_name=definitions.MODEL_ENDPOINT_METHOD_NAME,
)
cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True)
framework.validate_and_register_cls(cls)


@overload
def mark_entrypoint(
cls_or_chain_name: Type[framework.ChainletT],
Expand Down

0 comments on commit 2568a20

Please sign in to comment.