diff --git a/truss-chains/tests/chains_test.py b/truss-chains/tests/chains_e2e_test.py similarity index 56% rename from truss-chains/tests/chains_test.py rename to truss-chains/tests/chains_e2e_test.py index e3c7fcb89..9e403a91b 100644 --- a/truss-chains/tests/chains_test.py +++ b/truss-chains/tests/chains_e2e_test.py @@ -1,14 +1,10 @@ import logging -import re from pathlib import Path -from typing import List -import pydantic import pytest import requests from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all -import truss_chains as chains from truss_chains import definitions, framework, public_api, remote, utils utils.setup_dev_logging(logging.DEBUG) @@ -128,114 +124,3 @@ async def test_chain_local(): match="Chainlets cannot be naively instantiated", ): await entrypoint().run_remote(length=20, num_partitions=5) - - -# Chainlet Initialization Guarding ##################################################### - - -def test_raises_without_depends(): - with pytest.raises(definitions.ChainsUsageError, match="chains.provide"): - - class WithoutDepends(chains.ChainletBase): - def __init__(self, chainlet1): - self.chainlet1 = chainlet1 - - def run_remote(self) -> str: - return self.chainlet1.run_remote() - - -class Chainlet1(chains.ChainletBase): - def run_remote(self) -> str: - return self.__class__.name - - -class Chainlet2(chains.ChainletBase): - def run_remote(self) -> str: - return self.__class__.name - - -class InitInInit(chains.ChainletBase): - def __init__(self, chainlet2=chains.depends(Chainlet2)): - self.chainlet1 = Chainlet1() - self.chainlet2 = chainlet2 - - def run_remote(self) -> str: - return self.chainlet1.run_remote() - - -class InitInRun(chains.ChainletBase): - def run_remote(self) -> str: - Chainlet1() - return "abc" - - -def foo(): - return Chainlet1() - - -class InitWithFn(chains.ChainletBase): - def __init__(self): - foo() - - def run_remote(self) -> str: - return self.__class__.name - - -def test_raises_init_in_init(): - match = "Chainlets cannot be naively instantiated" - with pytest.raises(definitions.ChainsRuntimeError, match=match): - with chains.run_local(): - InitInInit() - - -def test_raises_init_in_run(): - match = "Chainlets cannot be naively instantiated" - with pytest.raises(definitions.ChainsRuntimeError, match=match): - with chains.run_local(): - chain = InitInRun() - chain.run_remote() - - -def test_raises_init_in_function(): - match = "Chainlets cannot be naively instantiated" - with pytest.raises(definitions.ChainsRuntimeError, match=match): - with chains.run_local(): - InitWithFn() - - -def test_raises_depends_usage(): - class InlinedDepends(chains.ChainletBase): - def __init__(self): - self.chainlet1 = chains.depends(Chainlet1) - - def run_remote(self) -> str: - return self.chainlet1.run_remote() - - match = ( - "`chains.depends(Chainlet1)` was used, but not as " - "an argument to the `__init__`" - ) - with pytest.raises(definitions.ChainsRuntimeError, match=re.escape(match)): - with chains.run_local(): - chain = InlinedDepends() - chain.run_remote() - - -class SomeModel(pydantic.BaseModel): - foo: int - - -def test_raises_unsupported_arg_type_list_object(): - with pytest.raises(definitions.ChainsUsageError, match="Unsupported I/O type"): - - class UnsupportedArgType(chains.ChainletBase): - def run_remote(self) -> list[pydantic.BaseModel]: - return [SomeModel(foo=0)] - - -def test_raises_unsupported_arg_type_list_object_legacy(): - with pytest.raises(definitions.ChainsUsageError, match="Unsupported I/O type"): - - class UnsupportedArgType(chains.ChainletBase): - def run_remote(self) -> List[pydantic.BaseModel]: - return [SomeModel(foo=0)] diff --git a/truss-chains/tests/test_framework.py b/truss-chains/tests/test_framework.py new file mode 100644 index 000000000..45ef50c73 --- /dev/null +++ b/truss-chains/tests/test_framework.py @@ -0,0 +1,419 @@ +import contextlib +import logging +import re +from typing import List + +import pydantic +import pytest + +import truss_chains as chains +from truss_chains import definitions, framework, public_api, utils + +utils.setup_dev_logging(logging.DEBUG) + + +# Assert that naive chainlet initialization is detected and prevented. ################# + + +class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + +class Chainlet2(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + +class InitInInit(chains.ChainletBase): + def __init__(self, chainlet2=chains.depends(Chainlet2)): + self.chainlet1 = Chainlet1() + self.chainlet2 = chainlet2 + + def run_remote(self) -> str: + return self.chainlet1.run_remote() + + +class InitInRun(chains.ChainletBase): + def run_remote(self) -> str: + Chainlet1() + return "abc" + + +def foo(): + return Chainlet1() + + +class InitWithFn(chains.ChainletBase): + def __init__(self): + foo() + + def run_remote(self) -> str: + return self.__class__.name + + +def test_raises_init_in_init(): + match = "Chainlets cannot be naively instantiated" + with pytest.raises(definitions.ChainsRuntimeError, match=match): + with chains.run_local(): + InitInInit() + + +def test_raises_init_in_run(): + match = "Chainlets cannot be naively instantiated" + with pytest.raises(definitions.ChainsRuntimeError, match=match): + with chains.run_local(): + chain = InitInRun() + chain.run_remote() + + +def test_raises_init_in_function(): + match = "Chainlets cannot be naively instantiated" + with pytest.raises(definitions.ChainsRuntimeError, match=match): + with chains.run_local(): + InitWithFn() + + +def test_raises_depends_usage(): + class InlinedDepends(chains.ChainletBase): + def __init__(self): + self.chainlet1 = chains.depends(Chainlet1) + + def run_remote(self) -> str: + return self.chainlet1.run_remote() + + match = ( + "`chains.depends(Chainlet1)` was used, but not as " + "an argument to the `__init__`" + ) + with pytest.raises(definitions.ChainsRuntimeError, match=re.escape(match)): + with chains.run_local(): + chain = InlinedDepends() + chain.run_remote() + + +# Assert that Chain(let) definitions are validated ################################# + + +@contextlib.contextmanager +def _raise_errors(): + framework._global_chainlet_registry.clear() + framework.raise_validation_errors() + yield + framework._global_chainlet_registry.clear() + framework.raise_validation_errors() + + +TEST_FILE = __file__ + + +def test_raises_without_depends(): + match = ( + rf"{TEST_FILE}:\d+ \(WithoutDepends\.__init__\) \[kind: TYPE_ERROR\].*must " + r"have dependency Chainlets with default values from `chains.depends`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class WithoutDepends(chains.ChainletBase): + def __init__(self, chainlet1): + self.chainlet1 = chainlet1 + + def run_remote(self) -> str: + return self.chainlet1.run_remote() + + +class SomeModel(pydantic.BaseModel): + foo: int + + +def test_raises_unsupported_return_type_list_object(): + match = ( + rf"{TEST_FILE}:\d+ \(UnsupportedArgType\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Unsupported I/O type for `return_type`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class UnsupportedArgType(chains.ChainletBase): + def run_remote(self) -> list[pydantic.BaseModel]: + return [SomeModel(foo=0)] + + +def test_raises_unsupported_return_type_list_object_legacy(): + match = ( + rf"{TEST_FILE}:\d+ \(UnsupportedArgType\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Unsupported I/O type for `return_type`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class UnsupportedArgType(chains.ChainletBase): + def run_remote(self) -> List[pydantic.BaseModel]: + return [SomeModel(foo=0)] + + +def test_raises_unsupported_arg_type_list_object(): + match = ( + rf"{TEST_FILE}:\d+ \(UnsupportedArgType\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Unsupported I/O type for `arg`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class UnsupportedArgType(chains.ChainletBase): + def run_remote(self, arg: list[pydantic.BaseModel]) -> None: + return + + +def test_raises_unsupported_arg_type_object(): + match = ( + rf"{TEST_FILE}:\d+ \(UnsupportedArgType\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Unsupported I/O type for `arg` of type `<class 'object'>`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class UnsupportedArgType(chains.ChainletBase): + def run_remote(self, arg: object) -> None: + return + + +def test_raises_unsupported_arg_type_str_annot(): + match = ( + rf"{TEST_FILE}:\d+ \(UnsupportedArgType\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"A string-valued type annotation was found for `arg` of type `SomeModel`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class UnsupportedArgType(chains.ChainletBase): + def run_remote(self, arg: "SomeModel") -> None: + return + + +def test_raises_endpoint_no_method(): + match = ( + rf"{TEST_FILE}:\d+ \(StaticMethod\.run_remote\) \[kind: TYPE_ERROR\].*" + r"Endpoint must be a method" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class StaticMethod(chains.ChainletBase): + @staticmethod + def run_remote() -> None: + return + + +def test_raises_endpoint_no_method_arg(): + match = ( + rf"{TEST_FILE}:\d+ \(StaticMethod\.run_remote\) \[kind: TYPE_ERROR\].*" + r"Endpoint must be a method" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class StaticMethod(chains.ChainletBase): + @staticmethod + def run_remote(arg: "SomeModel") -> None: + return + + +def test_raises_endpoint_not_annotated(): + match = ( + rf"{TEST_FILE}:\d+ \(NoArgAnnot\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Arguments of endpoints must have type annotations." + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class NoArgAnnot(chains.ChainletBase): + def run_remote(self, arg) -> None: + return + + +def test_raises_endpoint_return_not_annotated(): + match = ( + rf"{TEST_FILE}:\d+ \(NoReturnAnnot\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Return values of endpoints must be type annotated." + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class NoReturnAnnot(chains.ChainletBase): + def run_remote(self): + return + + +def test_raises_endpoint_return_not_supported(): + match = ( + rf"{TEST_FILE}:\d+ \(ReturnNotSupported\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Unsupported I/O type for `return_type` of type `<class 'object'>`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class ReturnNotSupported(chains.ChainletBase): + def run_remote(self) -> object: + return object() + + +def test_raises_no_endpoint(): + match = ( + rf"{TEST_FILE}:\d+ \(NoEndpoint\) \[kind: MISSING_API_ERROR\].*" + r"Chainlets must have a `run_remote` method." + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class NoEndpoint(chains.ChainletBase): + def rum_remote(self) -> object: + return object() + + +def test_raises_context_not_trailing(): + match = ( + rf"{TEST_FILE}:\d+ \(ContextNotTrailing\.__init__\) \[kind: TYPE_ERROR\].*" + r"The init argument name `context` is reserved for the optional context " + f"argument, which must be trailing" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + class ContextNotTrailing(chains.ChainletBase): + def __init__(self, context, chainlet1=chains.depends(Chainlet1)): ... + + +def test_raises_not_dep_marker(): + match = ( + rf"{TEST_FILE}:\d+ \(NoDepMarker\.__init__\) \[kind: TYPE_ERROR\].*" + r"Any arguments of a Chainlet\'s __init__ \(besides `context`\) must have " + f"dependency Chainlets with default values from `chains.depends`-directive" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + class NoDepMarker(chains.ChainletBase): + def __init__(self, chainlet1=Chainlet1): ... + + +def test_raises_dep_not_chainlet(): + match = ( + rf"{TEST_FILE}:\d+ \(DepNotChainlet\.__init__\) \[kind: TYPE_ERROR\].*" + r"`chains.depends` must be used with a Chainlet class as argument, got <class " + f"'truss_chains.definitions.RPCOptions'>" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + class DepNotChainlet(chains.ChainletBase): + def __init__(self, chainlet1=chains.depends(definitions.RPCOptions)): ... + + +def test_raises_dep_not_chainlet_annot(): + match = ( + rf"{TEST_FILE}:\d+ \(DepNotChainletAnnot\.__init__\) \[kind: TYPE_ERROR\].*" + r"The type annotation for `chainlet1` must be a class/subclass of the " + "Chainlet type" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + class DepNotChainletAnnot(chains.ChainletBase): + def __init__( + self, + chainlet1: definitions.RPCOptions = chains.depends(Chainlet1), # type: ignore + ): ... + + +def test_raises_context_missing_default(): + match = ( + rf"{TEST_FILE}:\d+ \(ContextMissingDefault\.__init__\) \[kind: TYPE_ERROR\].*" + r"f `<class \'truss_chains.definitions.ABCChainlet\'>` uses context for " + r"initialization, it must have `context` argument of type `<class " + f"'truss_chains.definitions.DeploymentContext'>`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class ContextMissingDefault(chains.ChainletBase): + def __init__(self, context=None): ... + + +def test_raises_context_wrong_annot(): + match = ( + rf"{TEST_FILE}:\d+ \(ConextWrongAnnot\.__init__\) \[kind: TYPE_ERROR\].*" + r"f `<class \'truss_chains.definitions.ABCChainlet\'>` uses context for " + r"initialization, it must have `context` argument of type `<class " + f"'truss_chains.definitions.DeploymentContext'>`" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class ConextWrongAnnot(chains.ChainletBase): + def __init__(self, context: object = chains.depends_context()): ... + + +def test_raises_chainlet_reuse(): + match = ( + rf"{TEST_FILE}:\d+ \(ChainletReuse\.__init__\) \[kind: TYPE_ERROR\].*" + r"The same Chainlet class cannot be used multiple times for different arguments" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Chainlet1(chains.ChainletBase): + def run_remote(self) -> str: + return self.__class__.name + + class ChainletReuse(chains.ChainletBase): + def __init__( + self, dep1=chains.depends(Chainlet1), dep2=chains.depends(Chainlet1) + ): ... + + def run_remote(self) -> None: + return + + +def test_collects_multiple_errors(): + match = r"The Chainlet definitions contain 5 errors:" + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class MultiIssue(chains.ChainletBase): + def __init__(self, context, chainlet1): + self.chainlet1 = chainlet1 + + def run_remote(argument: object): ... + + assert len(framework._global_error_collector._errors) == 5 + + +def test_collects_multiple_errors_run_local(): + class MultiIssue(chains.ChainletBase): + def __init__(self, context, chainlet1): + self.chainlet1 = chainlet1 + + def run_remote(argument: object): ... + + match = r"The Chainlet definitions contain 5 errors:" + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + with public_api.run_local(): + MultiIssue() diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 22fa8f18f..01cacd22a 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -636,9 +636,6 @@ def gen_truss_chainlet( chainlet_display_name_to_url: Mapping[str, str], user_env: Mapping[str, str], ) -> pathlib.Path: - dependencies = framework.global_chainlet_registry.get_dependencies( - chainlet_descriptor - ) # Filter needed services and customize options. dep_services = {} for dep in chainlet_descriptor.dependencies.values(): @@ -672,7 +669,9 @@ def gen_truss_chainlet( f"Python file name `{_MODEL_FILENAME}` is reserved and cannot be used." ) chainlet_file = _gen_truss_chainlet_file( - chainlet_dir, chainlet_descriptor, dependencies + chainlet_dir, + chainlet_descriptor, + framework.get_dependencies(chainlet_descriptor), ) remote_config = chainlet_descriptor.chainlet_cls.remote_config if remote_config.docker_image.data_dir: diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 3563f6f09..a65028a63 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -1,7 +1,9 @@ import ast +import atexit import collections import contextlib import contextvars +import enum import functools import importlib.util import inspect @@ -28,6 +30,7 @@ ) import pydantic +from typing_extensions import ParamSpec from truss_chains import definitions, utils @@ -40,6 +43,112 @@ _ENTRYPOINT_ATTR_NAME = "_chains_entrypoint" ChainletT = TypeVar("ChainletT", bound=definitions.ABCChainlet) +_P = ParamSpec("_P") +_R = TypeVar("_R") + +# Error Collector ###################################################################### + + +class _ErrorKind(str, enum.Enum): + TYPE_ERROR = enum.auto() + IO_TYPE_ERROR = enum.auto() + MISSING_API_ERROR = enum.auto() + + +class _ErrorLocation(definitions.SafeModel): + src_path: str + line: Optional[int] = None + chainlet_name: Optional[str] = None + method_name: Optional[str] = None + + def __str__(self) -> str: + value = f"{self.src_path}:{self.line}" + if self.chainlet_name and self.method_name: + value = f"{value} ({self.chainlet_name}.{self.method_name})" + elif self.chainlet_name: + value = f"{value} ({self.chainlet_name})" + else: + assert not self.chainlet_name + return value + + +class _ValidationError(definitions.SafeModel): + msg: str + kind: _ErrorKind + location: _ErrorLocation + + def __str__(self) -> str: + return f"{self.location} [kind: {self.kind.name}]: {self.msg}" + + +class _ErrorCollector: + _errors: list[_ValidationError] + + def __init__(self) -> None: + self._errors = [] + # This hook is for the case of just running the Chainlet file, without + # making a push - we want to surface the errors at exit. + atexit.register(self.maybe_display_errors) + + def clear(self) -> None: + self._errors.clear() + + def collect(self, error): + self._errors.append(error) + + @property + def has_errors(self) -> bool: + return bool(self._errors) + + @property + def num_errors(self) -> int: + return len(self._errors) + + def format_errors(self) -> str: + parts = [] + for error in self._errors: + parts.append(str(error)) + + return "\n".join(parts) + + def maybe_display_errors(self) -> None: + if self.has_errors: + sys.stderr.write(self.format_errors()) + + +_global_error_collector = _ErrorCollector() + + +def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation): + _global_error_collector.collect( + _ValidationError(msg=msg, kind=kind, location=location) + ) + + +def raise_validation_errors() -> None: + """Raises validation errors as combined ``ChainsUsageError``""" + if _global_error_collector.has_errors: + error_msg = _global_error_collector.format_errors() + errors_count = ( + "an error" + if _global_error_collector.num_errors == 1 + else f"{_global_error_collector.num_errors} errors" + ) + _global_error_collector.clear() # Clear errors so `atexit` won't display them + raise definitions.ChainsUsageError( + f"The Chainlet definitions contain {errors_count}:\n{error_msg}" + ) + + +def raise_validation_errors_before(f: Callable[_P, _R]) -> Callable[_P, _R]: + """Raises validation errors as combined ``ChainsUsageError`` before invoking `f`.""" + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + raise_validation_errors() + return f(*args, **kwargs) + + return wrapper class _BaseProvisionMarker: @@ -54,7 +163,7 @@ def __getattr__(self, item: str) -> Any: logging.error(f"Attempting to access attribute `{item}` on `{self}`.") raise definitions.ChainsRuntimeError( "It seems `chains.depends_context()` was used, but not as an argument " - "to the `__init__` method of a chainlet - This is not supported." + "to the `__init__` method of a Chainlet - This is not supported." f"See {_DOCS_URL_CHAINING}.\n" "Example of correct `__init__` with context:\n" f"{_example_chainlet_code()}" @@ -80,15 +189,15 @@ def __getattr__(self, item: str) -> Any: logging.error(f"Attempting to access attribute `{item}` on `{self}`.") raise definitions.ChainsRuntimeError( f"It seems `chains.depends({self.chainlet_cls.name})` was used, but " - "not as an argument to the `__init__` method of a chainlet - This is not " - "supported. Dependency chainlets must be passed as init arguments.\n" + "not as an argument to the `__init__` method of a Chainlet - This is not " + "supported. Dependency Chainlets must be passed as init arguments.\n" f"See {_DOCS_URL_CHAINING}.\n" "Example of correct `__init__` with dependencies:\n" f"{_example_chainlet_code()}" ) -# Checking of Chainlet class definition ############################################### +# Validation of Chainlet class definition ############################################## @functools.cache @@ -122,9 +231,9 @@ def _example_chainlet_code() -> str: def _instantiation_error_msg(cls_name: str): return ( - f"Error when instantiating chainlet `{cls_name}`. " + f"Error when instantiating Chainlet `{cls_name}`. " "Chainlets cannot be naively instantiated. Possible fixes:\n" - "1. To use chainlets as dependencies in other chainlets 'chaining'), " + "1. To use Chainlets as dependencies in other Chainlets 'chaining'), " f"add them as init argument. See {_DOCS_URL_CHAINING}.\n" f"2. For local / debug execution, use the `{run_local.__name__}`-" f"context. See {_DOCS_URL_LOCAL}.\n" @@ -134,7 +243,9 @@ def _instantiation_error_msg(cls_name: str): ) -def _validate_io_type(annotation: Any, name: str) -> None: +def _validate_io_type( + annotation: Any, param_name: str, location: _ErrorLocation +) -> None: """ For Chainlet I/O (both data or parameters), we allow simple types (int, str, float...) and `list` or `dict` containers of these. @@ -142,88 +253,106 @@ def _validate_io_type(annotation: Any, name: str) -> None: """ containers_str = [c.__name__ for c in _SIMPLE_CONTAINERS] types_str = [c.__name__ if c is not None else "None" for c in _SIMPLE_TYPES] - error_msg = ( - f"Unsupported I/O type `{name}` of type `{annotation}`. Supported are:\n" - f"\t* simple types: {types_str}\n" - f"\t* containers of these simple types, with annotated items: {containers_str}" - ", e.g. `dict[str, int]` (use built-in types, not `typing.Dict`).\n" - "\t* For complicated / nested data structures: `pydantic` models." - ) if isinstance(annotation, str): - raise definitions.ChainsUsageError( - f"A string-valued type annotation was found for `{name}` of type " - f"`{annotation}`. Use only actual types and avoid `from __future__ import " - "annotations` (upgrade python)." + _collect_error( + f"A string-valued type annotation was found for `{param_name}` of type " + f"`{annotation}`. Use only actual types objects and avoid " + "`from __future__ import annotations` (if needed upgrade python).", + _ErrorKind.IO_TYPE_ERROR, + location, ) + return if annotation in _SIMPLE_TYPES: return + + error_msg = ( + f"Unsupported I/O type for `{param_name}` of type `{annotation}`. " + "Supported are:\n" + f"\t* simple types: {types_str}\n" + "\t* containers of these simple types, with annotated item types: " + f"{containers_str}, e.g. `dict[str, int]` (use built-in types, not " + "`typing.Dict`).\n" + "\t* For complicated / nested data structures: `pydantic` models." + ) if isinstance(annotation, types.GenericAlias): if get_origin(annotation) not in _SIMPLE_CONTAINERS: - raise definitions.ChainsUsageError(error_msg) + _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) + return args = get_args(annotation) for arg in args: if arg not in _SIMPLE_TYPES: - raise definitions.ChainsUsageError(error_msg) + _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) + return return if utils.issubclass_safe(annotation, pydantic.BaseModel): return - raise definitions.ChainsUsageError(error_msg) + _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) def _validate_endpoint_params( - params: list[inspect.Parameter], cls_name: str + params: list[inspect.Parameter], location: _ErrorLocation ) -> list[definitions.InputArg]: if len(params) == 0: - raise definitions.ChainsUsageError( - f"`{cls_name}.{definitions.ENDPOINT_METHOD_NAME}` must be a method, i.e. " - f"with `{definitions.SELF_ARG_NAME}` argument." + _collect_error( + f"`Endpoint must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as " + "first argument. Got function with no arguments.", + _ErrorKind.TYPE_ERROR, + location, ) + return [] if params[0].name != definitions.SELF_ARG_NAME: - raise definitions.ChainsUsageError( - f"`{cls_name}.{definitions.ENDPOINT_METHOD_NAME}` must be a method, i.e. " - f"with `{definitions.SELF_ARG_NAME}` argument." + _collect_error( + f"`Endpoint must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as " + f"first argument. Got `{params[0].name}` as first argument.", + _ErrorKind.TYPE_ERROR, + location, ) input_args = [] for param in params[1:]: # Skip self argument. if param.annotation == inspect.Parameter.empty: - raise definitions.ChainsUsageError( - "Inputs of endpoints must have type annotations. For " - f"`{cls_name}.{definitions.ENDPOINT_METHOD_NAME}` parameter " - f"`{param.name}` has no type annotation." + _collect_error( + "Arguments of endpoints must have type annotations. " + f"Parameter `{param.name}` has no type annotation.", + _ErrorKind.IO_TYPE_ERROR, + location, ) - _validate_io_type(param.annotation, param.name) - type_descriptor = definitions.TypeDescriptor(raw=param.annotation) - is_optional = param.default != inspect.Parameter.empty - input_args.append( - definitions.InputArg( - name=param.name, type=type_descriptor, is_optional=is_optional + else: + _validate_io_type(param.annotation, param.name, location) + type_descriptor = definitions.TypeDescriptor(raw=param.annotation) + is_optional = param.default != inspect.Parameter.empty + input_args.append( + definitions.InputArg( + name=param.name, type=type_descriptor, is_optional=is_optional + ) ) - ) return input_args def _validate_endpoint_output_types( - annotation: Any, cls_name: str, signature + annotation: Any, signature, location: _ErrorLocation ) -> list[definitions.TypeDescriptor]: if annotation == inspect.Parameter.empty: - raise definitions.ChainsUsageError( + _collect_error( "Return values of endpoints must be type annotated. Got:\n" - f"{cls_name}.{definitions.ENDPOINT_METHOD_NAME}{signature} -> !MISSING!" + f"\t{location.method_name}{signature} -> !MISSING!", + _ErrorKind.IO_TYPE_ERROR, + location, ) + return [] if get_origin(annotation) is tuple: output_types = [] for i, arg in enumerate(get_args(annotation)): - _validate_io_type(arg, f"return_type[{i}]") + _validate_io_type(arg, f"return_type[{i}]", location) output_types.append(definitions.TypeDescriptor(raw=arg)) else: - _validate_io_type(annotation, "return_type") + _validate_io_type(annotation, "return_type", location) output_types = [definitions.TypeDescriptor(raw=annotation)] return output_types def _validate_and_describe_endpoint( - cls: Type[definitions.ABCChainlet], + cls: Type[definitions.ABCChainlet], location: _ErrorLocation ) -> definitions.EndpointAPIDescriptor: """The "endpoint method" of a Chainlet must have the following signature: @@ -241,22 +370,41 @@ def _validate_and_describe_endpoint( * Generators are allowed, too (but not yet supported). """ if not hasattr(cls, definitions.ENDPOINT_METHOD_NAME): - raise definitions.ChainsUsageError( - f"`{cls.name}` must have a {definitions.ENDPOINT_METHOD_NAME}` method." + _collect_error( + f"Chainlets must have a `{definitions.ENDPOINT_METHOD_NAME}` method.", + _ErrorKind.MISSING_API_ERROR, + location, + ) + # Return a "neutral dummy" if validation fails, this allows to safely + # continue checking for more errors. + return definitions.EndpointAPIDescriptor( + input_args=[], output_types=[], is_async=False, is_generator=False ) + # This is the unbound method. endpoint_method = getattr(cls, definitions.ENDPOINT_METHOD_NAME) + + line = inspect.getsourcelines(endpoint_method)[1] + location = location.model_copy( + update={"line": line, "method_name": definitions.ENDPOINT_METHOD_NAME} + ) + if not inspect.isfunction(endpoint_method): - raise definitions.ChainsUsageError( - f"`{cls.name}.{definitions.ENDPOINT_METHOD_NAME}` must be a method." + _collect_error("`Endpoints must be a method.", _ErrorKind.TYPE_ERROR, location) + # If it's not a function, it might be a class var and subsequent inspections + # fail. + # Return a "neutral dummy" if validation fails, this allows to safely + # continue checking for more errors. + return definitions.EndpointAPIDescriptor( + input_args=[], output_types=[], is_async=False, is_generator=False ) signature = inspect.signature(endpoint_method) input_args = _validate_endpoint_params( - list(signature.parameters.values()), cls.name + list(signature.parameters.values()), location ) output_types = _validate_endpoint_output_types( - signature.return_annotation, cls.name, signature + signature.return_annotation, signature, location ) if inspect.isasyncgenfunction(endpoint_method): @@ -302,29 +450,42 @@ def _get_generic_class_type(var): return origin if origin is not None else var -def _validate_dependency_arg(param) -> ChainletDependencyMarker: +def _validate_dependency_arg( + param, location: _ErrorLocation +) -> Optional[ChainletDependencyMarker]: + # Returns `None` if unvalidated. # TODO: handle subclasses, unions, optionals, check default value etc. if param.name == definitions.CONTEXT_ARG_NAME: - raise definitions.ChainsUsageError( + _collect_error( f"The init argument name `{definitions.CONTEXT_ARG_NAME}` is reserved for " "the optional context argument, which must be trailing if used. Example " "of correct `__init__` with context:\n" - f"{_example_chainlet_code()}" + f"{_example_chainlet_code()}", + _ErrorKind.TYPE_ERROR, + location, ) if not isinstance(param.default, ChainletDependencyMarker): - raise definitions.ChainsUsageError( - f"Any arguments of a chainlet's __init__ (besides `context`) must have " - "dependency chainlets with default values from `chains.provide`-directive. " + _collect_error( + f"Any arguments of a Chainlet's __init__ (besides `context`) must have " + "dependency Chainlets with default values from `chains.depends`-directive. " f"Got `{param}`.\n" f"Example of correct `__init__` with dependencies:\n" - f"{_example_chainlet_code()}" + f"{_example_chainlet_code()}", + _ErrorKind.TYPE_ERROR, + location, ) + return None + chainlet_cls = param.default.chainlet_cls if not utils.issubclass_safe(chainlet_cls, definitions.ABCChainlet): - raise definitions.ChainsUsageError( - f"`{chainlet_cls}` must be a subclass of `{definitions.ABCChainlet}`." + _collect_error( + f"`chains.depends` must be used with a Chainlet class as argument, got " + f"{chainlet_cls} instead.", + _ErrorKind.TYPE_ERROR, + location, ) + return None # Check type annotation. # Also lenient with type annotation: since the RHS / default is asserted to be a # chainlet class, proper type inference is possible even without annotation. @@ -335,10 +496,12 @@ def _validate_dependency_arg(param) -> ChainletDependencyMarker: or utils.issubclass_safe(param.annotation, Protocol) # type: ignore[arg-type] or utils.issubclass_safe(chainlet_cls, param.annotation) ): - raise definitions.ChainsUsageError( - f"The type annotation for `{param.name}` must either be a `{Protocol}` " - "or a class/subclass of the Chainlet type used as default value. " - f"Got `{param.annotation}`." + _collect_error( + f"The type annotation for `{param.name}` must be a class/subclass of the " + "Chainlet type specified by `chains.provides` or a compatible " + f"typing.Protocol`. Got `{param.annotation}`.", + _ErrorKind.TYPE_ERROR, + location, ) return param.default # The Marker. @@ -349,19 +512,19 @@ class _ChainletInitValidator: ``` def __init__( self, - [dep_0: dep_0_type = truss_chains.provide(dep_0_class),] - [dep_1: dep_1_type = truss_chains.provide(dep_1_class),] + [dep_0: dep_0_type = truss_chains.depends(dep_0_class),] + [dep_1: dep_1_type = truss_chains.depends(dep_1_class),] ... - [dep_N: dep_N_type = truss_chains.provide(dep_N_class),] + [dep_N: dep_N_type = truss_chains.provides(dep_N_class),] [context: truss_chains.Context[UserConfig] = truss_chains.provide_context()] ) -> None: ``` * The context argument is optionally trailing and must have a default constructed with the `provide_context` directive. The type can be templated by a user defined config e.g. `truss_chains.Context[UserConfig]`. - * The names and number of chainlet "dependency" arguments are arbitrary. - * Default values for dependencies must be constructed with the `provide` directive - to make the dependency injection work. The argument to `provide` must be a + * The names and number of Chainlet "dependency" arguments are arbitrary. + * Default values for dependencies must be constructed with the `depends` directive + to make the dependency injection work. The argument to `depends` must be a Chainlet class. * The type annotation for dependencies can be a Chainlet class, but it can also be a `Protocol` with an equivalent `run` method (e.g. for getting correct type @@ -369,39 +532,51 @@ def __init__( the type is clear from the RHS. """ + _location: _ErrorLocation has_context: bool validated_dependencies: Mapping[str, definitions.DependencyDescriptor] - def __init__(self, cls: Type[definitions.ABCChainlet]) -> None: + def __init__( + self, cls: Type[definitions.ABCChainlet], location: _ErrorLocation + ) -> None: if not cls.has_custom_init(): self.has_context = False self.validated_dependencies = {} return # Each validation pops of "processed" arguments from the list. + line = inspect.getsourcelines(cls.__init__)[1] + self._location = location.model_copy( + update={"line": line, "method_name": "__init__"} + ) params = list(inspect.signature(cls.__init__).parameters.values()) params = self._validate_self_arg(list(params)) params, self.has_context = self._validate_context_arg(params) self.validated_dependencies = self._validate_dependencies(params) - @staticmethod - def _validate_self_arg(params: list[inspect.Parameter]) -> list[inspect.Parameter]: + def _validate_self_arg( + self, params: list[inspect.Parameter] + ) -> list[inspect.Parameter]: if len(params) == 0: - raise definitions.ChainsUsageError( - "Methods must have first argument `self`, got no arguments." + _collect_error( + "Methods must have first argument `self`, got no arguments.", + _ErrorKind.TYPE_ERROR, + self._location, ) + return params param = params.pop(0) if param.name != definitions.SELF_ARG_NAME: - raise definitions.ChainsUsageError( - f"Methods must have first argument `self`, got `{param.name}`." + _collect_error( + f"Methods must have first argument `self`, got `{param.name}`.", + _ErrorKind.TYPE_ERROR, + self._location, ) return params - @staticmethod def _validate_context_arg( - params: list[inspect.Parameter], + self, params: list[inspect.Parameter] ) -> tuple[list[inspect.Parameter], bool]: - def make_context_exception(): - return definitions.ChainsUsageError( + def make_context_error_msg(): + return ( f"If `{definitions.ABCChainlet}` uses context for initialization, it " f"must have `{definitions.CONTEXT_ARG_NAME}` argument of type " f"`{definitions.DeploymentContext}` as the last argument.\n" @@ -416,7 +591,9 @@ def make_context_exception(): has_context = params[-1].name == definitions.CONTEXT_ARG_NAME has_context_marker = isinstance(params[-1].default, ContextDependencyMarker) if has_context ^ has_context_marker: - raise make_context_exception() + _collect_error( + make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location + ) if not has_context: return params, has_context @@ -429,75 +606,95 @@ def make_context_exception(): and (param_type != inspect.Parameter.empty) and (not utils.issubclass_safe(param_type, definitions.DeploymentContext)) ): - raise make_context_exception() + _collect_error( + make_context_error_msg(), _ErrorKind.TYPE_ERROR, self._location + ) if not isinstance(param.default, ContextDependencyMarker): - raise definitions.ChainsUsageError( + _collect_error( f"Incorrect default value `{param.default}` for `context` argument. " "Example of correct `__init__` with dependencies:\n" - f"{_example_chainlet_code()}" + f"{_example_chainlet_code()}", + _ErrorKind.TYPE_ERROR, + self._location, ) return params, has_context - @staticmethod def _validate_dependencies( - params, + self, params ) -> Mapping[str, definitions.DependencyDescriptor]: used = set() dependencies = {} for param in params: - marker = _validate_dependency_arg(param) + marker = _validate_dependency_arg(param, self._location) + if marker is None: + continue if marker.chainlet_cls in used: - raise definitions.ChainsUsageError( + _collect_error( f"The same Chainlet class cannot be used multiple times for " f"different arguments. Got previously used " - f"`{marker.chainlet_cls}` for `{param.name}`." + f"`{marker.chainlet_cls}` for `{param.name}`.", + _ErrorKind.TYPE_ERROR, + self._location, ) + dependencies[param.name] = definitions.DependencyDescriptor( chainlet_cls=marker.chainlet_cls, options=marker.options ) - used.add(marker) + used.add(marker.chainlet_cls) return dependencies -def _validate_chainlet_cls(cls: Type[definitions.ABCChainlet]) -> None: - # TODO: ensure Chainlets are only accessed via `provided` in `__init__`, - # not from manual instantiations on module-level or nested in a Chainlet. - # See other constraints listed in: - # https://www.notion.so/ml-infra/WIP-Orchestration-a8cb4dad00dd488191be374b469ffd0a?pvs=4#7df299eb008f467a80f7ee3c0eccf0f0 +def _validate_chainlet_cls( + cls: Type[definitions.ABCChainlet], location: _ErrorLocation +) -> None: if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): - raise definitions.ChainsUsageError( + _collect_error( f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " f"`{definitions.REMOTE_CONFIG_NAME} = {definitions.RemoteConfig.__name__}" - f"(...)`. Missing for `{cls}`." + f"(...)`. Missing for `{cls}`.", + _ErrorKind.MISSING_API_ERROR, + location, ) + return + if not isinstance( remote_config := getattr(cls, definitions.REMOTE_CONFIG_NAME), definitions.RemoteConfig, ): - raise definitions.ChainsUsageError( + _collect_error( f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable " f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` " - f"for `{cls}`." + f"for `{cls}`.", + _ErrorKind.TYPE_ERROR, + location, ) + return -def check_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: - _validate_chainlet_cls(cls) +def validate_and_register_class(cls: Type[definitions.ABCChainlet]) -> None: + """Note that validation errors will only be collected, not raised, and Chainlets. + with issues, are still added to the registry. Use `raise_validation_errors` to + assert all Chainlets are valid and before performing operations that depend on + these constraints.""" + src_path = os.path.abspath(inspect.getfile(cls)) + line = inspect.getsourcelines(cls)[1] + location = _ErrorLocation(src_path=src_path, line=line, chainlet_name=cls.__name__) - init_validator = _ChainletInitValidator(cls) + _validate_chainlet_cls(cls, location) + init_validator = _ChainletInitValidator(cls, location) chainlet_descriptor = definitions.ChainletAPIDescriptor( chainlet_cls=cls, dependencies=init_validator.validated_dependencies, has_context=init_validator.has_context, - endpoint=_validate_and_describe_endpoint(cls), - src_path=os.path.abspath(inspect.getfile(cls)), + endpoint=_validate_and_describe_endpoint(cls, location), + src_path=src_path, user_config_type=definitions.TypeDescriptor(raw=type(cls.default_user_config)), ) logging.debug( f"Descriptor for {cls}:\n{pprint.pformat(chainlet_descriptor, indent=4)}\n" ) - global_chainlet_registry.register_chainlet(chainlet_descriptor) + _global_chainlet_registry.register_chainlet(chainlet_descriptor) # Dependency-Injection / Registry ###################################################### @@ -515,13 +712,21 @@ def __init__(self) -> None: self._chainlets = collections.OrderedDict() self._name_to_cls = {} + def clear(self): + self._chainlets = collections.OrderedDict() + self._name_to_cls = {} + def register_chainlet(self, chainlet_descriptor: definitions.ChainletAPIDescriptor): for dep in chainlet_descriptor.dependencies.values(): # To depend on a Chainlet, the class must be defined (module initialized) # which entails that is has already been added to the registry. - if dep.chainlet_cls not in self._chainlets: - logging.error(f"Available chainlets: {list(self._chainlets.keys())}") - raise KeyError(dep.chainlet_cls) + # This is an assertion, because unless users meddle with the internal + # registry, it's not possible to depend on another chainlet before it's + # also added to the registry. + assert dep.chainlet_cls in self._chainlets, ( + "Cannot depend on Chainlet. Available Chainlets: " + f"{list(self._chainlets.keys())}" + ) # Because class are globally unique, to prevent re-use / overwriting of names, # We must check this in addition. @@ -529,8 +734,8 @@ def register_chainlet(self, chainlet_descriptor: definitions.ChainletAPIDescript conflict = self._name_to_cls[chainlet_descriptor.name] existing_source_path = self._chainlets[conflict].src_path raise definitions.ChainsUsageError( - f"A chainlet with name `{chainlet_descriptor.name}` was already " - f"defined, chainlet names must be globally unique.\n" + f"A Chainlet with name `{chainlet_descriptor.name}` was already " + f"defined, Chainlet names must be globally unique.\n" f"Pre-existing in: `{existing_source_path}`\n" f"New conflict in: `{chainlet_descriptor.src_path}`." ) @@ -563,7 +768,23 @@ def get_chainlet_names(self) -> set[str]: return set(self._name_to_cls.keys()) -global_chainlet_registry = _ChainletRegistry() +_global_chainlet_registry = _ChainletRegistry() + + +def get_dependencies( + chainlet: definitions.ChainletAPIDescriptor, +) -> Iterable[definitions.ChainletAPIDescriptor]: + return _global_chainlet_registry.get_dependencies(chainlet) + + +def get_descriptor( + chainlet_cls: Type[definitions.ABCChainlet], +) -> definitions.ChainletAPIDescriptor: + return _global_chainlet_registry.get_descriptor(chainlet_cls) + + +def get_ordered_descriptors() -> list[definitions.ChainletAPIDescriptor]: + return _global_chainlet_registry.chainlet_descriptors # Chainlet class runtime utils ######################################################### @@ -591,7 +812,7 @@ def ensure_args_are_injected(cls, original_init: Callable, kwargs) -> None: # The argument is a dependency chainlet. elif isinstance(value, _BaseProvisionMarker): logging.error( - f"When initializing Chainlet `{cls.name}`, for dependency chainlet" + f"When initializing Chainlet `{cls.name}`, for dependency Chainlet" f"argument `{name}` an incompatible value was passed, value: `{value}`." ) raise definitions.ChainsRuntimeError(_instantiation_error_msg(cls.name)) @@ -698,6 +919,7 @@ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None: @contextlib.contextmanager +@raise_validation_errors_before def run_local( secrets: Mapping[str, str], data_dir: Optional[pathlib.Path], @@ -714,7 +936,7 @@ def run_local( # Capture the stack depth when entering the context manager stack_depth = len(inspect.stack()) token = None - for chainlet_descriptor in global_chainlet_registry.chainlet_descriptors: + for chainlet_descriptor in _global_chainlet_registry.chainlet_descriptors: original_inits[chainlet_descriptor.chainlet_cls] = ( chainlet_descriptor.chainlet_cls.__init__ ) @@ -745,10 +967,15 @@ def run_local( def entrypoint(cls: Type[ChainletT]) -> Type[ChainletT]: - """Decorator to tag a chainlet as an entrypoint.""" + """Decorator to tag a Chainlet as an entrypoint.""" if not (utils.issubclass_safe(cls, definitions.ABCChainlet)): - raise definitions.ChainsUsageError( - "Only chainlet classes can be marked as entrypoint." + src_path = os.path.abspath(inspect.getfile(cls)) + line = inspect.getsourcelines(cls)[1] + location = _ErrorLocation(src_path=src_path, line=line) + _collect_error( + "Only Chainlet classes can be marked as entrypoint.", + _ErrorKind.TYPE_ERROR, + location, ) setattr(cls, _ENTRYPOINT_ATTR_NAME, True) return cls @@ -768,16 +995,16 @@ def import_target( module_path: pathlib.Path, target_name: Optional[str] ) -> Iterator[Type[definitions.ABCChainlet]]: """The context manager ensures that modules imported by the chain and - chainlets registered in ``global_chainlet_registry`` are removed upon exit. + Chainlets registered in ``_global_chainlet_registry`` are removed upon exit. I.e. aiming at making the import idempotent for common usages, although there could - be additional side-effects not accounted for by this implementation.""" + be additional side effects not accounted for by this implementation.""" module_path = pathlib.Path(module_path).resolve() module_name = module_path.stem # Use the file's name as the module name if not os.path.isfile(module_path): raise ImportError( f"`{module_path}` is not a file. You must point to a python file where " - "the entrypoint chainlet is defined." + "the entrypoint Chainlet is defined." ) import_error_msg = f"Could not import `{module_path}`. Check path." @@ -803,21 +1030,22 @@ def import_target( sys.modules[module_name] = module # Add path for making absolute imports relative to the source_module's dir. sys.path.insert(0, str(module_path.parent)) - chainlets_before = global_chainlet_registry.get_chainlet_names() + chainlets_before = _global_chainlet_registry.get_chainlet_names() chainlets_after = set() modules_after = set() try: try: spec.loader.exec_module(module) + raise_validation_errors() finally: modules_after = set(sys.modules.keys()) - chainlets_after = global_chainlet_registry.get_chainlet_names() + chainlets_after = _global_chainlet_registry.get_chainlet_names() if target_name: target_cls = getattr(module, target_name, None) if not target_cls: raise AttributeError( - f"Target chainlet class `{target_name}` not found in `{module_path}`." + f"Target Chainlet class `{target_name}` not found in `{module_path}`." ) if not utils.issubclass_safe(target_cls, definitions.ABCChainlet): raise TypeError( @@ -828,16 +1056,16 @@ def import_target( entrypoints = _get_entrypoint_chainlets(module_vars) if len(entrypoints) == 0: raise ValueError( - f"No `target_name` was specified and no chainlet in `{module_path}` " - "was tagged with `@chains.mark_entrypoint`. Tag one chainlet or provide " - "the chainlet class name." + "No `target_name` was specified and no Chainlet in " + "`{module_path}` was tagged with `@chains.mark_entrypoint`. Tag " + "one Chainlet or provide the Chainlet class name." ) elif len(entrypoints) > 1: raise ValueError( - "`target_name` was not specified and multiple chainlets in " - f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag one " - "chainlet or provide the chainlet class name. Found chainlets: \n" - f"{entrypoints}" + "`target_name` was not specified and multiple Chainlets in " + f"`{module_path}` were tagged with `@chains.mark_entrypoint`. Tag " + "one Chainlet or provide the Chainlet class name. Found Chainlets: " + f"\n{entrypoints}" ) target_cls = utils.expect_one(entrypoints) if not utils.issubclass_safe(target_cls, definitions.ABCChainlet): @@ -848,7 +1076,7 @@ def import_target( yield target_cls finally: for chainlet_name in chainlets_after - chainlets_before: - global_chainlet_registry.unregister_chainlet(chainlet_name) + _global_chainlet_registry.unregister_chainlet(chainlet_name) modules_diff = modules_after - modules_before # Apparently torch import leaves some side effects that cannot be reverted diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 1b503fa72..d8817decb 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -84,7 +84,7 @@ class ChainletBase(definitions.ABCChainlet): def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) - framework.check_and_register_class(cls) + framework.validate_and_register_class(cls) # Errors are collected, not raised! # For default init (from `object`) we don't need to check anything. if cls.has_custom_init(): original_init = cls.__init__ diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 24a618e5a..1ccb7f667 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -50,7 +50,8 @@ def _push_to_baseten( assert model_name is not None assert bool(_MODEL_NAME_RE.match(model_name)) logging.info( - f"Pushing chainlet `{model_name}` as a truss model on Baseten (publish={options.publish})" + f"Pushing chainlet `{model_name}` as a truss model on " + f"Baseten (publish={options.publish})" ) # Models must be trusted to use the API KEY secret. service = options.remote_provider.push( @@ -147,20 +148,16 @@ def _get_ordered_dependencies( def add_needed_chainlets(chainlet: definitions.ChainletAPIDescriptor): needed_chainlets.add(chainlet) - for chainlet_descriptor in framework.global_chainlet_registry.get_dependencies( - chainlet - ): + for chainlet_descriptor in framework.get_dependencies(chainlet): needed_chainlets.add(chainlet_descriptor) add_needed_chainlets(chainlet_descriptor) for chainlet_cls in chainlets: - add_needed_chainlets( - framework.global_chainlet_registry.get_descriptor(chainlet_cls) - ) - # Iterating over the registry ensures topological ordering. + add_needed_chainlets(framework.get_descriptor(chainlet_cls)) + # Get dependencies in topological order. return [ descr - for descr in framework.global_chainlet_registry.chainlet_descriptors + for descr in framework.get_ordered_descriptors() if descr in needed_chainlets ] @@ -433,6 +430,7 @@ def push( raise NotImplementedError(self._options) +@framework.raise_validation_errors_before def push( entrypoint: Type[definitions.ABCChainlet], options: definitions.PushOptions, @@ -692,6 +690,7 @@ def watch(self, user_env: Optional[Mapping[str, str]]) -> None: self._console.print("👀 Watching for new changes.", style="blue") +@framework.raise_validation_errors_before def watch( source: pathlib.Path, entrypoint: Optional[str], diff --git a/truss/cli/cli.py b/truss/cli/cli.py index d505ed41b..b35aa09ac 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -96,7 +96,10 @@ def wrapper(*args, **kwargs): raise e # You can re-raise the exception or handle it different except Exception as e: if is_humanfriendly_log_level: - click.secho(f"ERROR: {type(e).__name__}: {e}", fg="red") + console.print( + f"[bold red]ERROR {type(e).__name__}[/bold red]: {e}", + highlight=True, + ) else: console.print_exception(show_locals=True)