From 713f74d80dd01a907fe395e647f21a9c5b9a9ddb Mon Sep 17 00:00:00 2001 From: Basetenbot <96544894+basetenbot@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:18:22 -0700 Subject: [PATCH] Release 0.9.15 (#977) * Set default predict_concurrency when using trt-llm to 512 (#954) * Set default predict_concurrency when using trt-llm to 512 * update tests * Truss changes to support lazy data that reads bptr secret and fetches from remote (#963) * lazy data resolution support * add support for lazy data resolver in truss * remove lazy loader reference from template * fetch in model wrapper * duplicate download util for shared template * concurrent download * fix path reference * use updated expiration_timestamp type --------- Co-authored-by: Pankaj Gupta * Update push docs. (#965) * Adding initial code to implement build commands (#961) * Adding initial code to implement build commands * Adding some tests * Adding docker integration tests * making build command an empty list by default * removing unnecessary build_commands list for loop thing * correct secrets str in docs (#968) * Fix lazy data resolver error handling (#967) * [chains] Add external_package_dirs option. Usage in Whiper model chainlet. (#966) * add truss chains init (#973) * [BT-10657] Wire up truss chains deploy (#969) * Wire up the new chains mutations to truss chains deploy. * Add comment. * Respond to PR feedback. * * Prune docker build cache in integration tests. (#976) * Show requirement file content before pip install. * For all tests running docker containers, show container logs if an exception was raised. * Update control requirements to truss 0.9.14 (required also incrementing httpx version). * Bump version to 0.9.15 --------- Co-authored-by: Bryce Dubayah Co-authored-by: joostinyi <63941848+joostinyi@users.noreply.github.com> Co-authored-by: Pankaj Gupta Co-authored-by: Sidharth Shanker Co-authored-by: Het Trivedi Co-authored-by: rcano-baseten Co-authored-by: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> --- .github/workflows/integration-tests.yml | 4 + docs/chains/getting-started.mdx | 1 + docs/chains/guide.mdx | 8 +- docs/reference/cli/push.mdx | 7 + pyproject.toml | 2 +- truss-chains/examples/whisper/whisper.py | 89 ++++++++++++ truss-chains/truss_chains/code_gen.py | 3 + truss-chains/truss_chains/definitions.py | 5 +- truss-chains/truss_chains/deploy.py | 53 ++++++- truss-chains/truss_chains/example_chainlet.py | 50 ++++--- truss-chains/truss_chains/framework.py | 6 +- truss/cli/cli.py | 69 +++++++++ truss/constants.py | 1 + .../image_builder/serving_image_builder.py | 6 + truss/local/local_config_handler.py | 6 + truss/remote/baseten/api.py | 92 +++++++++++- truss/remote/baseten/core.py | 34 +++++ truss/remote/baseten/remote.py | 15 ++ truss/remote/baseten/types.py | 7 + truss/templates/base.Dockerfile.jinja | 2 + truss/templates/control/requirements.txt | 4 +- truss/templates/server.Dockerfile.jinja | 7 + truss/templates/server/model_wrapper.py | 3 + truss/templates/shared/lazy_data_resolver.py | 70 ++++++++++ truss/templates/shared/util.py | 17 +++ .../test_data/test_build_commands/config.yaml | 13 ++ .../test_build_commands/model/model.py | 17 +++ .../test_build_commands_failure/config.yaml | 14 ++ .../model/model.py | 17 +++ truss/tests/conftest.py | 29 ++++ .../test_serving_image_builder.py | 43 ++++++ .../core/server/test_lazy_data_resolver.py | 132 ++++++++++++++++++ truss/tests/test_config.py | 4 +- .../test_testing_utilities_for_other_tests.py | 38 ++++- truss/tests/test_truss_handle.py | 22 +++ truss/truss_config.py | 3 + truss/truss_handle.py | 7 +- truss/truss_spec.py | 4 + truss/util/download.py | 4 +- 39 files changed, 867 insertions(+), 41 deletions(-) create mode 100644 truss-chains/examples/whisper/whisper.py create mode 100644 truss/remote/baseten/types.py create mode 100644 truss/templates/shared/lazy_data_resolver.py create mode 100644 truss/test_data/test_build_commands/config.yaml create mode 100644 truss/test_data/test_build_commands/model/model.py create mode 100644 truss/test_data/test_build_commands_failure/config.yaml create mode 100644 truss/test_data/test_build_commands_failure/model/model.py create mode 100644 truss/tests/templates/core/server/test_lazy_data_resolver.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7baa1c768..629ec07ae 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -65,6 +65,8 @@ jobs: matrix: split_group: ["1", "2", "3", "4", "5"] steps: + - name: Purge Docker cache + run: docker builder prune -af - uses: actions/checkout@v3 - uses: ./.github/actions/setup-python/ - run: poetry install @@ -77,6 +79,8 @@ jobs: strategy: fail-fast: false steps: + - name: Purge Docker cache + run: docker builder prune -af - uses: actions/checkout@v3 - uses: ./.github/actions/setup-python/ - run: poetry install diff --git a/docs/chains/getting-started.mdx b/docs/chains/getting-started.mdx index 28e2601f7..a7ef1f054 100644 --- a/docs/chains/getting-started.mdx +++ b/docs/chains/getting-started.mdx @@ -37,6 +37,7 @@ More details are in the [concepts section](/chains/concepts). Create a Chain project directory with a python file in it. You can chose a name and location, in this example we assume the file is named `hello.py`. +-- Note: If you are changing this snippet, please update the example code in example_chainlet.py accordingly ```python import random import truss_chains as chains diff --git a/docs/chains/guide.mdx b/docs/chains/guide.mdx index c2d797d8b..db4036e5d 100644 --- a/docs/chains/guide.mdx +++ b/docs/chains/guide.mdx @@ -110,6 +110,8 @@ MISTRAL_HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" MISTRAL_CACHE = truss_config.ModelRepo( repo_id=MISTRAL_HF_MODEL, allow_patterns=["*.json", "*.safetensors", ".model"] ) +# This name should correspond to a secret "name" in https://app.baseten.co/settings/secrets +HF_ACCESS_TOKEN_NAME = "hf_access_token" class MistralLLM(chains.ChainletBase): # The RemoteConfig object defines the resources required for this chainlet. @@ -127,7 +129,7 @@ class MistralLLM(chains.ChainletBase): compute=chains.Compute(cpu_count=2, gpu="A10G"), # Cache the model weights in the image and make the huggingface # access token secret available to the model. - assets=chains.Assets(cached=[MISTRAL_CACHE], secret_keys=["hf_access_token"]), + assets=chains.Assets(cached=[MISTRAL_CACHE], secret_keys=[HF_ACCESS_TOKEN_NAME]), ) def __init__( @@ -147,14 +149,14 @@ class MistralLLM(chains.ChainletBase): MISTRAL_HF_MODEL, torch_dtype=torch.float16, device_map="auto", - use_auth_token=context.secrets["HF_ACCESS_TOKEN"], + use_auth_token=context.secrets[HF_ACCESS_TOKEN_NAME], ) self._tokenizer = transformers.AutoTokenizer.from_pretrained( MISTRAL_HF_MODEL, device_map="auto", torch_dtype=torch.float16, - use_auth_token=context.secrets["HF_ACCESS_TOKEN"], + use_auth_token=context.secrets[HF_ACCESS_TOKEN_NAME], ) self._generate_args = { diff --git a/docs/reference/cli/push.mdx b/docs/reference/cli/push.mdx index 4c9401157..41245e47b 100644 --- a/docs/reference/cli/push.mdx +++ b/docs/reference/cli/push.mdx @@ -15,6 +15,9 @@ Name of the remote in .trussrc to patch changes to. Push the truss as a published deployment. If no production deployment exists, promote the truss to production after deploy completes. + +Name of the model + Push the truss as a published deployment. Even if a production deployment exists, promote the truss to production after deploy completes. @@ -30,6 +33,10 @@ Name of the deployment created by the push. Can only be used in combination with Whether to wait for deployment to complete before returning. If the deploy or build fails, will return with a non-zero exit code. + +Maximum time to wait for deployment to complete in seconds. Without specifying, the command will not complete until the deployment is complete. + + Show help message and exit. diff --git a/pyproject.toml b/pyproject.toml index 498d756fc..ceb496978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.14" +version = "0.9.15" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss-chains/examples/whisper/whisper.py b/truss-chains/examples/whisper/whisper.py new file mode 100644 index 000000000..4e5381426 --- /dev/null +++ b/truss-chains/examples/whisper/whisper.py @@ -0,0 +1,89 @@ +from typing import Optional + +# flake8: noqa F402 +# This location assumes `fde`-repo is checked out at the same level as `truss`-repo. +_LOCAL_WHISPER_LIB = "../../../../fde/whisper-trt/src" +import sys + +sys.path.append(_LOCAL_WHISPER_LIB) + +import base64 + +import pydantic +import truss_chains as chains +from huggingface_hub import snapshot_download + + +# TODO: The I/O types below should actually be taken from `whisper_trt.types`. +# But that cannot be imported without having `tensorrt_llm` installed. +# It could be fixed, by making that module importable without any special requirements. +class Segment(pydantic.BaseModel): + start_time_sec: float + end_time_sec: float + text: str + start: float # TODO: deprecate, use field with unit (seconds). + end: float # TODO: deprecate, use field with unit (seconds). + + +class WhisperResult(pydantic.BaseModel): + segments: list[Segment] + language: Optional[str] + language_code: Optional[str] = pydantic.Field( + ..., + description="IETF language tag, e.g. 'en', see. " + "https://en.wikipedia.org/wiki/IETF_language_tag.", + ) + + +class WhisperInput(pydantic.BaseModel): + audio_b64: str + + +@chains.mark_entrypoint +class WhisperModel(chains.ChainletBase): + + remote_config = chains.RemoteConfig( + docker_image=chains.DockerImage( + base_image="baseten/truss-server-base:3.10-gpu-v0.9.0", + apt_requirements=["python3.10-venv", "openmpi-bin", "libopenmpi-dev"], + pip_requirements=[ + "--extra-index-url https://pypi.nvidia.com", + "tensorrt_llm==0.10.0.dev2024042300", + "hf_transfer", + "janus", + "kaldialign", + "librosa", + "mpi4py==3.1.4", + "safetensors", + "soundfile", + "tiktoken", + "torchaudio", + "async-batcher>=0.2.0", + "pydantic>=2.7.1", + ], + external_package_dirs=[chains.make_abs_path_here(_LOCAL_WHISPER_LIB)], + ), + compute=chains.Compute(gpu="A10G", predict_concurrency=128), + assets=chains.Assets(secret_keys=["hf_access_token"]), + ) + + def __init__( + self, + context: chains.DeploymentContext = chains.depends_context(), + ) -> None: + snapshot_download( + repo_id="baseten/whisper_trt_large-v3_A10G_i224_o512_bs8_bw5", + local_dir=context.data_dir, + allow_patterns=["**"], + token=context.secrets["hf_access_token"], + ) + from whisper_trt import WhisperModel + + self._model = WhisperModel(str(context.data_dir), max_queue_time=0.050) + + async def run_remote(self, request: WhisperInput) -> WhisperResult: + binary_data = base64.b64decode(request.audio_b64.encode("utf-8")) + waveform = self._model.preprocess_audio(binary_data) + return await self._model.transcribe( + waveform, timestamps=True, raise_when_trimmed=True + ) diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 9408cc85f..cb9510b90 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -585,6 +585,9 @@ def _make_truss_config( # Absolute paths don't work with remote build. config.requirements_file = _REQUIREMENTS_FILENAME config.system_packages = image.apt_requirements + if image.external_package_dirs: + for ext_dir in image.external_package_dirs: + config.external_package_dirs.append(ext_dir.abs_path) # Assets. assets = chains_config.get_asset_spec() config.secrets = assets.secrets diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 062fa5c0c..b1c6ec205 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -1,7 +1,6 @@ # TODO: this file contains too much implementation -> restructure. import abc import logging -import os import pathlib import traceback from types import GenericAlias @@ -103,7 +102,8 @@ def __init__( self._original_path = original_path def _raise_if_not_exists(self, abs_path: str) -> None: - if not os.path.isfile(abs_path): + path = pathlib.Path(abs_path) + if not (path.is_file() or (path.is_dir() and any(path.iterdir()))): raise MissingDependencyError( f"With the file path `{self._original_path}` an absolute path relative " f"to the calling module `{self._creating_module}` was created, " @@ -129,6 +129,7 @@ class DockerImage(SafeModelNonSerializable): pip_requirements: list[str] = [] apt_requirements: list[str] = [] data_dir: Optional[AbsPath] = None + external_package_dirs: Optional[list[AbsPath]] = None class ComputeSpec(pydantic.BaseModel): diff --git a/truss-chains/truss_chains/deploy.py b/truss-chains/truss_chains/deploy.py index dc97385ea..977732ace 100644 --- a/truss-chains/truss_chains/deploy.py +++ b/truss-chains/truss_chains/deploy.py @@ -2,10 +2,22 @@ import inspect import logging import pathlib -from typing import Any, Dict, Iterable, Iterator, MutableMapping, Optional, Type, cast +import uuid +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Type, + cast, +) import truss from truss.remote.baseten import service as b10_service +from truss.remote.baseten import types as b10_types from truss_chains import code_gen, definitions, framework, utils @@ -19,10 +31,18 @@ def _deploy_to_baseten( f"Deploying chainlet `{model_name}` as truss model on Baseten " f"(publish={options.publish}, promote={options.promote})." ) + + # Since we are deploying a model independently of the chain, we add a random suffix to + # prevent us from running into issues with existing models with the same name. + # + # This is a bit of a hack for now. Once we support model_origin for Chains models, we + # can drop the requirement for names on models. + model_suffix = str(uuid.uuid4()).split("-")[0] + # Models must be trusted to use the API KEY secret. service = options.remote_provider.push( truss_handle, - model_name=model_name, + model_name=model_name + model_suffix, trusted=True, publish=options.publish, promote=options.promote, @@ -158,6 +178,14 @@ def get_entrypoint(self) -> b10_service.TrussService: ) return service + @property + def services(self) -> MutableMapping[str, b10_service.TrussService]: + return self._services + + @property + def entrypoint_name(self) -> str: + return self._entrypoint + @property def run_url(self) -> str: return self.get_entrypoint.predict_url @@ -221,4 +249,25 @@ def deploy_remotely( chainlet_name_to_url[chainlet_descriptor.name] = service.predict_url else: chainlet_name_to_url[chainlet_descriptor.name] = "http://dummy" + + if isinstance(options, definitions.DeploymentOptionsBaseten): + chainlets: List[b10_types.ChainletData] = [] + entrypoint_name = chain_service.entrypoint_name + + for chainlet_name, truss_service in chain_service.services.items(): + baseten_service = cast(b10_service.BasetenService, truss_service) + chainlets.append( + b10_types.ChainletData( + name=chainlet_name, + oracle_version_id=baseten_service.model_version_id, + is_entrypoint=chainlet_name == entrypoint_name, + ) + ) + + chain_id = options.remote_provider.create_chain( + chain_name=chain_service.name, chainlets=chainlets, publish=options.publish + ) + + print(f"Newly Created Chain: {chain_id}") + return chain_service diff --git a/truss-chains/truss_chains/example_chainlet.py b/truss-chains/truss_chains/example_chainlet.py index f024a8104..d39715f6b 100644 --- a/truss-chains/truss_chains/example_chainlet.py +++ b/truss-chains/truss_chains/example_chainlet.py @@ -1,31 +1,35 @@ +import random + +# For more on chains, check out https://truss.baseten.co/chains/intro. import truss_chains as chains -class DummyGenerateData(chains.ChainletBase): - def run_remote(self) -> str: - return "abc" +# By inhereting chains.ChainletBase, the chains framework will know to create a chainlet that hosts the RandInt class. +class RandInt(chains.ChainletBase): + + # run_remote must be implemented by all chainlets. This is the code that will be executed at inference time. + def run_remote(self, max_value: int) -> int: + return random.randint(1, max_value) + +# The @chains.mark_entrypoint decorator indicates that this Chainlet is the entrypoint. +# Each chain must have exactly one entrypoint. +@chains.mark_entrypoint +class HelloWorld(chains.ChainletBase): + # chains.depends indicates that the HelloWorld chainlet depends on the RandInt Chainlet + # this enables the HelloWorld chainlet to call the RandInt chainlet + def __init__(self, rand_int=chains.depends(RandInt, retries=3)) -> None: + self._rand_int = rand_int -# Nesting the classes is a hack to make it *appear* like SplitText is from a different -# module. -class shared_chainlet: - class DummySplitText(chains.ChainletBase): - def run_remote(self, data: str) -> list[str]: - return [data[:2], data[2:]] + def run_remote(self, max_value: int) -> str: + num_repetitions = self._rand_int.run_remote(max_value) + return "Hello World! " * num_repetitions -class DummyExample(chains.ChainletBase): - def __init__( - self, - data_generator: DummyGenerateData = chains.depends(DummyGenerateData), - splitter: shared_chainlet.DummySplitText = chains.depends( - shared_chainlet.DummySplitText - ), - context: chains.DeploymentContext = chains.depends_context(), - ) -> None: - self._data_generator = data_generator - self._data_splitter = splitter - self._context = context +if __name__ == "__main__": + with chains.run_local(): + hello_world_chain = HelloWorld() + result = hello_world_chain.run_remote(max_value=5) - def run_remote(self) -> list[str]: - return self._data_splitter.run_remote(self._data_generator.run_remote()) + print(result) + # Hello World! Hello World! Hello World! diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 89dd72773..7312b9a50 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -104,7 +104,7 @@ def _example_chainlet_code() -> str: logging.error("example_chainlet` is broken.", exc_info=True, stack_info=True) return "" - example_name = example_chainlet.DummyExample.__name__ + example_name = example_chainlet.HelloWorld.__name__ source = pathlib.Path(example_chainlet.__file__).read_text() tree = ast.parse(source) class_code = "" @@ -720,6 +720,10 @@ def import_target( ) -> Iterator[Type[definitions.ABCChainlet]]: module_path = pathlib.Path(module_path).resolve() module_name = module_path.stem # Use the file's name as the module name + if not os.path.isfile(module_path): + raise ImportError( + f"`{module_path}` is not a file. You must point to a file where the entrypoint is defined." + ) error_msg = f"Could not import `{target_name}` from `{module_path}`. Check path." spec = importlib.util.spec_from_file_location(module_name, module_path) diff --git a/truss/cli/cli.py b/truss/cli/cli.py index aa1e084b6..520d043fc 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -14,6 +14,7 @@ import rich.table import rich_click as click import truss +from InquirerPy import inquirer from truss.cli.console import console from truss.cli.create import ask_name from truss.remote.baseten.core import ( @@ -344,6 +345,12 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]: is_flag=True, help="Produces only generated files, but doesn't deploy anything.", ) +@click.option( + "--remote", + type=str, + required=False, + help="Name of the remote in .trussrc to push to.", +) @error_handling def deploy( source: Path, @@ -353,6 +360,7 @@ def deploy( promote: bool, wait: bool, dryrun: bool, + remote: Optional[str], ) -> None: """ Deploys a chain remotely. @@ -374,6 +382,7 @@ def deploy( promote=promote, publish=publish, only_generate_trusses=dryrun, + remote=remote, ) service = chains_deploy.deploy_remotely(entrypoint_cls, options) @@ -417,6 +426,66 @@ def deploy( ) +@chains.command(name="init") # type: ignore +@click.option( + "--target_directory", + type=str, + required=False, + help="Name of the chain to be deployed, if not given, the user will be prompted for a name", +) +@error_handling +def chains_init( + target_directory: Optional[str], +) -> None: + """ + Initializes a chains project with hello.py + """ + FILENAME = "main.py" + if not target_directory: + target_directory = inquirer.text( + qmark="", message="Enter the target directory for the chain" + ).execute() + # Ensure that `None` is replaced to empty string. This will write a main.py + # file to the cwd. + if not target_directory: + target_directory = "" + # we do this cast to satisfy mypy when handling the output of the inquirer.text call + target_directory = str(target_directory) + cur_path = os.getcwd() + abs_path = os.path.join(cur_path, target_directory) + filename = os.path.join(abs_path, FILENAME) + if os.path.exists(filename): + raise click.UsageError( + f"Cannot init chains project with {filename}. Path already exists" + ) + user_friendly_path = os.path.join(target_directory, FILENAME) + rich.print(f"Creating {user_friendly_path}...\n") + + source_code = _load_example_chainlet_code() + + if not os.path.exists(abs_path): + os.mkdir(abs_path) + with open(filename, "w") as f: + f.write(source_code) + + rich.print( + "Next steps:\n", + f"💻 Run `python {user_friendly_path}` to execute the code locally\n", + f"🚢 Run `truss chains deploy {user_friendly_path}` to deploy the chain to Baseten\n", + ) + + +def _load_example_chainlet_code() -> str: + try: + from truss_chains import example_chainlet + # if the example is faulty, a validation error would be raised + except Exception as e: + raise Exception("Failed to load starter code. Please notify support.") from e + + source = Path(example_chainlet.__file__).read_text() + return source + + def _extract_and_validate_model_identifier( target_directory: str, model_id: Optional[str], diff --git a/truss/constants.py b/truss/constants.py index 7aab8a743..d52581607 100644 --- a/truss/constants.py +++ b/truss/constants.py @@ -32,6 +32,7 @@ SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"} +TRTLLM_PREDICT_CONCURRENCY = 512 # Alias for TEMPLATES_DIR SERVING_DIR: pathlib.Path = TEMPLATES_DIR diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 57ccfae7f..52c5678ec 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -30,6 +30,7 @@ SYSTEM_PACKAGES_TXT_FILENAME, TEMPLATES_DIR, TRTLLM_BASE_IMAGE, + TRTLLM_PREDICT_CONCURRENCY, TRTLLM_PYTHON_EXECUTABLE, TRTLLM_TRUSS_DIR, USE_BRITON, @@ -353,6 +354,8 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): "Tensor parallelism and GPU count must be the same for TRT-LLM" ) + config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY + config.base_image = BaseImage( image=BRITON_TRTLLM_BASE_IMAGE if USE_BRITON else TRTLLM_BASE_IMAGE, python_executable_path=TRTLLM_PYTHON_EXECUTABLE, @@ -447,6 +450,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): use_hf_secret, cached_files, external_data_files, + self._spec.build_commands, ) def _render_dockerfile( @@ -457,6 +461,7 @@ def _render_dockerfile( use_hf_secret: bool, cached_files: List[str], external_data_files: List[Tuple[str, str]], + build_commands: List[str], ): config = self._spec.config data_dir = build_dir / config.data_dir @@ -509,6 +514,7 @@ def _render_dockerfile( hf_access_token=hf_access_token, hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME, external_data_files=external_data_files, + build_commands=build_commands, **FILENAME_CONSTANTS_MAP, ) docker_file_path = build_dir / MODEL_DOCKERFILE_NAME diff --git a/truss/local/local_config_handler.py b/truss/local/local_config_handler.py index 1817da672..14f316614 100644 --- a/truss/local/local_config_handler.py +++ b/truss/local/local_config_handler.py @@ -69,6 +69,12 @@ def _config_path(): def secrets_dir_path(): return LocalConfigHandler.TRUSS_CONFIG_DIR / "secrets" + @staticmethod + def bptr_data_resolution_dir_path(): + bptr_data_dir = LocalConfigHandler.TRUSS_CONFIG_DIR / "bptr" + bptr_data_dir.mkdir(exist_ok=True, parents=True) + return bptr_data_dir + @staticmethod def _signatures_dir_path(): return LocalConfigHandler.TRUSS_CONFIG_DIR / "signatures" diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 286d4303f..9130df78e 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -1,8 +1,9 @@ import logging from enum import Enum -from typing import Any, Optional +from typing import Any, List, Optional import requests +from truss.remote.baseten import types as b10_types from truss.remote.baseten.auth import ApiKey, AuthService from truss.remote.baseten.error import ApiError from truss.remote.baseten.utils.transfer import base64_encoded_json_str @@ -22,6 +23,16 @@ DEFAULT_API_DOMAIN = "https://api.baseten.co" +def _chainlet_data_to_graphql_mutation(chainlet: b10_types.ChainletData): + return f""" + {{ + name: "{chainlet.name}", + oracle_version_id: "{chainlet.oracle_version_id}", + is_entrypoint: {'true' if chainlet.is_entrypoint else 'false'} + }} + """ + + class BasetenApi: class GraphQLErrorCodes(Enum): RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND" @@ -51,6 +62,7 @@ def auth_token(self) -> ApiKey: def _post_graphql_query(self, query_string: str) -> dict: headers = self._auth_token.header() + resp = requests.post( self._graphql_api_url, data={"query": query_string}, @@ -174,6 +186,84 @@ def create_development_model_from_truss( resp = self._post_graphql_query(query_string) return resp["data"]["deploy_draft_truss"] + def deploy_chain(self, name: str, chainlet_data: List[b10_types.ChainletData]): + chainlet_data_strings = [ + _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data + ] + + chainlets_string = ", ".join(chainlet_data_strings) + query_string = f""" + mutation {{ + deploy_chain( + name: "{name}", + chainlets: [{chainlets_string}] + ) {{ + id + }} + }} + """ + resp = self._post_graphql_query(query_string) + return resp["data"]["deploy_chain"] + + def deploy_draft_chain( + self, name: str, chainlet_data: List[b10_types.ChainletData] + ): + chainlet_data_strings = [ + _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data + ] + chainlets_string = ", ".join(chainlet_data_strings) + query_string = f""" + mutation {{ + deploy_draft_chain( + name: "{name}", + chainlets: [{chainlets_string}] + ) {{ + chain_id + }} + }} + """ + resp = self._post_graphql_query(query_string) + return resp["data"]["deploy_draft_chain"] + + def deploy_chain_deployment( + self, chain_id: str, chainlet_data: List[b10_types.ChainletData] + ): + chainlet_data_strings = [ + _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data + ] + chainlets_string = ", ".join(chainlet_data_strings) + query_string = f""" + mutation {{ + deploy_chain_deployment( + chain_id: "{chain_id}", + chainlets: [{chainlets_string}] + ) {{ + chain_id + chain_deployment_id + }} + }} + """ + resp = self._post_graphql_query(query_string) + return resp["data"]["deploy_chain_deployment"] + + def get_chain_by_id(self, id: str): + + # TODO: Implement + pass + + def get_chains(self): + query_string = """ + { + chains { + id + name + } + } + """ + + resp = self._post_graphql_query(query_string) + return resp["data"]["chains"] + def models(self): query_string = """ { diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 1c93133f7..1c5594035 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -2,6 +2,7 @@ from typing import IO, List, Optional, Tuple import truss +from truss.remote.baseten import types as b10_types from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.error import ApiError from truss.remote.baseten.utils.tar import create_tar_with_progress_bar @@ -35,6 +36,23 @@ def __init__(self, model_version_id: str): self.value = model_version_id +def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]: + """ + Check if a chain with the given name exists in the Baseten remote. + + Args: + api: BasetenApi instance + chain_name: Name of the chain to check for existence + + Returns: + chain_id if present, otherwise None + """ + chains = api.get_chains() + + chain_name_id_mapping = {chain["name"]: chain["id"] for chain in chains} + return chain_name_id_mapping.get(chain_name) + + def exists_model(api: BasetenApi, model_name: str) -> Optional[str]: """ Check if a model with the given name exists in the Baseten remote. @@ -60,6 +78,22 @@ def exists_model(api: BasetenApi, model_name: str) -> Optional[str]: return model["model"]["id"] +def create_chain( + api: BasetenApi, + chain_id: Optional[str], + chain_name: str, + chainlets: List[b10_types.ChainletData], + is_draft: bool = False, +) -> str: + if is_draft: + return api.deploy_draft_chain(chain_name, chainlets)["chain_id"] + + if chain_id: + return api.deploy_chain_deployment(chain_id, chainlets)["chain_id"] + + return api.deploy_chain(chain_name, chainlets)["id"] + + def get_model_versions(api: BasetenApi, model_name: ModelName) -> Tuple[str, List]: query_result = api.get_model(model_name.value)["model"] return (query_result["id"], query_result["versions"]) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 1282aeade..25ba28a4f 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -16,8 +16,10 @@ ModelName, ModelVersionId, archive_truss, + create_chain, create_truss_service, exists_model, + get_chain_id_by_name, get_dev_version, get_dev_version_from_versions, get_model_versions, @@ -26,6 +28,7 @@ ) from truss.remote.baseten.error import ApiError from truss.remote.baseten.service import BasetenService +from truss.remote.baseten.types import ChainletData from truss.remote.baseten.utils.transfer import base64_encoded_json_str from truss.remote.truss_remote import TrussRemote from truss.truss_config import ModelServer @@ -44,6 +47,18 @@ def __init__(self, remote_url: str, api_key: str, **kwargs): def api(self) -> BasetenApi: return self._api + def create_chain( + self, chain_name: str, chainlets: List[ChainletData], publish: bool = False + ) -> str: + chain_id = get_chain_id_by_name(self._api, chain_name) + return create_chain( + self._api, + chain_id=chain_id, + chain_name=chain_name, + chainlets=chainlets, + is_draft=not publish, + ) + def push( # type: ignore self, truss_handle: TrussHandle, diff --git a/truss/remote/baseten/types.py b/truss/remote/baseten/types.py new file mode 100644 index 000000000..1331ab61f --- /dev/null +++ b/truss/remote/baseten/types.py @@ -0,0 +1,7 @@ +import pydantic + + +class ChainletData(pydantic.BaseModel): + name: str + oracle_version_id: str + is_entrypoint: bool diff --git a/truss/templates/base.Dockerfile.jinja b/truss/templates/base.Dockerfile.jinja index 604fdd975..d7a11b144 100644 --- a/truss/templates/base.Dockerfile.jinja +++ b/truss/templates/base.Dockerfile.jinja @@ -39,10 +39,12 @@ RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{syst {% block install_requirements %} {%- if should_install_user_requirements_file %} COPY ./{{user_supplied_requirements_filename}} {{user_supplied_requirements_filename}} +RUN cat {{user_supplied_requirements_filename}} RUN pip install -r {{user_supplied_requirements_filename}} --no-cache-dir && rm -rf /root/.cache/pip {%- endif %} {%- if should_install_requirements %} COPY ./{{config_requirements_filename}} {{config_requirements_filename}} +RUN cat {{config_requirements_filename}} RUN pip install -r {{config_requirements_filename}} --no-cache-dir && rm -rf /root/.cache/pip {%- endif %} {% endblock %} diff --git a/truss/templates/control/requirements.txt b/truss/templates/control/requirements.txt index f9bc7e9a3..9e7163895 100644 --- a/truss/templates/control/requirements.txt +++ b/truss/templates/control/requirements.txt @@ -1,9 +1,9 @@ dataclasses-json==0.5.7 -truss==0.9.10 # This needs to support py3.8. It can be incremented to releases after #955 (i.e. >0.9.12). +truss==0.9.14 fastapi==0.109.1 uvicorn==0.24.0 uvloop==0.19.0 tenacity==8.1.0 -httpx==0.24.1 +httpx==0.27.0 python-json-logger==2.0.2 loguru==0.7.2 diff --git a/truss/templates/server.Dockerfile.jinja b/truss/templates/server.Dockerfile.jinja index 690980bcd..896279901 100644 --- a/truss/templates/server.Dockerfile.jinja +++ b/truss/templates/server.Dockerfile.jinja @@ -62,6 +62,13 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }} {% endfor %} {%- endif %} + +{%- if build_commands %} +{% for command in build_commands %} +RUN {{ command }} +{% endfor %} +{%- endif %} + # Copy data before code for better caching {%- if data_dir_exists %} COPY ./{{config.data_dir}} /app/data diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 1563cadc3..c4ed58125 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -31,6 +31,7 @@ from common.schema import TrussSchema from fastapi import HTTPException from pydantic import BaseModel +from shared.lazy_data_resolver import LazyDataResolver from shared.secrets_resolver import SecretsResolver from typing_extensions import ParamSpec @@ -164,6 +165,8 @@ def try_load(self): model_init_params["data_dir"] = data_dir if _signature_accepts_keyword_arg(model_class_signature, "secrets"): model_init_params["secrets"] = SecretsResolver.get_secrets(self._config) + if _signature_accepts_keyword_arg(model_class_signature, "lazy_data_resolver"): + model_init_params["lazy_data_resolver"] = LazyDataResolver(data_dir).fetch() apply_patches( self._config.get("apply_library_patches", True), self._config["requirements"], diff --git a/truss/templates/shared/lazy_data_resolver.py b/truss/templates/shared/lazy_data_resolver.py new file mode 100644 index 000000000..426f5cad4 --- /dev/null +++ b/truss/templates/shared/lazy_data_resolver.py @@ -0,0 +1,70 @@ +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List + +import pydantic +import yaml + +try: + from shared.util import download_from_url_using_requests +except ModuleNotFoundError: + from truss.templates.shared.util import download_from_url_using_requests + +LAZY_DATA_RESOLVER_PATH = Path("/bptr/bptr-manifest") +NUM_WORKERS = 4 + + +class Resolution(pydantic.BaseModel): + url: str + expiration_timestamp: int + + +class BasetenPointer(pydantic.BaseModel): + """Specification for lazy data resolution for download of large files, similar to Git LFS pointers""" + + resolution: Resolution + uid: str + file_name: str + hashtype: str + hash: str + size: int + + +class BasetenPointerManifest(pydantic.BaseModel): + pointers: List[BasetenPointer] + + +class LazyDataResolver: + def __init__(self, data_dir: Path): + self._data_dir: Path = data_dir + self._bptr_resolution: Dict[str, str] = _read_bptr_resolution() + + def fetch(self): + with ThreadPoolExecutor(NUM_WORKERS) as executor: + futures = {} + for file_name, resolved_url in self._bptr_resolution.items(): + futures[file_name] = executor.submit( + download_from_url_using_requests, + resolved_url, + self._data_dir / file_name, + ) + for file_name, future in futures.items(): + if not future: + raise RuntimeError(f"Download failure for file {file_name}") + + +def _read_bptr_resolution() -> Dict[str, str]: + if not LAZY_DATA_RESOLVER_PATH.is_file(): + return {} + bptr_manifest = BasetenPointerManifest( + **yaml.safe_load(LAZY_DATA_RESOLVER_PATH.read_text()) + ) + resolution_map = {} + for bptr in bptr_manifest.pointers: + if bptr.resolution.expiration_timestamp < int( + datetime.now(timezone.utc).timestamp() + ): + raise RuntimeError("Baseten pointer lazy data resolution has expired") + resolution_map[bptr.file_name] = bptr.resolution.url + return resolution_map diff --git a/truss/templates/shared/util.py b/truss/templates/shared/util.py index 6fc8f245d..302198c9d 100644 --- a/truss/templates/shared/util.py +++ b/truss/templates/shared/util.py @@ -1,10 +1,14 @@ import multiprocessing import os +import shutil import sys +from pathlib import Path from typing import Callable, Dict, List, TypeVar import psutil +import requests +BLOB_DOWNLOAD_TIMEOUT_SECS = 600 # 10 minutes # number of seconds to wait for truss server child processes before sending kill signal CHILD_PROCESS_WAIT_TIMEOUT_SECONDS = 120 @@ -85,3 +89,16 @@ def kill_child_processes(parent_pid: int): def transform_keys(d: Dict[X, Z], fn: Callable[[X], Y]) -> Dict[Y, Z]: return {fn(key): value for key, value in d.items()} + + +def download_from_url_using_requests(URL: str, download_to: Path): + # Streaming download to keep memory usage low + resp = requests.get( + URL, + allow_redirects=True, + stream=True, + timeout=BLOB_DOWNLOAD_TIMEOUT_SECS, + ) + resp.raise_for_status() + with download_to.open("wb") as file: + shutil.copyfileobj(resp.raw, file) diff --git a/truss/test_data/test_build_commands/config.yaml b/truss/test_data/test_build_commands/config.yaml new file mode 100644 index 000000000..6bbf66e02 --- /dev/null +++ b/truss/test_data/test_build_commands/config.yaml @@ -0,0 +1,13 @@ +build_commands: + - mkdir example_dir + - cd example_dir && touch testing.py +model_metadata: {} +model_name: null +model_type: custom +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false diff --git a/truss/test_data/test_build_commands/model/model.py b/truss/test_data/test_build_commands/model/model.py new file mode 100644 index 000000000..ebca785a0 --- /dev/null +++ b/truss/test_data/test_build_commands/model/model.py @@ -0,0 +1,17 @@ +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + def predict(self, model_input: Any) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + return {"predictions": [1, 2]} diff --git a/truss/test_data/test_build_commands_failure/config.yaml b/truss/test_data/test_build_commands_failure/config.yaml new file mode 100644 index 000000000..d6ba09d51 --- /dev/null +++ b/truss/test_data/test_build_commands_failure/config.yaml @@ -0,0 +1,14 @@ +build_commands: + - mkdir example_dir + - cd example_dir && touch testing.py + - haha lol +model_metadata: {} +model_name: null +model_type: custom +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false diff --git a/truss/test_data/test_build_commands_failure/model/model.py b/truss/test_data/test_build_commands_failure/model/model.py new file mode 100644 index 000000000..ebca785a0 --- /dev/null +++ b/truss/test_data/test_build_commands_failure/model/model.py @@ -0,0 +1,17 @@ +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + def predict(self, model_input: Any) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + return {"predictions": [1, 2]} diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index 333db8c2a..3afccd09e 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -16,6 +16,7 @@ ) from truss.contexts.local_loader.docker_build_emulator import DockerBuildEmulator from truss.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR +from truss.truss_handle import TrussHandle from truss.types import Example CUSTOM_MODEL_CODE = """ @@ -367,6 +368,34 @@ def no_params_init_custom_model(tmp_path): ) +@pytest.fixture +def custom_model_trt_llm(tmp_path): + def modify_handle(h: TrussHandle): + with _modify_yaml(h.spec.config_path) as content: + h.enable_gpu() + content["trt_llm"] = { + "build": { + "base_model": "llama", + "max_input_len": 1024, + "max_output_len": 1024, + "max_batch_size": 512, + "max_beam_width": 1, + "checkpoint_repository": { + "source": "LOCAL", + "repo": "/path/to/checkpoint", + }, + } + } + content["resources"]["accelerator"] = "H100:1" + + yield _custom_model_from_code( + tmp_path, + "my_trt_llm_model", + CUSTOM_MODEL_CODE, + handle_ops=modify_handle, + ) + + @pytest.fixture def useless_file(tmp_path): f = tmp_path / "useless.py" diff --git a/truss/tests/contexts/image_builder/test_serving_image_builder.py b/truss/tests/contexts/image_builder/test_serving_image_builder.py index 95ea83b8c..6cd80059f 100644 --- a/truss/tests/contexts/image_builder/test_serving_image_builder.py +++ b/truss/tests/contexts/image_builder/test_serving_image_builder.py @@ -1,9 +1,19 @@ +import filecmp +import os import time from pathlib import Path from tempfile import TemporaryDirectory from unittest.mock import patch import pytest +from truss.constants import ( + BASE_TRTLLM_REQUIREMENTS, + OPENAI_COMPATIBLE_TAG, + TRTLLM_BASE_IMAGE, + TRTLLM_PREDICT_CONCURRENCY, + TRTLLM_PYTHON_EXECUTABLE, + TRTLLM_TRUSS_DIR, +) from truss.contexts.image_builder.serving_image_builder import ( HF_ACCESS_TOKEN_FILE_NAME, ServingImageBuilderContext, @@ -288,3 +298,36 @@ def test_ignore_files_during_build_setup(custom_model_truss_dir_with_truss_ignor assert not (build_path / ignore_folder).exists() assert (build_path / do_not_ignore_folder).exists() + + +def test_trt_llm_build_dir(custom_model_trt_llm): + th = TrussHandle(custom_model_trt_llm) + builder_context = ServingImageBuilderContext + image_builder = builder_context.run(th.spec.truss_dir) + with TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_builder.prepare_image_build_dir(tmp_path) + build_th = TrussHandle(tmp_path) + + # Check that all files were copied + for dirpath, dirnames, filenames in os.walk(TRTLLM_TRUSS_DIR): + rel_path = os.path.relpath(dirpath, TRTLLM_TRUSS_DIR) + for filename in filenames: + src_file = os.path.join(dirpath, filename) + dest_file = os.path.join(tmp_path, rel_path, filename) + assert os.path.exists(dest_file), f"{dest_file} was not copied" + assert filecmp.cmp( + src_file, dest_file, shallow=False + ), f"{src_file} and {dest_file} are not the same" + + assert ( + build_th.spec.config.runtime.predict_concurrency + == TRTLLM_PREDICT_CONCURRENCY + ) + assert build_th.spec.config.base_image.image == TRTLLM_BASE_IMAGE + assert ( + build_th.spec.config.base_image.python_executable_path + == TRTLLM_PYTHON_EXECUTABLE + ) + assert BASE_TRTLLM_REQUIREMENTS == build_th.spec.config.requirements + assert OPENAI_COMPATIBLE_TAG in build_th.spec.config.model_metadata["tags"] diff --git a/truss/tests/templates/core/server/test_lazy_data_resolver.py b/truss/tests/templates/core/server/test_lazy_data_resolver.py new file mode 100644 index 000000000..030fc4c31 --- /dev/null +++ b/truss/tests/templates/core/server/test_lazy_data_resolver.py @@ -0,0 +1,132 @@ +import datetime +import json +from contextlib import nullcontext +from pathlib import Path +from typing import Callable +from unittest.mock import patch + +import pytest +import requests_mock +from truss.templates.shared.lazy_data_resolver import ( + LAZY_DATA_RESOLVER_PATH, + LazyDataResolver, +) + + +@pytest.fixture +def baseten_pointer_manifest_mock() -> Callable: + def _baseten_pointer_manifest_mock( + foo_expiry_timestamp: int, bar_expiry_timestamp: int + ): + return f""" +pointers: +- uid: foo + file_name: foo-name + hashtype: hash-type + hash: foo-hash + size: 100 + resolution: + url: https://foo-rl + expiration_timestamp: {foo_expiry_timestamp} +- uid: bar + file_name: bar-name + hashtype: hash-type + hash: bar-hash + size: 1000 + resolution: + url: https://bar-rl + expiration_timestamp: {bar_expiry_timestamp} +""" + + return _baseten_pointer_manifest_mock + + +def test_lazy_data_resolution_not_found(): + ldr = LazyDataResolver(Path("foo")) + assert not LAZY_DATA_RESOLVER_PATH.exists() + assert ldr._bptr_resolution == {} + + +@pytest.mark.parametrize( + "foo_expiry,bar_expiry,expectation", + [ + ( + int( + datetime.datetime(3000, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + int( + datetime.datetime(3000, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + nullcontext(), + ), + ( + int( + datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + int( + datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + pytest.raises(RuntimeError), + ), + ], +) +def test_lazy_data_resolution( + baseten_pointer_manifest_mock, foo_expiry, bar_expiry, expectation, tmp_path +): + baseten_pointer_manifest_mock = baseten_pointer_manifest_mock( + foo_expiry, bar_expiry + ) + manifest_path = tmp_path / "bptr" / "bptr-manifest" + manifest_path.parent.mkdir() + manifest_path.touch() + manifest_path.write_text(baseten_pointer_manifest_mock) + with patch( + "truss.templates.shared.lazy_data_resolver.LAZY_DATA_RESOLVER_PATH", + manifest_path, + ): + with expectation: + ldr = LazyDataResolver(Path("foo")) + assert ldr._bptr_resolution == { + "foo-name": "https://foo-rl", + "bar-name": "https://bar-rl", + } + + +@pytest.mark.parametrize( + "foo_expiry,bar_expiry", + [ + ( + int( + datetime.datetime(3000, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + int( + datetime.datetime(3000, 1, 1, tzinfo=datetime.timezone.utc).timestamp() + ), + ) + ], +) +def test_lazy_data_fetch( + baseten_pointer_manifest_mock, foo_expiry, bar_expiry, tmp_path +): + baseten_pointer_manifest_mock = baseten_pointer_manifest_mock( + foo_expiry, bar_expiry + ) + manifest_path = tmp_path / "bptr" / "bptr-manifest" + manifest_path.parent.mkdir() + manifest_path.touch() + manifest_path.write_text(baseten_pointer_manifest_mock) + with patch( + "truss.templates.shared.lazy_data_resolver.LAZY_DATA_RESOLVER_PATH", + manifest_path, + ): + data_dir = Path(tmp_path) + ldr = LazyDataResolver(data_dir) + with requests_mock.Mocker() as m: + for file_name, url in ldr._bptr_resolution.items(): + resp = {"file_name": file_name, "url": url} + m.get(url, json=resp) + ldr.fetch() + for file_name, url in ldr._bptr_resolution.items(): + assert (ldr._data_dir / file_name).read_text() == json.dumps( + {"file_name": file_name, "url": url} + ) diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index f57c39ea3..ff383e8c5 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -152,6 +152,7 @@ def test_parse_base_image(input_dict, expect_base_image, output_dict): def generate_default_config(): config = { + "build_commands": [], "environment_variables": {}, "external_package_dirs": [], "model_metadata": {}, @@ -176,7 +177,8 @@ def test_default_config_not_crowded_end_to_end(): requirements=[], ) - config_yaml = """environment_variables: {} + config_yaml = """build_commands: [] +environment_variables: {} external_package_dirs: [] model_metadata: {} model_name: null diff --git a/truss/tests/test_testing_utilities_for_other_tests.py b/truss/tests/test_testing_utilities_for_other_tests.py index 8e856307f..7ec085a84 100644 --- a/truss/tests/test_testing_utilities_for_other_tests.py +++ b/truss/tests/test_testing_utilities_for_other_tests.py @@ -1,7 +1,7 @@ # This file contains shared code to be used in other tests # TODO(pankaj): Using a tests file for shared code is not ideal, we should # move it to a regular file. This is a short term hack. - +import json import shutil import subprocess import time @@ -17,12 +17,46 @@ @contextmanager def ensure_kill_all(): try: - yield + with _show_container_logs_if_raised(): + yield finally: kill_all_with_retries() ensure_free_disk_space() +def _human_readable_json_logs(raw_logs: str) -> str: + output = [] + for line in raw_logs.splitlines(): + try: + log_entry = json.loads(line) + human_readable_log = " ".join( + f"{key}: {value}" for key, value in log_entry.items() + ) + output.append(f"\t{human_readable_log}") + except Exception: + output.append(line) + return "\n".join(output) + + +@contextmanager +def _show_container_logs_if_raised(): + initial_ids = {c.id for c in get_containers({TRUSS: True})} + exception_raised = False + try: + yield + except Exception: + exception_raised = True + raise + finally: + if exception_raised: + print("An exception was raised, showing logs of all containers.") + containers = get_containers({TRUSS: True}) + new_containers = [c for c in containers if c.id not in initial_ids] + for container in new_containers: + print(f"Logs for container {container.name} ({container.id}):") + print(_human_readable_json_logs(container.logs())) + + def kill_all_with_retries(num_retries: int = 10): kill_all() attempts = 0 diff --git a/truss/tests/test_truss_handle.py b/truss/tests/test_truss_handle.py index 13d577e94..fddebc233 100644 --- a/truss/tests/test_truss_handle.py +++ b/truss/tests/test_truss_handle.py @@ -458,6 +458,27 @@ def test_add_environment_variable(custom_model_truss_dir_with_pre_and_post): Docker.client().kill(container) +@pytest.mark.integration +def test_build_commands(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + truss_dir = truss_root / "test_data" / "test_build_commands" + tr = TrussHandle(truss_dir) + with ensure_kill_all(): + r1 = tr.docker_predict([1, 2]) + assert r1 == {"predictions": [1, 2]} + + +@pytest.mark.integration +def test_build_commands_failure(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + truss_dir = truss_root / "test_data" / "test_build_commands_failure" + tr = TrussHandle(truss_dir) + try: + tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + except DockerException as exc: + assert "It returned with code 1" in str(exc) + + def test_add_data_file(custom_model_truss_dir_with_pre_and_post, tmp_path): th = TrussHandle(custom_model_truss_dir_with_pre_and_post) data_filepath = tmp_path / "test_data.txt" @@ -800,6 +821,7 @@ def _read_readme(filename: str) -> str: def generate_default_config(): config = { + "build_commands": [], "environment_variables": {}, "external_package_dirs": [], "model_metadata": {}, diff --git a/truss/truss_config.py b/truss/truss_config.py index a0d9f227f..5056be36d 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -498,6 +498,7 @@ class TrussConfig: base_image: Optional[BaseImage] = None model_cache: ModelCache = field(default_factory=ModelCache) trt_llm: Optional[TRTLLMConfiguration] = None + build_commands: Optional[List[str]] = field(default_factory=list) @property def canonical_python_version(self) -> str: @@ -553,6 +554,7 @@ def from_dict(d): trt_llm=transform_optional( d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x) ), + build_commands=d.get("build_commands", []), ) config.validate() return config @@ -613,6 +615,7 @@ def validate(self): "resources", "secrets", "system_packages", + "build_commands", }, BaseImage: {"image", "python_executable_path"}, } diff --git a/truss/truss_handle.py b/truss/truss_handle.py index e7e785251..23a557d00 100644 --- a/truss/truss_handle.py +++ b/truss/truss_handle.py @@ -219,7 +219,12 @@ def _run_docker(gpus: Optional[str] = None): "type=bind", f"src={str(secrets_mount_dir_path)}", "target=/secrets", - ] + ], + [ + "type=bind", + f"src={str(LocalConfigHandler.bptr_data_resolution_dir_path())}", + "target=/bptr", + ], ], gpus=gpus, envs=envs, diff --git a/truss/truss_spec.py b/truss/truss_spec.py index d98d399d8..72fc7f359 100644 --- a/truss/truss_spec.py +++ b/truss/truss_spec.py @@ -34,6 +34,10 @@ def data_dir(self) -> Path: def external_data(self) -> Optional[ExternalData]: return self._config.external_data + @property + def build_commands(self) -> List[str]: + return self._config.build_commands + @property def model_module_dir(self) -> Path: return self._truss_dir / self._config.model_module_dir diff --git a/truss/util/download.py b/truss/util/download.py index c96ac3fd0..4a27520c0 100644 --- a/truss/util/download.py +++ b/truss/util/download.py @@ -74,12 +74,12 @@ def _download_from_url_using_b10cp( def _download_external_data_using_requests(data_dir: Path, external_data: ExternalData): for item in external_data.items: - _download_from_url_using_requests( + download_from_url_using_requests( item.url, (data_dir / item.local_data_path).resolve() ) -def _download_from_url_using_requests(URL: str, download_to: Path): +def download_from_url_using_requests(URL: str, download_to: Path): # Streaming download to keep memory usage low resp = requests.get( URL,