diff --git a/docs/chains/doc_gen/README.md b/docs/chains/doc_gen/README.md new file mode 100644 index 000000000..f13a18490 --- /dev/null +++ b/docs/chains/doc_gen/README.md @@ -0,0 +1,18 @@ +This generation process of the documentation is *extremely* scrappy and just +an interim solution. It requires significant manual oversight and the code +quality in this directory is non-existent. + +The general process is: +1. Document as much as possible in the code, including usage examples, links + etc. +2. Auto-generate `generated-API-reference.mdx` with `poetry run python + docs/chains/doc_gen/generate_reference.py`. This applies the patch file and + launches meld to resolve conflicts. +4. Proofread `docs/snippets/chains/API-reference.mdx`. +5. If proofreading leads to edits or the upstream docstrings changed lot, + update the patch file: `diff -u \ + docs/chains/doc_gen/generated-reference.mdx \ + docs/snippets/chains/API-reference.mdx > \ + docs/chains/doc_gen/reference.patch` + +For questions, please reach out to @marius-baseten. diff --git a/docs/chains/doc_gen/generate_reference.py b/docs/chains/doc_gen/generate_reference.py new file mode 100644 index 000000000..de9c477c8 --- /dev/null +++ b/docs/chains/doc_gen/generate_reference.py @@ -0,0 +1,215 @@ +# type: ignore # This tool is only for Marius. +"""Script to auot-generate the API reference for Truss Chains.""" +import inspect +import pathlib +import shutil +import subprocess +import tempfile +from pathlib import Path + +import truss_chains as chains +from sphinx import application + +DUMMY_INDEX_RST = """ +.. Dummy + +Welcome to Truss Chains's documentation! +======================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + modules +""" + + +BUILDER = "mdx_adapter" # "html" +NON_PUBLIC_SYMBOLS = [ + # "truss_chains.definitions.AssetSpec", + # "truss_chains.definitions.ComputeSpec", + "truss_chains.deploy.ChainService", +] + + +SECTION_CHAINLET = ( + "Chainlets", + "APIs for creating user-defined Chainlets.", + [ + "truss_chains.ChainletBase", + "truss_chains.depends", + "truss_chains.depends_context", + "truss_chains.DeploymentContext", + "truss_chains.RPCOptions", + "truss_chains.mark_entrypoint", + ], +) +SECTION_CONFIG = ( + "Remote Configuration", + ( + "These data structures specify for each chainlet how it gets deployed " + "remotely, e.g. dependencies and compute resources." + ), + [ + "truss_chains.RemoteConfig", + "truss_chains.DockerImage", + "truss_chains.Compute", + "truss_chains.Assets", + ], +) +SECTION_UTILITIES = ( + "Core", + "General framework and helper functions.", + [ + "truss_chains.deploy_remotely", + "truss_chains.deploy.ChainService", + "truss_chains.make_abs_path_here", + "truss_chains.run_local", + "truss_chains.ServiceDescriptor", + "truss_chains.StubBase", + "truss_chains.RemoteErrorDetail", + # "truss_chains.ChainsRuntimeError", + ], +) + +SECTIONS = [SECTION_CHAINLET, SECTION_CONFIG, SECTION_UTILITIES] + + +def _list_imported_symbols(module: object) -> dict[str, str]: + imported_symbols = { + f"truss_chains.{name}": ( + "autoclass" + if inspect.isclass(obj) + else "autofunction" + if inspect.isfunction(obj) + else "autodata" + ) + for name, obj in inspect.getmembers(module) + if not name.startswith("_") and not inspect.ismodule(obj) + } + # Extra classes that are not really exported as public API, but are still relevant. + imported_symbols.update({sym: "autoclass" for sym in NON_PUBLIC_SYMBOLS}) + return imported_symbols + + +def _make_rst_structure(chains): + exported_symbols = _list_imported_symbols(chains) + rst_parts = ["API Reference"] + rst_parts.append("=" * len(rst_parts[-1]) + "\n") + + for name, descr, symbols in SECTIONS: + rst_parts.append(name) + rst_parts.append("=" * len(rst_parts[-1]) + "\n") + rst_parts.append(descr) + rst_parts.append("\n") + + for symbol in symbols: + kind = exported_symbols.pop(symbol) + rst_parts.append(f".. {kind}:: {symbol}") + rst_parts.append("\n") + + if exported_symbols: + raise ValueError( + "All symbols must be mapped to a section. Left over:" + f"{list(exported_symbols.keys())}." + ) + return "\n".join(rst_parts) + + +def _clean_build_directory(build_dir: Path) -> None: + if build_dir.exists() and build_dir.is_dir(): + shutil.rmtree(build_dir) + build_dir.mkdir(parents=True, exist_ok=True) + + +def _apply_patch( + original_file_path: str, patch_file_path: str, output_file_path: str +) -> None: + original_file = Path(original_file_path) + patch_file = Path(patch_file_path) + output_file = Path(output_file_path) + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_output_file_path = Path(temp_file.name) + + try: + subprocess.run( + [ + "patch", + str(original_file), + "-o", + str(temp_output_file_path), + str(patch_file), + ], + check=True, + capture_output=True, + text=True, + ) + + # Copy temp file to final output if no errors + shutil.copy(temp_output_file_path, output_file) + + except subprocess.CalledProcessError as e: + reject_file = temp_output_file_path.with_suffix(".rej") + if reject_file.exists(): + print(f"Conflicts found, saved to {reject_file}") + subprocess.run( + [ + "meld", + str(original_file_path), + str(output_file), + str(temp_output_file_path), + ], + check=True, + ) + else: + print(f"Patch failed: {e.stderr}") + + finally: + if temp_output_file_path.exists(): + temp_output_file_path.unlink() + + +def generate_sphinx_docs( + output_dir: pathlib.Path, + snippets_dir: pathlib.Path, +) -> None: + _clean_build_directory(output_dir) + config_file = pathlib.Path(__file__).parent / "sphinx_config.py" + docs_dir = output_dir / "docs" + conf_dir = docs_dir + doctree_dir = docs_dir / "doctrees" + + docs_dir.mkdir(parents=True, exist_ok=True) + (docs_dir / "conf.py").write_text(config_file.read_text()) + (docs_dir / "index.rst").write_text(DUMMY_INDEX_RST) + (docs_dir / "modules.rst").write_text(_make_rst_structure(chains)) + + app = application.Sphinx( + srcdir=str(docs_dir), + confdir=str(conf_dir), + outdir=str(Path(output_dir).resolve()), + doctreedir=str(doctree_dir), + buildername=BUILDER, + ) + app.build() + if BUILDER == "mdx_adapter": + dog_gen_dir = pathlib.Path(__file__).parent.absolute() + generated_reference_path = dog_gen_dir / "generated-reference.mdx" + shutil.copy(output_dir / "modules.mdx", generated_reference_path) + patch_file_path = dog_gen_dir / "reference.patch" + # Apply patch to generated_reference_path + snippets_reference_path = snippets_dir / "chains/API-reference.mdx" + _apply_patch( + str(generated_reference_path), + str(patch_file_path), + str(snippets_reference_path), + ) + + +if __name__ == "__main__": + snippets_dir = pathlib.Path(__file__).parent.parent.parent.absolute() / "snippets" + generate_sphinx_docs( + output_dir=pathlib.Path("/tmp/doc_gen"), + snippets_dir=snippets_dir, + ) diff --git a/docs/chains/doc_gen/generated-reference.mdx b/docs/chains/doc_gen/generated-reference.mdx new file mode 100644 index 000000000..910c10995 --- /dev/null +++ b/docs/chains/doc_gen/generated-reference.mdx @@ -0,0 +1,630 @@ +# API Reference + +# Chainlets + +APIs for creating user-defined Chainlets. + +### *class* `truss_chains.ChainletBase` + +Base class for all chainlets. + +Inheriting from this class adds validations to make sure subclasses adhere to the +chainlet pattern and facilitates remote chainlet deployment. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on how to create subclasses. + +### `truss_chains.depends` + +Sets a “symbolic marker” to indicate to the framework that a chainlet is a +dependency of another chainlet. The return value of `depends` is intended to be +used as a default argument in a chainlet’s `__init__`-method. +When deploying a chain remotely, a corresponding stub to the remote is injected in +its place. In `run_local` mode an instance of a local chainlet is injected. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on how make one chainlet depend on another chainlet. + +#### WARNING +Despite the type annotation, this does *not* immediately provide a +chainlet instance. Only when deploying remotely or using `run_local` a +chainlet instance is provided. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `chainlet_cls` | *Type[ChainletT]* | The chainlet class of the dependency. | +| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | + +* **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. +* **Return type:** + *ChainletT* + +### `truss_chains.depends_context` + +Sets a “symbolic marker” for injecting a context object at runtime. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on the `__init__`-signature of chainlets. + +#### WARNING +Despite the type annotation, this does *not* immediately provide a +context instance. Only when deploying remotely or using `run_local` a +context instance is provided. + +* **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. +* **Return type:** + [*DeploymentContext*](#truss_chains.DeploymentContext) + +### *class* `truss_chains.DeploymentContext` + +Bases: `pydantic.BaseModel`, `Generic`[`UserConfigT`] + +Bundles config values and resources needed to instantiate Chainlets. + +This is provided at runtime to the Chainlet’s `__init__` method. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | +| `user_config` | *UserConfigT* | User-defined configuration for the chainlet. | +| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +| `secrets` | *MappingNoIter[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | + +#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#truss_chains.ServiceDescriptor)]* + +#### data_dir *: Path | None* + +#### get_baseten_api_key() + +* **Return type:** + str + +#### get_service_descriptor(chainlet_name) + +* **Parameters:** + **chainlet_name** (*str*) +* **Return type:** + [*ServiceDescriptor*](#truss_chains.ServiceDescriptor) + +#### secrets *: MappingNoIter[str, str]* + +#### user_config *: UserConfigT* + +### *class* `truss_chains.RPCOptions` + +Bases: `pydantic.BaseModel` + +Options to customize RPCs to dependency chainlets. + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `timeout_sec` | *int* | | +| `retries` | *int* | | + + +#### retries *: int* + +#### timeout_sec *: int* + +### `truss_chains.mark_entrypoint` + +Decorator to mark a chainlet as the entrypoint of a chain. + +This decorator can be applied to *one* chainlet in a source file and then the +CLI deploy command simplifies because only the file, but not the chainlet class +in the file needs to be specified. + +Example usage: + +```default +import truss_chains as chains + +@chains.mark_entrypoint +class MyChainlet(ChainletBase): + ... +``` + +* **Parameters:** + **cls** (*Type* *[**ChainletT* *]*) +* **Return type:** + *Type*[*ChainletT*] + +# Remote Configuration + +These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources. + +### *class* `truss_chains.RemoteConfig` + +Bases: `pydantic.BaseModel` + +Bundles config values needed to deploy a chainlet remotely.. + +This is specified as a class variable for each chainlet class, e.g.: + +```default +import truss_chains as chains + + +class MyChainlet(chains.ChainletBase): + remote_config = chains.RemoteConfig( + docker_image=chains.DockerImage( + pip_requirements=["torch==2.0.1", ... ] + ), + compute=chains.Compute(cpu_count=2, gpu="A10G", ...), + assets=chains.Assets(secret_keys=["hf_access_token"], ...), + ) +``` + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `docker_image` | *[DockerImage](#truss_chains.DockerImage* | | +| `compute` | *[Compute](#truss_chains.Compute* | | +| `assets` | *[Assets](#truss_chains.Assets* | | +| `name` | *str\|None* | | + + +#### assets *: [Assets](#truss_chains.Assets)* + +#### compute *: [Compute](#truss_chains.Compute)* + +#### docker_image *: [DockerImage](#truss_chains.DockerImage)* + +#### get_asset_spec() + +* **Return type:** + *AssetSpec* + +#### get_compute_spec() + +* **Return type:** + *ComputeSpec* + +#### name *: str | None* + +### *class* `truss_chains.DockerImage` + +Bases: `pydantic.BaseModel` + +Configures the docker image in which a remoted chainlet is deployed. + +#### NOTE +Any paths are relative to the source file where `DockerImage` is +defined and must be created with the helper function `make_abs_path_here`. +This allows you for example organize chainlets in different (potentially nested) +modules and keep their requirement files right next their python source files. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `base_image` | *str* | The base image to use for the chainlet. Default is `python:3.11-slim`. | +| `pip_requirements_file` | *AbsPath\|None* | Path to a file containing pip requirements. The file content is naively concatenated with `pip_requirements`. | +| `pip_requirements` | *list[str]* | A list of pip requirements to install. The items are naively concatenated with the content of the `pip_requirements_file`. | +| `apt_requirements` | *list[str]* | A list of apt requirements to install. | +| `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. | +| `external_package_dirs` | *list[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. | + +#### apt_requirements *: list[str]* + +#### base_image *: str* + +#### data_dir *: AbsPath | None* + +#### external_package_dirs *: list[AbsPath] | None* + +#### pip_requirements *: list[str]* + +#### pip_requirements_file *: AbsPath | None* + +### *class* `truss_chains.Compute` + +Specifies which compute resources a chainlet has in the *remote* deployment. + +#### NOTE +Not all combinations can be exactly satisfied by available hardware, in some +cases more powerful machine types are chosen to make sure requirements are met or +over-provisioned. Refer to the +[baseten instance reference](https://docs.baseten.co/performance/instances). + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `cpu_count` | *int* | Minimum number of CPUs to allocate. | +| `memory` | *str* | Minimum memory to allocate, e.g. “2Gi” (2 gibibytes). | +| `gpu` | *str\|Accelerator\|None* | GPU accelerator type, e.g. “A10G”, “A100”, refer to the [truss config](https://truss.baseten.co/reference/config#resources-accelerator) for more choices. | +| `gpu_count` | *int* | Number of GPUs to allocate. | +| `predict_concurrency` | *int\|Literal['cpu_count']* | Number of concurrent requests a single replica of a deployed chainlet handles. | + + +Concurrency concepts are explained in [this guide](https://truss.baseten.co/guides/concurrency). +It is important to understand the difference between predict_concurrency and +the concurrency target (used for autoscaling, i.e. adding or removing replicas). +Furthermore, the `predict_concurrency` of a single instance is implemented in +two ways: + +- Via python’s `asyncio`, if `run_remote` is an async def. This + requires that `run_remote` yields to the event loop. +- With a threadpool if it’s a synchronous function. This requires + that the threads don’t have significant CPU load (due to the GIL). + +#### get_spec() + +* **Return type:** + *ComputeSpec* + +### *class* `truss_chains.Assets` + +Specifies which assets a chainlet can access in the remote deployment. + +Model weight caching can be used like this: + +```default +import truss_chains as chains +from truss import truss_config + +mistral_cache = truss_config.ModelRepo( + repo_id="mistralai/Mistral-7B-Instruct-v0.2", + allow_patterns=["*.json", "*.safetensors", ".model"] + ) +chains.Assets(cached=[mistral_cache], ...) +``` + +See [truss caching guide](https://truss.baseten.co/guides/model-cache#enabling-caching-for-a-model) +for more details on caching. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `cached` | *Iterable[ModelRepo]* | One or more `truss_config.ModelRepo` objects. | +| `secret_keys` | *Iterable[str]* | Names of secrets stored on baseten, that the chainlet should have access to. You can manage secrets on baseten [here](https://app.baseten.co/settings/secrets). | + + +#### get_spec() + +Returns parsed and validated assets. + +* **Return type:** + *AssetSpec* + +# Core + +General framework and helper functions. + +### `truss_chains.deploy_remotely` + +Deploys a chain remotely (with all dependent chainlets). + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `entrypoint` | *Type[ABCChainlet]* | The chainlet class that serves as the entrypoint to the chain. | +| `chain_name` | *str* | The name of the chain. | +| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | +| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | +| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | + +* **Returns:** + A chain service handle to the deployed chain. +* **Return type:** + [*ChainService*](#truss_chains.deploy.ChainService) + +### *class* `truss_chains.deploy.ChainService` + +Handle for a deployed chain. + +A `ChainService` is created and returned when using `deploy_remotely`. It +bundles the individual services for each chainlet in the chain, and provides +utilities to query their status, invoke the entrypoint etc. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `entrypoint` | *str* | Name of the entrypoint chainlet. | +| `name` | *str* | Name of the chain. | + + +#### add_service(name, service) + +Used to add a chainlet service during the deployment sequence of the chain. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `name` | *str* | Chainlet name. | +| `service` | *TrussService* | Service object for the chainlet. | + +* **Return type:** + None + +#### *property* entrypoint_fake_json_data *: Any* + +Fake JSON example data that matches the entrypoint’s input schema. +This property must be externally populated. + +* **Raises:** + **ValueError** – If fake data was not set. + +#### *property* entrypoint_name *: str* + +#### *property* get_entrypoint *: TrussService* + +Returns the entrypoint’s service handle. + +* **Raises:** + **MissingDependencyError** – If the entrypoint service was not added. + +#### get_info() + +Queries the statuses of all chainlets in the chain. + +* **Returns:** + List with elements `(name, status, logs_url)` for each chainlet. +* **Return type:** + list[tuple[str, str, str]] + +#### name *: str* + +#### run_remote(json) + +Invokes the entrypoint with JSON data. + +* **Returns:** + The JSON response. +* **Parameters:** + **json** (*Dict*) +* **Return type:** + *Any* + +#### *property* run_url *: str* + +URL to invoke the entrypoint. + +#### *property* services *: MutableMapping[str, TrussService]* + +### `truss_chains.make_abs_path_here` + +Helper to specify file paths relative to the *immediately calling* module. + +E.g. in you have a project structure like this: + +```default +root/ + chain.py + common_requirements.text + sub_package/ + chainlet.py + chainlet_requirements.txt +``` + +You can now in `root/sub_package/chainlet.py` point to the requirements +file like this: + +```default +shared = RelativePathToHere("../common_requirements.text") +specific = RelativePathToHere("chainlet_requirements.text") +``` + +#### WARNING +This helper uses the directory of the immediately calling module as an +absolute reference point for resolving the file location. Therefore, +you MUST NOT wrap the instantiation of `make_abs_path_here` into a +function (e.g. applying decorators) or use dynamic code execution. + +Ok: + +```default +def foo(path: AbsPath): + abs_path = path.abs_path + +foo(make_abs_path_here("./somewhere")) +``` + +Not Ok: + +```default +def foo(path: str): + dangerous_value = make_abs_path_here(path).abs_path + +foo("./somewhere") +``` + +* **Parameters:** + **file_path** (*str*) +* **Return type:** + *AbsPath* + +### `truss_chains.run_local` + +Context manager local debug execution of a chain. + +The arguments only need to be provided if the chainlets explicitly access any the +corresponding fields of `DeploymentContext`. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | +| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | +| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A dict of chainlet names to service descriptors. | + +* **Return type:** + *ContextManager*[None] + +Example usage (as trailing main section in a chain file): + +```default +import os +import truss_chains as chains + + +class HelloWorld(chains.ChainletBase): + ... + + +if __name__ == "__main__": + with chains.run_local( + secrets={"some_token": os.environ["SOME_TOKEN"]}, + chainlet_to_service={ + "SomeChainlet": chains.ServiceDescriptor( + name="SomeChainlet", + predict_url="https://...", + options=chains.RPCOptions(), + ) + }, + ): + hello_world_chain = HelloWorld() + result = hello_world_chain.run_remote(max_value=5) + + print(result) +``` + +Refer to the [local debugging guide](https://truss.baseten.co/chains/guide#local-debugging) +for more details. + +### *class* `truss_chains.ServiceDescriptor` + +Bases: `pydantic.BaseModel` + +Bundles values to establish an RPC session to a dependency chainlet, +specifically with `StubBase`. + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `name` | *str* | | +| `predict_url` | *str* | | +| `options` | *[RPCOptions](#truss_chains.RPCOptions* | | + + +#### name *: str* + +#### options *: [RPCOptions](#truss_chains.RPCOptions)* + +#### predict_url *: str* + +### *class* `truss_chains.StubBase` + +Bases: `ABC` + +Base class for stubs that invoke remote chainlets. + +It is used internally for RPCs to dependency chainlets, but it can also be used +in user-code for wrapping a deployed truss model into the chains framework, e.g. +like that: + +```default +import pydantic +import truss_chains as chains + +class WhisperOutput(pydantic.BaseModel): + ... + + +class DeployedWhisper(chains.StubBase): + + async def run_remote(self, audio_b64: str) -> WhisperOutput: + resp = await self._remote.predict_async(json_payload={"audio": audio_b64}) + return WhisperOutput(text=resp["text"], language==resp["language"]) + + +class MyChainlet(chains.ChainletBase): + + def __init__(self, ..., context = chains.depends_context()): + ... + self._whisper = DeployedWhisper.from_url( + WHISPER_URL, + context, + options=chains.RPCOptions(retries=3), + ) +``` + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `service_descriptor` | *[ServiceDescriptor](#truss_chains.ServiceDescriptor* | Contains the URL and other configuration. | +| `api_key` | *str* | A baseten API key to authorize requests. | + + +#### *classmethod* from_url(predict_url, context, options=None) + +Factory method, convenient to be used in chainlet’s `__init__`-method. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | +| `context` | *[DeploymentContext](#truss_chains.DeploymentContext* | Deployment context object, obtained in the chainlet’s `__init__`. | +| `options` | *[RPCOptions](#truss_chains.RPCOptions* | RPC options, e.g. retries. | + + +### *class* `truss_chains.RemoteErrorDetail` + +Bases: `pydantic.BaseModel` + +When a remote chainlet raises an exception, this pydantic model contains +information about the error and stack trace and is included in JSON form in the +error response. + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `remote_name` | *str* | | +| `exception_cls_name` | *str* | | +| `exception_module_name` | *str\|None* | | +| `exception_message` | *str* | | +| `user_stack_trace` | *list[StackFrame]* | | + + +#### exception_cls_name *: str* + +#### exception_message *: str* + +#### exception_module_name *: str | None* + +#### format() + +Format the error for printing, similar to how Python formats exceptions +with stack traces. + +* **Return type:** + str + +#### remote_name *: str* + +#### user_stack_trace *: list[StackFrame]* diff --git a/docs/chains/doc_gen/mdx_adapter.py b/docs/chains/doc_gen/mdx_adapter.py new file mode 100644 index 000000000..53a40adce --- /dev/null +++ b/docs/chains/doc_gen/mdx_adapter.py @@ -0,0 +1,135 @@ +# type: ignore # This tool is only for Marius. +"""Super hacky plugin to make the generated markdown more suitable for +rendering in mintlify as an mdx doc.""" +import os +import re +from typing import Any, Dict + +from docutils import nodes +from docutils.io import StringOutput +from generate_reference import NON_PUBLIC_SYMBOLS +from sphinx.util.osutil import ensuredir, os_path +from sphinx_markdown_builder import MarkdownBuilder +from sphinx_markdown_builder.translator import MarkdownTranslator + +PYDANTIC_DEFAULT_DOCSTRING = """ + +Create a new model by parsing and validating input data from keyword arguments. + +Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be +validated to form a valid model. + +self is explicitly positional-only to allow self as a field name. +""" + + +class MDXAdapterTranslator(MarkdownTranslator): + ... + + +def extract_and_format_parameters_section(content: str) -> str: + def format_as_table(items: list[tuple[str, str, str]]) -> str: + header = "| Name | Type | Description |\n|------|------|-------------|\n" + rows = [ + f"| {name} | {typ} | {description} |" for name, typ, description in items + ] + return header + "\n".join(rows) + + pattern = r"(\* \*\*Parameters:\*\*\n((?: {2}\* .+(?:\n {4}.+)*\n?)+))" + matches = re.findall(pattern, content) + + for match in matches: + list_block = match[1] + list_items = re.findall(r"( {2}\* .+(?:\n {4}.+)*)", list_block) + + extracted_items = [] + for item in list_items: + item = item.replace("\n ", " ") + parts = item.split(" – ", 1) + if len(parts) == 2: + name_type, description = parts + else: + name_type = parts[0] + description = "" + + name_match = re.search(r"\*\*(.+?)\*\*", name_type) + type_match = re.search(r"\((.+?)\)", name_type) + name = name_match.group(1) if name_match else "" + typ = type_match.group(1) if type_match else "" + typ = typ.replace("*", "").replace(" ", "").replace("|", r"\|") + typ = f"*{typ}*" + name = f"`{name}`" + extracted_items.append((name, typ, description.strip())) + + table = format_as_table(extracted_items) + content = content.replace(match[0], f"\n**Parameters:**\n\n{table}\n\n") + + return content + + +def _line_replacements(line: str) -> str: + if line.startswith("### *class*"): + line = line.replace("### *class*", "").strip() + if not any(sym in line for sym in NON_PUBLIC_SYMBOLS): + line = line.replace("truss_chains.definitions", "truss_chains") + first_brace = line.find("(") + if first_brace > 0: + line = line[:first_brace] + return f"### *class* `{line}`" + elif line.startswith("### "): + line = line.replace("### ", "").strip() + if not any(sym in line for sym in NON_PUBLIC_SYMBOLS): + line = line.replace("truss_chains.definitions", "truss_chains") + first_brace = line.find("(") + if first_brace > 0: + line = line[:first_brace] + return f"### `{line}`" + + return line + + +def _raw_text_replacements(doc_text: str) -> str: + doc_text = doc_text.replace(PYDANTIC_DEFAULT_DOCSTRING, "") + doc_text = doc_text.replace("Bases: `object`\n\n", "") + doc_text = doc_text.replace("Bases: `ABCChainlet`\n\n", "") + doc_text = doc_text.replace("Bases: `SafeModel`", "Bases: `pydantic.BaseModel`") + doc_text = doc_text.replace( + "Bases: `SafeModelNonSerializable`", "Bases: `pydantic.BaseModel`" + ) + doc_text = doc_text.replace("<", "<").replace(">", ">") + doc_text = "\n".join(_line_replacements(line) for line in doc_text.split("\n")) + doc_text = extract_and_format_parameters_section(doc_text) + return doc_text + + +class MDXAdapterBuilder(MarkdownBuilder): + name = "mdx_adapter" + out_suffix = ".mdx" + default_translator_class = MDXAdapterTranslator + + def get_translator_class(self): + return MDXAdapterTranslator + + def get_target_uri(self, docname: str, typ: str = None) -> str: + return f"{docname}.mdx" + + def write_doc(self, docname: str, doctree: nodes.document): + self.current_doc_name = docname + self.sec_numbers = self.env.toc_secnumbers.get(docname, {}) + destination = StringOutput(encoding="utf-8") + self.writer.write(doctree, destination) + out_filename = os.path.join(self.outdir, f"{os_path(docname)}{self.out_suffix}") + ensuredir(os.path.dirname(out_filename)) + + with open(out_filename, "w", encoding="utf-8") as file: + # These replacements are custom, the rest of this method is unchanged. + file.write(_raw_text_replacements(self.writer.output)) + + +def setup(app: Any) -> Dict[str, Any]: + app.add_builder(MDXAdapterBuilder) + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/chains/doc_gen/reference.patch b/docs/chains/doc_gen/reference.patch new file mode 100644 index 000000000..e0c857965 --- /dev/null +++ b/docs/chains/doc_gen/reference.patch @@ -0,0 +1,497 @@ +--- docs/chains/doc_gen/generated-reference.mdx 2024-06-17 14:57:58.625022632 -0700 ++++ docs/snippets/chains/API-reference.mdx 2024-06-17 15:03:31.415948834 -0700 +@@ -21,30 +21,28 @@ + dependency of another chainlet. The return value of `depends` is intended to be + used as a default argument in a chainlet’s `__init__`-method. + When deploying a chain remotely, a corresponding stub to the remote is injected in +-its place. In `run_local` mode an instance of a local chainlet is injected. ++its place. In [`run_local`](#truss-chains-run-local) mode an instance of a local chainlet is injected. + + Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this + [example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) + for more guidance on how make one chainlet depend on another chainlet. + +-#### WARNING ++ + Despite the type annotation, this does *not* immediately provide a + chainlet instance. Only when deploying remotely or using `run_local` a + chainlet instance is provided. +- ++ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `chainlet_cls` | *Type[ChainletT]* | The chainlet class of the dependency. | +-| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | ++| Name | Type | Description | ++|------|----------------------|-------------| ++| `chainlet_cls` | *Type[ChainletBase]* | The chainlet class of the dependency. | ++| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | + + * **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. +-* **Return type:** +- *ChainletT* + + ### `truss_chains.depends_context` + +@@ -54,20 +52,19 @@ + [example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) + for more guidance on the `__init__`-signature of chainlets. + +-#### WARNING ++ + Despite the type annotation, this does *not* immediately provide a + context instance. Only when deploying remotely or using `run_local` a + context instance is provided. ++ + + * **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. +-* **Return type:** +- [*DeploymentContext*](#truss_chains.DeploymentContext) + +-### *class* `truss_chains.DeploymentContext` ++### *class* `truss_chains.DeploymentContext(Generic[UserConfigT])` + +-Bases: `pydantic.BaseModel`, `Generic`[`UserConfigT`] ++Bases: `pydantic.BaseModel` + + Bundles config values and resources needed to instantiate Chainlets. + +@@ -76,14 +73,14 @@ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | +-| `user_config` | *UserConfigT* | User-defined configuration for the chainlet. | +-| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +-| `secrets` | *MappingNoIter[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | ++| Name | Type | Description | ++|------|---------------------------------------------------------------------|-------------| ++| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | ++| `user_config` | *UserConfigT* | User-defined configuration for the chainlet. | ++| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | ++| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | + +-#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#truss_chains.ServiceDescriptor)]* ++#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#class-truss-chains-servicedescriptor)]* + + #### data_dir *: Path | None* + +@@ -94,12 +91,16 @@ + + #### get_service_descriptor(chainlet_name) + +-* **Parameters:** +- **chainlet_name** (*str*) ++**Parameters:** ++ ++| Name | Type | Description | ++|-------------------|---------|---------------------------| ++| `chainlet_name` | *str* | The name of the chainlet. | ++ + * **Return type:** +- [*ServiceDescriptor*](#truss_chains.ServiceDescriptor) ++ [*ServiceDescriptor*](#class-truss-chains-servicedescriptor) + +-#### secrets *: MappingNoIter[str, str]* ++#### secrets *: Mapping[str, str]* + + #### user_config *: UserConfigT* + +@@ -117,10 +118,6 @@ + | `retries` | *int* | | + + +-#### retries *: int* +- +-#### timeout_sec *: int* +- + ### `truss_chains.mark_entrypoint` + + Decorator to mark a chainlet as the entrypoint of a chain. +@@ -131,7 +128,7 @@ + + Example usage: + +-```default ++```python + import truss_chains as chains + + @chains.mark_entrypoint +@@ -139,10 +136,14 @@ + ... + ``` + +-* **Parameters:** +- **cls** (*Type* *[**ChainletT* *]*) ++**Parameters:** ++ ++| Name | Type | Description | ++|-------------------|---------------------------|---------------------| ++| `cls` | *Type[ChainletBase]* | The chainlet class. | ++ + * **Return type:** +- *Type*[*ChainletT*] ++ *Type*[*ChainletBase*] + + # Remote Configuration + +@@ -156,7 +157,7 @@ + + This is specified as a class variable for each chainlet class, e.g.: + +-```default ++```python + import truss_chains as chains + + +@@ -172,31 +173,12 @@ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `docker_image` | *[DockerImage](#truss_chains.DockerImage* | | +-| `compute` | *[Compute](#truss_chains.Compute* | | +-| `assets` | *[Assets](#truss_chains.Assets* | | +-| `name` | *str\|None* | | +- +- +-#### assets *: [Assets](#truss_chains.Assets)* +- +-#### compute *: [Compute](#truss_chains.Compute)* +- +-#### docker_image *: [DockerImage](#truss_chains.DockerImage)* +- +-#### get_asset_spec() +- +-* **Return type:** +- *AssetSpec* +- +-#### get_compute_spec() +- +-* **Return type:** +- *ComputeSpec* +- +-#### name *: str | None* ++| Name | Type | Description | ++|------|--------------------------------------------------|-------------| ++| `docker_image` | *[DockerImage](#class-truss-chains-dockerimage)* | | ++| `compute` | *[Compute](#class-truss-chains-compute)* | | ++| `assets` | *[Assets](#class-truss-chains-assets)* | | ++| `name` | *str\|None* | | + + ### *class* `truss_chains.DockerImage` + +@@ -204,12 +186,12 @@ + + Configures the docker image in which a remoted chainlet is deployed. + +-#### NOTE ++ + Any paths are relative to the source file where `DockerImage` is +-defined and must be created with the helper function `make_abs_path_here`. ++defined and must be created with the helper function [`make_abs_path_here`](#truss-chains-make-abs-path-here). + This allows you for example organize chainlets in different (potentially nested) + modules and keep their requirement files right next their python source files. +- ++ + + **Parameters:** + +@@ -222,28 +204,16 @@ + | `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. | + | `external_package_dirs` | *list[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. | + +-#### apt_requirements *: list[str]* +- +-#### base_image *: str* +- +-#### data_dir *: AbsPath | None* +- +-#### external_package_dirs *: list[AbsPath] | None* +- +-#### pip_requirements *: list[str]* +- +-#### pip_requirements_file *: AbsPath | None* +- + ### *class* `truss_chains.Compute` + + Specifies which compute resources a chainlet has in the *remote* deployment. + +-#### NOTE ++ + Not all combinations can be exactly satisfied by available hardware, in some + cases more powerful machine types are chosen to make sure requirements are met or + over-provisioned. Refer to the + [baseten instance reference](https://docs.baseten.co/performance/instances). +- ++ + + **Parameters:** + +@@ -278,7 +248,7 @@ + + Model weight caching can be used like this: + +-```default ++```python + import truss_chains as chains + from truss import truss_config + +@@ -321,7 +291,7 @@ + + | Name | Type | Description | + |------|------|-------------| +-| `entrypoint` | *Type[ABCChainlet]* | The chainlet class that serves as the entrypoint to the chain. | ++| `entrypoint` | *Type[ChainletBase]* | The chainlet class that serves as the entrypoint to the chain. | + | `chain_name` | *str* | The name of the chain. | + | `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | + | `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | +@@ -330,14 +300,14 @@ + * **Returns:** + A chain service handle to the deployed chain. + * **Return type:** +- [*ChainService*](#truss_chains.deploy.ChainService) ++ [*ChainService*](#class-truss-chains-deploy-chainservice) + + ### *class* `truss_chains.deploy.ChainService` + + Handle for a deployed chain. + +-A `ChainService` is created and returned when using `deploy_remotely`. It +-bundles the individual services for each chainlet in the chain, and provides ++A `ChainService` is created and returned when using [`deploy_remotely`](#truss-chains-deploy-remotely). ++It bundles the individual services for each chainlet in the chain, and provides + utilities to query their status, invoke the entrypoint etc. + + +@@ -364,14 +334,6 @@ + * **Return type:** + None + +-#### *property* entrypoint_fake_json_data *: Any* +- +-Fake JSON example data that matches the entrypoint’s input schema. +-This property must be externally populated. +- +-* **Raises:** +- **ValueError** – If fake data was not set. +- + #### *property* entrypoint_name *: str* + + #### *property* get_entrypoint *: TrussService* +@@ -390,18 +352,19 @@ + * **Return type:** + list[tuple[str, str, str]] + +-#### name *: str* +- + #### run_remote(json) + + Invokes the entrypoint with JSON data. + ++**Parameters:** ++| Name | Type | Description | ++|------|------|-------------| ++| `json` | *JSON Dict* | Request payload. | ++ + * **Returns:** + The JSON response. +-* **Parameters:** +- **json** (*Dict*) + * **Return type:** +- *Any* ++ *JSON Dict* + + #### *property* run_url *: str* + +@@ -427,12 +390,12 @@ + You can now in `root/sub_package/chainlet.py` point to the requirements + file like this: + +-```default ++```python + shared = RelativePathToHere("../common_requirements.text") + specific = RelativePathToHere("chainlet_requirements.text") + ``` + +-#### WARNING ++ + This helper uses the directory of the immediately calling module as an + absolute reference point for resolving the file location. Therefore, + you MUST NOT wrap the instantiation of `make_abs_path_here` into a +@@ -440,7 +403,7 @@ + + Ok: + +-```default ++```python + def foo(path: AbsPath): + abs_path = path.abs_path + +@@ -449,15 +412,20 @@ + + Not Ok: + +-```default ++```python + def foo(path: str): + dangerous_value = make_abs_path_here(path).abs_path + + foo("./somewhere") + ``` ++ + +-* **Parameters:** +- **file_path** (*str*) ++**Parameters:** ++ ++| Name | Type | Description | ++|-------------|---------|----------------------------| ++| `file_path` | *str* | Absolute or relative path. | ++* + * **Return type:** + *AbsPath* + +@@ -466,23 +434,23 @@ + Context manager local debug execution of a chain. + + The arguments only need to be provided if the chainlets explicitly access any the +-corresponding fields of `DeploymentContext`. ++corresponding fields of [`DeploymentContext`](#class-truss-chains-deploymentcontext-generic-userconfigt). + + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | +-| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | +-| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A dict of chainlet names to service descriptors. | ++| Name | Type | Description | ++|------|--------------------------------------------------------------------------|-------------| ++| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | ++| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | ++| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | A dict of chainlet names to service descriptors. | + + * **Return type:** + *ContextManager*[None] + + Example usage (as trailing main section in a chain file): + +-```default ++```python + import os + import truss_chains as chains + +@@ -520,22 +488,13 @@ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `name` | *str* | | +-| `predict_url` | *str* | | +-| `options` | *[RPCOptions](#truss_chains.RPCOptions* | | +- +- +-#### name *: str* +- +-#### options *: [RPCOptions](#truss_chains.RPCOptions)* ++| Name | Type | Description | ++|------|------------------------------------------------|-------------| ++| `name` | *str* | | ++| `predict_url` | *str* | | ++| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | | + +-#### predict_url *: str* +- +-### *class* `truss_chains.StubBase` +- +-Bases: `ABC` ++## *class* `truss_chains.StubBase` + + Base class for stubs that invoke remote chainlets. + +@@ -543,7 +502,7 @@ + in user-code for wrapping a deployed truss model into the chains framework, e.g. + like that: + +-```default ++```python + import pydantic + import truss_chains as chains + +@@ -553,7 +512,7 @@ + + class DeployedWhisper(chains.StubBase): + +- async def run_remote(self, audio_b64: str) -> WhisperOutput: ++ async def run_remote(self, audio_b64: str) -> WhisperOutput: + resp = await self._remote.predict_async(json_payload={"audio": audio_b64}) + return WhisperOutput(text=resp["text"], language==resp["language"]) + +@@ -572,10 +531,10 @@ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `service_descriptor` | *[ServiceDescriptor](#truss_chains.ServiceDescriptor* | Contains the URL and other configuration. | +-| `api_key` | *str* | A baseten API key to authorize requests. | ++| Name | Type | Description | ++|------|--------------------------------------------------------------|-------------| ++| `service_descriptor` | *[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | Contains the URL and other configuration. | ++| `api_key` | *str* | A baseten API key to authorize requests. | + + + #### *classmethod* from_url(predict_url, context, options=None) +@@ -585,12 +544,11 @@ + + **Parameters:** + +-| Name | Type | Description | +-|------|------|-------------| +-| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | +-| `context` | *[DeploymentContext](#truss_chains.DeploymentContext* | Deployment context object, obtained in the chainlet’s `__init__`. | +-| `options` | *[RPCOptions](#truss_chains.RPCOptions* | RPC options, e.g. retries. | +- ++| Name | Type | Description | ++|------|------------------------------------------------------------------------------|-------------| ++| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | ++| `context` | *[DeploymentContext](#class-truss-chains-deploymentcontext-generic-userconfigt)* | Deployment context object, obtained in the chainlet’s `__init__`. | ++| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | RPC options, e.g. retries. | + + ### *class* `truss_chains.RemoteErrorDetail` + +@@ -610,13 +568,6 @@ + | `exception_message` | *str* | | + | `user_stack_trace` | *list[StackFrame]* | | + +- +-#### exception_cls_name *: str* +- +-#### exception_message *: str* +- +-#### exception_module_name *: str | None* +- + #### format() + + Format the error for printing, similar to how Python formats exceptions +@@ -624,7 +575,3 @@ + + * **Return type:** + str +- +-#### remote_name *: str* +- +-#### user_stack_trace *: list[StackFrame]* diff --git a/docs/chains/doc_gen/sphinx_config.py b/docs/chains/doc_gen/sphinx_config.py new file mode 100644 index 000000000..13ac481ca --- /dev/null +++ b/docs/chains/doc_gen/sphinx_config.py @@ -0,0 +1,101 @@ +import os +import sys + +import sphinx_rtd_theme + +sys.path.insert(0, os.path.abspath(".")) + + +project = "Dummy" +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx_markdown_builder", + # "sphinx_markdown_parser", + "sphinx-pydantic", + # "myst_parser", + "mdx_adapter", +] +myst_enable_extensions = [ + "colon_fence", + "deflist", + "html_admonition", + "html_image", + "linkify", + "replacements", + "smartquotes", + "substitution", + "tasklist", +] +autodoc_default_options = { + "members": True, + "undoc-members": False, + "private-members": False, + "special-members": False, + "exclude-members": "__*", + "inherited-members": False, + "show-inheritance": True, +} + +# Other Options. +autodoc_typehints = "description" +always_document_param_types = True +# Include both class-level and __init__ docstrings +autoclass_content = "both" +# Napoleon (docstring parsing) +napoleon_google_docstring = True +napoleon_numpy_docstring = False +napoleon_use_param = True +napoleon_use_rtype = True +# HTML output. +html_theme = "sphinx_rtd_theme" +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + + +def skip_member(app, what, name, obj, skip, options): + if name == "Config" and isinstance(obj, type): + # print(options.parent) + return True + # Exclude Pydantic's Config class and internal attributes + pydantic_internal_attributes = { + "model_computed_fields", + "model_fields", + "model_json_schema", + "model_config", + } + if name in pydantic_internal_attributes: + # This shadows user defined usage of those names... + return True + return skip + + +def dump_doctree(app, doctree, docname: str): + output_file = f"/tmp/doc_gen/doctree_{docname}.txt" # Define the output file name + + def visit_node(node, depth=0, file=sys.stdout): + # Create a visual guide with indentation and vertical lines + indent = "│ " * depth + "├── " + newl = "\n" + node_text = node.astext()[:100].replace(newl, " ") + file.write( + f"{indent}{node.__class__.__name__}: `{node_text}`.\n" + ) # Write to file + + if not node.children: # Check if the node is a leaf node + empty_indent = "│ " * depth + file.write(f"{empty_indent}\n") + + for child in node.children: + visit_node(child, depth + 1, file) + + with open(output_file, "w") as file: # Open the file for writing + file.write(f"Dumping doctree for: {docname}\n") + visit_node( + doctree, file=file + ) # Pass the file handle to the visit_node function + file.write("\nFinished dumping doctree\n") + + +def setup(app): + app.connect("autodoc-skip-member", skip_member) + # app.connect("doctree-resolved", dump_doctree) diff --git a/docs/chains/full-reference.mdx b/docs/chains/full-reference.mdx index ec7cb2565..86619a658 100644 --- a/docs/chains/full-reference.mdx +++ b/docs/chains/full-reference.mdx @@ -10,6 +10,6 @@ import TOC from '/snippets/chains/TOC.mdx'; -# Coming soon +import Reference from '/snippets/chains/API-reference.mdx'; -For the time being, you can use IDE code-completions for reference. + diff --git a/docs/snippets/chains/API-reference.mdx b/docs/snippets/chains/API-reference.mdx new file mode 100644 index 000000000..6d1427b34 --- /dev/null +++ b/docs/snippets/chains/API-reference.mdx @@ -0,0 +1,577 @@ +# API Reference + +# Chainlets + +APIs for creating user-defined Chainlets. + +### *class* `truss_chains.ChainletBase` + +Base class for all chainlets. + +Inheriting from this class adds validations to make sure subclasses adhere to the +chainlet pattern and facilitates remote chainlet deployment. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on how to create subclasses. + +### `truss_chains.depends` + +Sets a “symbolic marker” to indicate to the framework that a chainlet is a +dependency of another chainlet. The return value of `depends` is intended to be +used as a default argument in a chainlet’s `__init__`-method. +When deploying a chain remotely, a corresponding stub to the remote is injected in +its place. In [`run_local`](#truss-chains-run-local) mode an instance of a local chainlet is injected. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on how make one chainlet depend on another chainlet. + + +Despite the type annotation, this does *not* immediately provide a +chainlet instance. Only when deploying remotely or using `run_local` a +chainlet instance is provided. + + +**Parameters:** + +| Name | Type | Description | +|------|----------------------|-------------| +| `chainlet_cls` | *Type[ChainletBase]* | The chainlet class of the dependency. | +| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | + +* **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. + +### `truss_chains.depends_context` + +Sets a “symbolic marker” for injecting a context object at runtime. + +Refer to [the docs](https://truss.baseten.co/chains/getting-started) and this +[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) +for more guidance on the `__init__`-signature of chainlets. + + +Despite the type annotation, this does *not* immediately provide a +context instance. Only when deploying remotely or using `run_local` a +context instance is provided. + + +* **Returns:** + A “symbolic marker” to be used as a default argument in a chainlet’s + initializer. + +### *class* `truss_chains.DeploymentContext(Generic[UserConfigT])` + +Bases: `pydantic.BaseModel` + +Bundles config values and resources needed to instantiate Chainlets. + +This is provided at runtime to the Chainlet’s `__init__` method. + + +**Parameters:** + +| Name | Type | Description | +|------|---------------------------------------------------------------------|-------------| +| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | +| `user_config` | *UserConfigT* | User-defined configuration for the chainlet. | +| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | + +#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#class-truss-chains-servicedescriptor)]* + +#### data_dir *: Path | None* + +#### get_baseten_api_key() + +* **Return type:** + str + +#### get_service_descriptor(chainlet_name) + +**Parameters:** + +| Name | Type | Description | +|-------------------|---------|---------------------------| +| `chainlet_name` | *str* | The name of the chainlet. | + +* **Return type:** + [*ServiceDescriptor*](#class-truss-chains-servicedescriptor) + +#### secrets *: Mapping[str, str]* + +#### user_config *: UserConfigT* + +### *class* `truss_chains.RPCOptions` + +Bases: `pydantic.BaseModel` + +Options to customize RPCs to dependency chainlets. + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `timeout_sec` | *int* | | +| `retries` | *int* | | + + +### `truss_chains.mark_entrypoint` + +Decorator to mark a chainlet as the entrypoint of a chain. + +This decorator can be applied to *one* chainlet in a source file and then the +CLI deploy command simplifies because only the file, but not the chainlet class +in the file needs to be specified. + +Example usage: + +```python +import truss_chains as chains + +@chains.mark_entrypoint +class MyChainlet(ChainletBase): + ... +``` + +**Parameters:** + +| Name | Type | Description | +|-------------------|---------------------------|---------------------| +| `cls` | *Type[ChainletBase]* | The chainlet class. | + +* **Return type:** + *Type*[*ChainletBase*] + +# Remote Configuration + +These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources. + +### *class* `truss_chains.RemoteConfig` + +Bases: `pydantic.BaseModel` + +Bundles config values needed to deploy a chainlet remotely.. + +This is specified as a class variable for each chainlet class, e.g.: + +```python +import truss_chains as chains + + +class MyChainlet(chains.ChainletBase): + remote_config = chains.RemoteConfig( + docker_image=chains.DockerImage( + pip_requirements=["torch==2.0.1", ... ] + ), + compute=chains.Compute(cpu_count=2, gpu="A10G", ...), + assets=chains.Assets(secret_keys=["hf_access_token"], ...), + ) +``` + +**Parameters:** + +| Name | Type | Description | +|------|--------------------------------------------------|-------------| +| `docker_image` | *[DockerImage](#class-truss-chains-dockerimage)* | | +| `compute` | *[Compute](#class-truss-chains-compute)* | | +| `assets` | *[Assets](#class-truss-chains-assets)* | | +| `name` | *str\|None* | | + +### *class* `truss_chains.DockerImage` + +Bases: `pydantic.BaseModel` + +Configures the docker image in which a remoted chainlet is deployed. + + +Any paths are relative to the source file where `DockerImage` is +defined and must be created with the helper function [`make_abs_path_here`](#truss-chains-make-abs-path-here). +This allows you for example organize chainlets in different (potentially nested) +modules and keep their requirement files right next their python source files. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `base_image` | *str* | The base image to use for the chainlet. Default is `python:3.11-slim`. | +| `pip_requirements_file` | *AbsPath\|None* | Path to a file containing pip requirements. The file content is naively concatenated with `pip_requirements`. | +| `pip_requirements` | *list[str]* | A list of pip requirements to install. The items are naively concatenated with the content of the `pip_requirements_file`. | +| `apt_requirements` | *list[str]* | A list of apt requirements to install. | +| `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. | +| `external_package_dirs` | *list[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. | + +### *class* `truss_chains.Compute` + +Specifies which compute resources a chainlet has in the *remote* deployment. + + +Not all combinations can be exactly satisfied by available hardware, in some +cases more powerful machine types are chosen to make sure requirements are met or +over-provisioned. Refer to the +[baseten instance reference](https://docs.baseten.co/performance/instances). + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `cpu_count` | *int* | Minimum number of CPUs to allocate. | +| `memory` | *str* | Minimum memory to allocate, e.g. “2Gi” (2 gibibytes). | +| `gpu` | *str\|Accelerator\|None* | GPU accelerator type, e.g. “A10G”, “A100”, refer to the [truss config](https://truss.baseten.co/reference/config#resources-accelerator) for more choices. | +| `gpu_count` | *int* | Number of GPUs to allocate. | +| `predict_concurrency` | *int\|Literal['cpu_count']* | Number of concurrent requests a single replica of a deployed chainlet handles. | + + +Concurrency concepts are explained in [this guide](https://truss.baseten.co/guides/concurrency). +It is important to understand the difference between predict_concurrency and +the concurrency target (used for autoscaling, i.e. adding or removing replicas). +Furthermore, the `predict_concurrency` of a single instance is implemented in +two ways: + +- Via python’s `asyncio`, if `run_remote` is an async def. This + requires that `run_remote` yields to the event loop. +- With a threadpool if it’s a synchronous function. This requires + that the threads don’t have significant CPU load (due to the GIL). + +#### get_spec() + +* **Return type:** + *ComputeSpec* + +### *class* `truss_chains.Assets` + +Specifies which assets a chainlet can access in the remote deployment. + +Model weight caching can be used like this: + +```python +import truss_chains as chains +from truss import truss_config + +mistral_cache = truss_config.ModelRepo( + repo_id="mistralai/Mistral-7B-Instruct-v0.2", + allow_patterns=["*.json", "*.safetensors", ".model"] + ) +chains.Assets(cached=[mistral_cache], ...) +``` + +See [truss caching guide](https://truss.baseten.co/guides/model-cache#enabling-caching-for-a-model) +for more details on caching. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `cached` | *Iterable[ModelRepo]* | One or more `truss_config.ModelRepo` objects. | +| `secret_keys` | *Iterable[str]* | Names of secrets stored on baseten, that the chainlet should have access to. You can manage secrets on baseten [here](https://app.baseten.co/settings/secrets). | + + +#### get_spec() + +Returns parsed and validated assets. + +* **Return type:** + *AssetSpec* + +# Core + +General framework and helper functions. + +### `truss_chains.deploy_remotely` + +Deploys a chain remotely (with all dependent chainlets). + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `entrypoint` | *Type[ChainletBase]* | The chainlet class that serves as the entrypoint to the chain. | +| `chain_name` | *str* | The name of the chain. | +| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | +| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | +| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | + +* **Returns:** + A chain service handle to the deployed chain. +* **Return type:** + [*ChainService*](#class-truss-chains-deploy-chainservice) + +### *class* `truss_chains.deploy.ChainService` + +Handle for a deployed chain. + +A `ChainService` is created and returned when using [`deploy_remotely`](#truss-chains-deploy-remotely). +It bundles the individual services for each chainlet in the chain, and provides +utilities to query their status, invoke the entrypoint etc. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `entrypoint` | *str* | Name of the entrypoint chainlet. | +| `name` | *str* | Name of the chain. | + + +#### add_service(name, service) + +Used to add a chainlet service during the deployment sequence of the chain. + + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `name` | *str* | Chainlet name. | +| `service` | *TrussService* | Service object for the chainlet. | + +* **Return type:** + None + +#### *property* entrypoint_name *: str* + +#### *property* get_entrypoint *: TrussService* + +Returns the entrypoint’s service handle. + +* **Raises:** + **MissingDependencyError** – If the entrypoint service was not added. + +#### get_info() + +Queries the statuses of all chainlets in the chain. + +* **Returns:** + List with elements `(name, status, logs_url)` for each chainlet. +* **Return type:** + list[tuple[str, str, str]] + +#### run_remote(json) + +Invokes the entrypoint with JSON data. + +**Parameters:** +| Name | Type | Description | +|------|------|-------------| +| `json` | *JSON Dict* | Request payload. | + +* **Returns:** + The JSON response. +* **Return type:** + *JSON Dict* + +#### *property* run_url *: str* + +URL to invoke the entrypoint. + +#### *property* services *: MutableMapping[str, TrussService]* + +### `truss_chains.make_abs_path_here` + +Helper to specify file paths relative to the *immediately calling* module. + +E.g. in you have a project structure like this: + +```default +root/ + chain.py + common_requirements.text + sub_package/ + chainlet.py + chainlet_requirements.txt +``` + +You can now in `root/sub_package/chainlet.py` point to the requirements +file like this: + +```python +shared = RelativePathToHere("../common_requirements.text") +specific = RelativePathToHere("chainlet_requirements.text") +``` + + +This helper uses the directory of the immediately calling module as an +absolute reference point for resolving the file location. Therefore, +you MUST NOT wrap the instantiation of `make_abs_path_here` into a +function (e.g. applying decorators) or use dynamic code execution. + +Ok: + +```python +def foo(path: AbsPath): + abs_path = path.abs_path + +foo(make_abs_path_here("./somewhere")) +``` + +Not Ok: + +```python +def foo(path: str): + dangerous_value = make_abs_path_here(path).abs_path + +foo("./somewhere") +``` + + +**Parameters:** + +| Name | Type | Description | +|-------------|---------|----------------------------| +| `file_path` | *str* | Absolute or relative path. | +* +* **Return type:** + *AbsPath* + +### `truss_chains.run_local` + +Context manager local debug execution of a chain. + +The arguments only need to be provided if the chainlets explicitly access any the +corresponding fields of [`DeploymentContext`](#class-truss-chains-deploymentcontext-generic-userconfigt). + + +**Parameters:** + +| Name | Type | Description | +|------|--------------------------------------------------------------------------|-------------| +| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | +| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | +| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | A dict of chainlet names to service descriptors. | + +* **Return type:** + *ContextManager*[None] + +Example usage (as trailing main section in a chain file): + +```python +import os +import truss_chains as chains + + +class HelloWorld(chains.ChainletBase): + ... + + +if __name__ == "__main__": + with chains.run_local( + secrets={"some_token": os.environ["SOME_TOKEN"]}, + chainlet_to_service={ + "SomeChainlet": chains.ServiceDescriptor( + name="SomeChainlet", + predict_url="https://...", + options=chains.RPCOptions(), + ) + }, + ): + hello_world_chain = HelloWorld() + result = hello_world_chain.run_remote(max_value=5) + + print(result) +``` + +Refer to the [local debugging guide](https://truss.baseten.co/chains/guide#local-debugging) +for more details. + +### *class* `truss_chains.ServiceDescriptor` + +Bases: `pydantic.BaseModel` + +Bundles values to establish an RPC session to a dependency chainlet, +specifically with `StubBase`. + +**Parameters:** + +| Name | Type | Description | +|------|------------------------------------------------|-------------| +| `name` | *str* | | +| `predict_url` | *str* | | +| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | | + +## *class* `truss_chains.StubBase` + +Base class for stubs that invoke remote chainlets. + +It is used internally for RPCs to dependency chainlets, but it can also be used +in user-code for wrapping a deployed truss model into the chains framework, e.g. +like that: + +```python +import pydantic +import truss_chains as chains + +class WhisperOutput(pydantic.BaseModel): + ... + + +class DeployedWhisper(chains.StubBase): + + async def run_remote(self, audio_b64: str) -> WhisperOutput: + resp = await self._remote.predict_async(json_payload={"audio": audio_b64}) + return WhisperOutput(text=resp["text"], language==resp["language"]) + + +class MyChainlet(chains.ChainletBase): + + def __init__(self, ..., context = chains.depends_context()): + ... + self._whisper = DeployedWhisper.from_url( + WHISPER_URL, + context, + options=chains.RPCOptions(retries=3), + ) +``` + + +**Parameters:** + +| Name | Type | Description | +|------|--------------------------------------------------------------|-------------| +| `service_descriptor` | *[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | Contains the URL and other configuration. | +| `api_key` | *str* | A baseten API key to authorize requests. | + + +#### *classmethod* from_url(predict_url, context, options=None) + +Factory method, convenient to be used in chainlet’s `__init__`-method. + + +**Parameters:** + +| Name | Type | Description | +|------|------------------------------------------------------------------------------|-------------| +| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | +| `context` | *[DeploymentContext](#class-truss-chains-deploymentcontext-generic-userconfigt)* | Deployment context object, obtained in the chainlet’s `__init__`. | +| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | RPC options, e.g. retries. | + +### *class* `truss_chains.RemoteErrorDetail` + +Bases: `pydantic.BaseModel` + +When a remote chainlet raises an exception, this pydantic model contains +information about the error and stack trace and is included in JSON form in the +error response. + +**Parameters:** + +| Name | Type | Description | +|------|------|-------------| +| `remote_name` | *str* | | +| `exception_cls_name` | *str* | | +| `exception_module_name` | *str\|None* | | +| `exception_message` | *str* | | +| `user_stack_trace` | *list[StackFrame]* | | + +#### format() + +Format the error for printing, similar to how Python formats exceptions +with stack traces. + +* **Return type:** + str diff --git a/docs/snippets/chains/TOC.mdx b/docs/snippets/chains/TOC.mdx index 48546670a..1f4c58012 100644 --- a/docs/snippets/chains/TOC.mdx +++ b/docs/snippets/chains/TOC.mdx @@ -9,3 +9,4 @@ * [Advanced Guide](/chains/guide) * [Testing, Debugging, Mocking](/chains/guide#development-experience) * [Dependencies & Resources](/chains/guide#changing-compute-resources) +* [API Reference](/chains/full-reference) diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index 55dddbe28..efb2ccee8 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -15,19 +15,19 @@ "You can still use other Truss functionality." ) -del pydantic +del pydantic, pydantic_major_version # flake8: noqa F401 from truss_chains.definitions import ( Assets, - ChainsRuntimeError, Compute, DeploymentContext, DockerImage, RemoteConfig, RemoteErrorDetail, RPCOptions, + ServiceDescriptor, ) from truss_chains.public_api import ( ChainletBase, diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index cb9510b90..1af6505b9 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -2,7 +2,6 @@ Chains currently assumes that everything from the directory in which the entrypoint is defined (i.e. sibling files and nested dirs) could be imported/used. e.g.: -``` workspace/ entrypoint.py helper.py @@ -10,18 +9,15 @@ utils.py sub_package/ ... -``` These sources are copied into truss's `/packages` and can be imported on the remote. Using code *outside* of the workspace is not supported: -``` shared_lib/ common.py workspace/ entrypoint.py ... -``` `shared_lib` can only be imported on the remote if its installed as a pip requirement (site-package), it will not be copied from the local host. diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index b1c6ec205..fb0f341e9 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -123,6 +123,29 @@ def abs_path(self) -> str: class DockerImage(SafeModelNonSerializable): + """Configures the docker image in which a remoted chainlet is deployed. + + Note: + Any paths are relative to the source file where ``DockerImage`` is + defined and must be created with the helper function ``make_abs_path_here``. + This allows you for example organize chainlets in different (potentially nested) + modules and keep their requirement files right next their python source files. + + Args: + base_image: The base image to use for the chainlet. Default is + ``python:3.11-slim``. + pip_requirements_file: Path to a file containing pip requirements. The file + content is naively concatenated with ``pip_requirements``. + pip_requirements: A list of pip requirements to install. The items are + naively concatenated with the content of the ``pip_requirements_file``. + apt_requirements: A list of apt requirements to install. + data_dir: Data from this directory is copied into the docker image and + accessible to the remote chainlet at runtime. + external_package_dirs: A list of directories containing additional python + packages outside the chain's workspace dir, e.g. a shared library. This code + is copied into the docker image and importable at runtime. + """ + # TODO: this is not stable yet and might change or refer back to truss. base_image: str = "python:3.11-slim" pip_requirements_file: Optional[AbsPath] = None @@ -133,6 +156,8 @@ class DockerImage(SafeModelNonSerializable): class ComputeSpec(pydantic.BaseModel): + """Parsed and validated compute. See ``Compute`` for more information.""" + # TODO: this is not stable yet and might change or refer back to truss. cpu_count: int = 1 predict_concurrency: int = 1 @@ -145,8 +170,16 @@ class ComputeSpec(pydantic.BaseModel): class Compute: - """Builder to create ComputeSpec.""" + """Specifies which compute resources a chainlet has in the *remote* deployment. + Note: + Not all combinations can be exactly satisfied by available hardware, in some + cases more powerful machine types are chosen to make sure requirements are met or + over-provisioned. Refer to the + `baseten instance reference `_. + """ + + # Builder to create ComputeSpec. # This extra layer around `ComputeSpec` is needed to parse the accelerator options. _spec: ComputeSpec @@ -159,6 +192,29 @@ def __init__( gpu_count: int = 1, predict_concurrency: Union[int, CpuCountT] = 1, ) -> None: + """ + Args: + cpu_count: Minimum number of CPUs to allocate. + memory: Minimum memory to allocate, e.g. "2Gi" (2 gibibytes). + gpu: GPU accelerator type, e.g. "A10G", "A100", refer to the + `truss config `_ + for more choices. + gpu_count: Number of GPUs to allocate. + predict_concurrency: Number of concurrent requests a single replica of a + deployed chainlet handles. + + Concurrency concepts are explained in `this guide `_. + It is important to understand the difference between `predict_concurrency` and + the concurrency target (used for autoscaling, i.e. adding or removing replicas). + Furthermore, the ``predict_concurrency`` of a single instance is implemented in + two ways: + + - Via python's ``asyncio``, if ``run_remote`` is an async def. This + requires that ``run_remote`` yields to the event loop. + + - With a threadpool if it's a synchronous function. This requires + that the threads don't have significant CPU load (due to the GIL). + """ accelerator = truss_config.AcceleratorSpec() if gpu: accelerator.accelerator = truss_config.Accelerator(gpu) @@ -184,32 +240,74 @@ def get_spec(self) -> ComputeSpec: class AssetSpec(SafeModel): + """Parsed and validated assets. See ``Assets`` for more information.""" + # TODO: this is not stable yet and might change or refer back to truss. - secrets: dict[str, str] = {} - cached: list[Any] = [] + secrets: dict[str, str] = pydantic.Field({}) + cached: list[truss_config.ModelRepo] = [] class Assets: - """Builder to create asset spec.""" + """Specifies which assets a chainlet can access in the remote deployment. + + Model weight caching can be used like this:: + + import truss_chains as chains + from truss import truss_config + + mistral_cache = truss_config.ModelRepo( + repo_id="mistralai/Mistral-7B-Instruct-v0.2", + allow_patterns=["*.json", "*.safetensors", ".model"] + ) + chains.Assets(cached=[mistral_cache], ...) + + See `truss caching guide `_ + for more details on caching. + """ - # This extra layer around `ComputeSpec` is needed to add secret_keys. + # Builder to create asset spec. + # This extra layer around `AssetSpec` is needed to add secret_keys. _spec: AssetSpec def __init__( self, - cached: Iterable[Any] = (), + cached: Iterable[truss_config.ModelRepo] = (), secret_keys: Iterable[str] = (), ) -> None: + """ + Args: + cached: One or more ``truss_config.ModelRepo`` objects. + secret_keys: Names of secrets stored on baseten, that the + chainlet should have access to. You can manage secrets on baseten + `here `_. + """ self._spec = AssetSpec( cached=list(cached), secrets={k: SECRET_DUMMY for k in secret_keys} ) def get_spec(self) -> AssetSpec: + """Returns parsed and validated assets.""" return self._spec.copy(deep=True) class RemoteConfig(SafeModelNonSerializable): - """Bundles config values needed to deploy a Chainlet.""" + """Bundles config values needed to deploy a chainlet remotely.. + + This is specified as a class variable for each chainlet class, e.g.:: + + import truss_chains as chains + + + class MyChainlet(chains.ChainletBase): + remote_config = chains.RemoteConfig( + docker_image=chains.DockerImage( + pip_requirements=["torch==2.0.1", ... ] + ), + compute=chains.Compute(cpu_count=2, gpu="A10G", ...), + assets=chains.Assets(secret_keys=["hf_access_token"], ...), + ) + + """ docker_image: DockerImage = DockerImage() compute: Compute = Compute() @@ -224,18 +322,37 @@ def get_asset_spec(self) -> AssetSpec: class RPCOptions(SafeModel): + """Options to customize RPCs to dependency chainlets.""" + timeout_sec: int = 600 retries: int = 1 class ServiceDescriptor(SafeModel): + """Bundles values to establish an RPC session to a dependency chainlet, + specifically with ``StubBase``.""" + name: str predict_url: str options: RPCOptions class DeploymentContext(SafeModelNonSerializable, Generic[UserConfigT]): - """Bundles config values and resources needed to instantiate Chainlets.""" + """Bundles config values and resources needed to instantiate Chainlets. + + This is provided at runtime to the Chainlet's ``__init__`` method. + + Args: + data_dir: The directory where the chainlet can store and access data, + e.g. for downloading model weights. + user_config: User-defined configuration for the chainlet. + chainlet_to_service: A mapping from chainlet names to service descriptors. + This is used create RPCs sessions to dependency chainlets. It contains only + the chainlet services that are dependencies of the current chainlet. + secrets: A mapping from secret names to secret values. It contains only the + secrets that are listed in ``remote_config.assets.secret_keys`` of the + current chainlet. + """ data_dir: Optional[pathlib.Path] = None user_config: UserConfigT @@ -361,19 +478,26 @@ def to_frame_summary(self) -> traceback.FrameSummary: class RemoteErrorDetail(SafeModel): + """When a remote chainlet raises an exception, this pydantic model contains + information about the error and stack trace and is included in JSON form in the + error response. + """ + remote_name: str exception_cls_name: str exception_module_name: Optional[str] exception_message: str user_stack_trace: list[StackFrame] - def to_stack_summary(self) -> traceback.StackSummary: + def _to_stack_summary(self) -> traceback.StackSummary: return traceback.StackSummary.from_list( frame.to_frame_summary() for frame in self.user_stack_trace ) def format(self) -> str: - stack = "".join(traceback.format_list(self.to_stack_summary())) + """Format the error for printing, similar to how Python formats exceptions + with stack traces.""" + stack = "".join(traceback.format_list(self._to_stack_summary())) exc_info = ( f"\n(Exception class defined in `{self.exception_module_name}`.)" if self.exception_module_name diff --git a/truss-chains/truss_chains/deploy.py b/truss-chains/truss_chains/deploy.py index 1a92c84b6..9b1880233 100644 --- a/truss-chains/truss_chains/deploy.py +++ b/truss-chains/truss_chains/deploy.py @@ -146,22 +146,49 @@ def add_needed_chainlets(chainlet: definitions.ChainletAPIDescriptor): class ChainService: + # TODO: this exposes methods to users that should be internal (e.g. `add_service`). + """Handle for a deployed chain. + + A ``ChainService`` is created and returned when using ``deploy_remotely``. It + bundles the individual services for each chainlet in the chain, and provides + utilities to query their status, invoke the entrypoint etc. + """ + name: str _entrypoint: str _services: MutableMapping[str, b10_service.TrussService] _entrypoint_fake_json_data = Any def __init__(self, entrypoint: str, name: str) -> None: + """ + Args: + entrypoint: Name of the entrypoint chainlet. + name: Name of the chain. + """ self.name = name self._entrypoint = entrypoint self._services = collections.OrderedDict() # Preserve order. self.entrypoint_fake_json_data = None def add_service(self, name: str, service: b10_service.TrussService) -> None: + """ + Used to add a chainlet service during the deployment sequence of the chain. + + + Args: + name: Chainlet name. + service: Service object for the chainlet. + """ self._services[name] = service @property def entrypoint_fake_json_data(self) -> Any: + """Fake JSON example data that matches the entrypoint's input schema. + This property must be externally populated. + + Raises: + ValueError: If fake data was not set. + """ if self._entrypoint_fake_json_data is None: raise ValueError("Fake data was not set.") return self._entrypoint_fake_json_data @@ -172,6 +199,11 @@ def entrypoint_fake_json_data(self, fake_data: Any) -> None: @property def get_entrypoint(self) -> b10_service.TrussService: + """Returns the entrypoint's service handle. + + Raises: + MissingDependencyError: If the entrypoint service was not added. + """ service = self._services.get(self._entrypoint) if not service: raise definitions.MissingDependencyError( @@ -189,13 +221,21 @@ def entrypoint_name(self) -> str: @property def run_url(self) -> str: + """URL to invoke the entrypoint.""" return self.get_entrypoint.predict_url def run_remote(self, json: Dict) -> Any: + """Invokes the entrypoint with JSON data. + + Returns: + The JSON response.""" return self.get_entrypoint.predict(json) def get_info(self) -> list[tuple[str, str, str]]: - """Return list with elements (name, status, logs_url) for each chainlet.""" + """Queries the statuses of all chainlets in the chain. + + Returns: + List with elements ``(name, status, logs_url)`` for each chainlet.""" return list( (name, next(service.poll_deployment_status(sleep_secs=0)), service.logs_url) for name, service in self._services.items() diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 97e61a821..f1b55ba85 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -6,11 +6,22 @@ def depends_context() -> definitions.DeploymentContext: - """Sets a 'symbolic marker' for injecting a Context object at runtime. + """Sets a "symbolic marker" for injecting a context object at runtime. + + Refer to `the docs `_ and this + `example chainlet `_ + for more guidance on the ``__init__``-signature of chainlets. + + Warning: + Despite the type annotation, this does *not* immediately provide a + context instance. Only when deploying remotely or using ``run_local`` a + context instance is provided. + + Returns: + A "symbolic marker" to be used as a default argument in a chainlet's + initializer. + - WARNING: Despite the type annotation, this does *not* immediately provide a - context instance. - Only when deploying remotely or using `run_local` a context instance is provided. """ # The type error is silenced to because chains framework will at runtime inject # a corresponding instance. Nonetheless, we want to use a type annotation here, @@ -22,11 +33,30 @@ def depends_context() -> definitions.DeploymentContext: def depends( chainlet_cls: Type[framework.ChainletT], retries: int = 1 ) -> framework.ChainletT: - """Sets a 'symbolic marker' for injecting a stub or local Chainlet at runtime. - - WARNING: Despite the type annotation, this does *not* immediately provide a - chainlet instance. - Only when deploying remotely or using `run_local` a chainlet instance is provided. + """Sets a "symbolic marker" to indicate to the framework that a chainlet is a + dependency of another chainlet. The return value of ``depends`` is intended to be + used as a default argument in a chainlet's ``__init__``-method. + When deploying a chain remotely, a corresponding stub to the remote is injected in + its place. In ``run_local`` mode an instance of a local chainlet is injected. + + Refer to `the docs `_ and this + `example chainlet `_ + for more guidance on how make one chainlet depend on another chainlet. + + Warning: + Despite the type annotation, this does *not* immediately provide a + chainlet instance. Only when deploying remotely or using ``run_local`` a + chainlet instance is provided. + + + Args: + chainlet_cls: The chainlet class of the dependency. + retries: The number of times to retry the remote chainlet in case of failures + (e.g. due to transient network issues). + + Returns: + A "symbolic marker" to be used as a default argument in a chainlet's + initializer. """ options = definitions.RPCOptions(retries=retries) # The type error is silenced to because chains framework will at runtime inject @@ -37,6 +67,16 @@ def depends( class ChainletBase(definitions.ABCChainlet): + """Base class for all chainlets. + + Inheriting from this class adds validations to make sure subclasses adhere to the + chainlet pattern and facilitates remote chainlet deployment. + + Refer to `the docs `_ and this + `example chainlet `_ + for more guidance on how to create subclasses. + """ + def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) framework.check_and_register_class(cls) @@ -55,6 +95,20 @@ def __init_with_arg_check__(self, *args, **kwargs): def mark_entrypoint(cls: Type[framework.ChainletT]) -> Type[framework.ChainletT]: + """Decorator to mark a chainlet as the entrypoint of a chain. + + This decorator can be applied to *one* chainlet in a source file and then the + CLI deploy command simplifies because only the file, but not the chainlet class + in the file needs to be specified. + + Example usage:: + + import truss_chains as chains + + @chains.mark_entrypoint + class MyChainlet(ChainletBase): + ... + """ return framework.entrypoint(cls) @@ -65,6 +119,24 @@ def deploy_remotely( promote: bool = True, only_generate_trusses: bool = False, ) -> deploy.ChainService: + """ + Deploys a chain remotely (with all dependent chainlets). + + Args: + entrypoint: The chainlet class that serves as the entrypoint to the chain. + chain_name: The name of the chain. + publish: Whether to publish the chain as a published deployment (it is a + draft deployment otherwise) + promote: Whether to promote the chain to be the production deployment (this + implies publishing as well). + only_generate_trusses: Used for debugging purposes. If set to True, only the + the underlying truss models for the chainlets are generated in + ``/tmp/.chains_generated``. + + Returns: + A chain service handle to the deployed chain. + + """ options = definitions.DeploymentOptionsBaseten.create( chain_name=chain_name, publish=publish, @@ -79,6 +151,45 @@ def run_local( data_dir: Optional[Union[pathlib.Path, str]] = None, chainlet_to_service: Optional[Mapping[str, definitions.ServiceDescriptor]] = None, ) -> ContextManager[None]: - """Context manager for using in-process instantiations of Chainlet dependencies.""" + """Context manager local debug execution of a chain. + + The arguments only need to be provided if the chainlets explicitly access any the + corresponding fields of ``DeploymentContext``. + + Args: + secrets: A dict of secrets keys and values to provide to the chainlets. + data_dir: Path to a directory with data files. + chainlet_to_service: A dict of chainlet names to service descriptors. + + Example usage (as trailing main section in a chain file):: + + import os + import truss_chains as chains + + + class HelloWorld(chains.ChainletBase): + ... + + + if __name__ == "__main__": + with chains.run_local( + secrets={"some_token": os.environ["SOME_TOKEN"]}, + chainlet_to_service={ + "SomeChainlet": chains.ServiceDescriptor( + name="SomeChainlet", + predict_url="https://...", + options=chains.RPCOptions(), + ) + }, + ): + hello_world_chain = HelloWorld() + result = hello_world_chain.run_remote(max_value=5) + + print(result) + + + Refer to the `local debugging guide `_ + for more details. + """ data_dir = pathlib.Path(data_dir) if data_dir else None return framework.run_local(secrets, data_dir, chainlet_to_service) diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 4ccfc32b9..04c853088 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -82,12 +82,49 @@ async def predict_async(self, json_payload): class StubBase(abc.ABC): + """Base class for stubs that invoke remote chainlets. + + It is used internally for RPCs to dependency chainlets, but it can also be used + in user-code for wrapping a deployed truss model into the chains framework, e.g. + like that:: + + import pydantic + import truss_chains as chains + + class WhisperOutput(pydantic.BaseModel): + ... + + + class DeployedWhisper(chains.StubBase): + + async def run_remote(self, audio_b64: str) -> WhisperOutput: + resp = await self._remote.predict_async(json_payload={"audio": audio_b64}) + return WhisperOutput(text=resp["text"], language==resp["language"]) + + + class MyChainlet(chains.ChainletBase): + + def __init__(self, ..., context = chains.depends_context()): + ... + self._whisper = DeployedWhisper.from_url( + WHISPER_URL, + context, + options=chains.RPCOptions(retries=3), + ) + + """ + _remote: BasetenSession @final def __init__( self, service_descriptor: definitions.ServiceDescriptor, api_key: str ) -> None: + """ + Args: + service_descriptor: Contains the URL and other configuration. + api_key: A baseten API key to authorize requests. + """ self._remote = BasetenSession(service_descriptor, api_key) @classmethod @@ -96,13 +133,18 @@ def from_url( predict_url: str, context: definitions.DeploymentContext, options: Optional[definitions.RPCOptions] = None, - name: Optional[str] = None, ): - name = name or cls.__name__ + """Factory method, convenient to be used in chainlet's ``__init__``-method. + + Args: + predict_url: URL to predict endpoint of another chain / truss model. + context: Deployment context object, obtained in the chainlet's ``__init__``. + options: RPC options, e.g. retries. + """ options = options or definitions.RPCOptions() return cls( definitions.ServiceDescriptor( - name=name, predict_url=predict_url, options=options + name=cls.__name__, predict_url=predict_url, options=options ), api_key=context.get_baseten_api_key(), ) @@ -112,7 +154,7 @@ def from_url( def factory(stub_cls: Type[StubT], context: definitions.DeploymentContext) -> StubT: - # Assumes the stub_cls-name and the name of the service in `context` match. + # Assumes the stub_cls-name and the name of the service in ``context` match. return stub_cls( service_descriptor=context.get_service_descriptor(stub_cls.__name__), api_key=context.get_baseten_api_key(), diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index 79e3469cf..b02b3a872 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -22,44 +22,42 @@ def make_abs_path_here(file_path: str) -> definitions.AbsPath: """Helper to specify file paths relative to the *immediately calling* module. - E.g. in you have a project structure like this + E.g. in you have a project structure like this:: - root/ - chain.py - common_requirements.text - sub_package/ - Chainlet.py - chainlet_requirements.txt + root/ + chain.py + common_requirements.text + sub_package/ + chainlet.py + chainlet_requirements.txt - Not in `root/sub_package/Chainlet.py` you can point to the requirements - file like this: + You can now in ``root/sub_package/chainlet.py`` point to the requirements + file like this:: - ``` - shared = RelativePathToHere("../common_requirements.text") - specific = RelativePathToHere("chainlet_requirements.text") - ``` + shared = RelativePathToHere("../common_requirements.text") + specific = RelativePathToHere("chainlet_requirements.text") - Caveat: this helper uses the directory of the immediately calling module as an - absolute reference point for resolving the file location. - Therefore, you MUST NOT wrap the instantiation of `RelativePathToHere` into a - function (e.g. applying decorators) or use dynamic code execution. - Ok: - ``` - def foo(path: AbsPath): - abs_path = path.abs_path + Warning: + This helper uses the directory of the immediately calling module as an + absolute reference point for resolving the file location. Therefore, + you MUST NOT wrap the instantiation of ``make_abs_path_here`` into a + function (e.g. applying decorators) or use dynamic code execution. + Ok:: - foo(make_abs_path_here("blabla")) - ``` + def foo(path: AbsPath): + abs_path = path.abs_path - Not Ok: - ``` - def foo(path: str): - badbadbad = make_abs_path_here(path).abs_path + foo(make_abs_path_here("./somewhere")) + + Not Ok:: + + def foo(path: str): + dangerous_value = make_abs_path_here(path).abs_path + + foo("./somewhere") - foo("blabla")) - ``` """ # TODO: the absolute path resolution below uses the calling module as a # reference point. This would not work if users wrap this call in a function @@ -300,6 +298,7 @@ class StrEnum(str, enum.Enum): StrEnum is a Python `enum.Enum` that inherits from `str`. The `auto()` behavior uses the member name and lowers it. This is useful for compatibility with pydantic. Example usage: + ``` class Example(StrEnum): SOME_VALUE = enum.auto()