Skip to content

Commit

Permalink
Set default predict_concurrency when using trt-llm to 512 (#954)
Browse files Browse the repository at this point in the history
* Set default predict_concurrency when using trt-llm to 512

* update tests
  • Loading branch information
bdubayah authored May 31, 2024
1 parent 7e17bf7 commit 473a06b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}

TRTLLM_PREDICT_CONCURRENCY = 512

# Alias for TEMPLATES_DIR
SERVING_DIR: pathlib.Path = TEMPLATES_DIR
Expand Down
3 changes: 3 additions & 0 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SYSTEM_PACKAGES_TXT_FILENAME,
TEMPLATES_DIR,
TRTLLM_BASE_IMAGE,
TRTLLM_PREDICT_CONCURRENCY,
TRTLLM_PYTHON_EXECUTABLE,
TRTLLM_TRUSS_DIR,
USE_BRITON,
Expand Down Expand Up @@ -353,6 +354,8 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)

config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY

config.base_image = BaseImage(
image=BRITON_TRTLLM_BASE_IMAGE if USE_BRITON else TRTLLM_BASE_IMAGE,
python_executable_path=TRTLLM_PYTHON_EXECUTABLE,
Expand Down
29 changes: 29 additions & 0 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from truss.contexts.local_loader.docker_build_emulator import DockerBuildEmulator
from truss.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR
from truss.truss_handle import TrussHandle
from truss.types import Example

CUSTOM_MODEL_CODE = """
Expand Down Expand Up @@ -367,6 +368,34 @@ def no_params_init_custom_model(tmp_path):
)


@pytest.fixture
def custom_model_trt_llm(tmp_path):
def modify_handle(h: TrussHandle):
with _modify_yaml(h.spec.config_path) as content:
h.enable_gpu()
content["trt_llm"] = {
"build": {
"base_model": "llama",
"max_input_len": 1024,
"max_output_len": 1024,
"max_batch_size": 512,
"max_beam_width": 1,
"checkpoint_repository": {
"source": "LOCAL",
"repo": "/path/to/checkpoint",
},
}
}
content["resources"]["accelerator"] = "H100:1"

yield _custom_model_from_code(
tmp_path,
"my_trt_llm_model",
CUSTOM_MODEL_CODE,
handle_ops=modify_handle,
)


@pytest.fixture
def useless_file(tmp_path):
f = tmp_path / "useless.py"
Expand Down
43 changes: 43 additions & 0 deletions truss/tests/contexts/image_builder/test_serving_image_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import filecmp
import os
import time
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch

import pytest
from truss.constants import (
BASE_TRTLLM_REQUIREMENTS,
OPENAI_COMPATIBLE_TAG,
TRTLLM_BASE_IMAGE,
TRTLLM_PREDICT_CONCURRENCY,
TRTLLM_PYTHON_EXECUTABLE,
TRTLLM_TRUSS_DIR,
)
from truss.contexts.image_builder.serving_image_builder import (
HF_ACCESS_TOKEN_FILE_NAME,
ServingImageBuilderContext,
Expand Down Expand Up @@ -288,3 +298,36 @@ def test_ignore_files_during_build_setup(custom_model_truss_dir_with_truss_ignor

assert not (build_path / ignore_folder).exists()
assert (build_path / do_not_ignore_folder).exists()


def test_trt_llm_build_dir(custom_model_trt_llm):
th = TrussHandle(custom_model_trt_llm)
builder_context = ServingImageBuilderContext
image_builder = builder_context.run(th.spec.truss_dir)
with TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
image_builder.prepare_image_build_dir(tmp_path)
build_th = TrussHandle(tmp_path)

# Check that all files were copied
for dirpath, dirnames, filenames in os.walk(TRTLLM_TRUSS_DIR):
rel_path = os.path.relpath(dirpath, TRTLLM_TRUSS_DIR)
for filename in filenames:
src_file = os.path.join(dirpath, filename)
dest_file = os.path.join(tmp_path, rel_path, filename)
assert os.path.exists(dest_file), f"{dest_file} was not copied"
assert filecmp.cmp(
src_file, dest_file, shallow=False
), f"{src_file} and {dest_file} are not the same"

assert (
build_th.spec.config.runtime.predict_concurrency
== TRTLLM_PREDICT_CONCURRENCY
)
assert build_th.spec.config.base_image.image == TRTLLM_BASE_IMAGE
assert (
build_th.spec.config.base_image.python_executable_path
== TRTLLM_PYTHON_EXECUTABLE
)
assert BASE_TRTLLM_REQUIREMENTS == build_th.spec.config.requirements
assert OPENAI_COMPATIBLE_TAG in build_th.spec.config.model_metadata["tags"]

0 comments on commit 473a06b

Please sign in to comment.