From e2709c846588d3857fa026c350ae3621e1ab500f Mon Sep 17 00:00:00 2001 From: Bryce Dubayah Date: Fri, 31 May 2024 18:33:36 +0000 Subject: [PATCH] update tests --- .../test_serving_image_builder.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/truss/tests/contexts/image_builder/test_serving_image_builder.py b/truss/tests/contexts/image_builder/test_serving_image_builder.py index 065f50f1a..6cd80059f 100644 --- a/truss/tests/contexts/image_builder/test_serving_image_builder.py +++ b/truss/tests/contexts/image_builder/test_serving_image_builder.py @@ -1,10 +1,19 @@ +import filecmp +import os import time from pathlib import Path from tempfile import TemporaryDirectory from unittest.mock import patch import pytest -from truss.constants import TRTLLM_PREDICT_CONCURRENCY +from truss.constants import ( + BASE_TRTLLM_REQUIREMENTS, + OPENAI_COMPATIBLE_TAG, + TRTLLM_BASE_IMAGE, + TRTLLM_PREDICT_CONCURRENCY, + TRTLLM_PYTHON_EXECUTABLE, + TRTLLM_TRUSS_DIR, +) from truss.contexts.image_builder.serving_image_builder import ( HF_ACCESS_TOKEN_FILE_NAME, ServingImageBuilderContext, @@ -299,7 +308,26 @@ def test_trt_llm_build_dir(custom_model_trt_llm): tmp_path = Path(tmp_dir) image_builder.prepare_image_build_dir(tmp_path) build_th = TrussHandle(tmp_path) + + # Check that all files were copied + for dirpath, dirnames, filenames in os.walk(TRTLLM_TRUSS_DIR): + rel_path = os.path.relpath(dirpath, TRTLLM_TRUSS_DIR) + for filename in filenames: + src_file = os.path.join(dirpath, filename) + dest_file = os.path.join(tmp_path, rel_path, filename) + assert os.path.exists(dest_file), f"{dest_file} was not copied" + assert filecmp.cmp( + src_file, dest_file, shallow=False + ), f"{src_file} and {dest_file} are not the same" + assert ( build_th.spec.config.runtime.predict_concurrency == TRTLLM_PREDICT_CONCURRENCY ) + assert build_th.spec.config.base_image.image == TRTLLM_BASE_IMAGE + assert ( + build_th.spec.config.base_image.python_executable_path + == TRTLLM_PYTHON_EXECUTABLE + ) + assert BASE_TRTLLM_REQUIREMENTS == build_th.spec.config.requirements + assert OPENAI_COMPATIBLE_TAG in build_th.spec.config.model_metadata["tags"]