diff --git a/pyproject.toml b/pyproject.toml index 9cf70cc00..a698d6b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.55" +version = "0.9.56" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py index 4b1b2488d..08369ea84 100644 --- a/truss-chains/examples/streaming/streaming_chain.py +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -38,7 +38,7 @@ class ConsumerOutput(pydantic.BaseModel): class Generator(chains.ChainletBase): """Example that streams fully structured pydantic items with header and footer.""" - async def run_remote(self) -> AsyncIterator[bytes]: + async def run_remote(self, cause_error: bool) -> AsyncIterator[bytes]: print("Entering Generator") streamer = streaming.stream_writer(STREAM_TYPES) header = Header(time=time.time(), msg="Start.") @@ -49,6 +49,8 @@ async def run_remote(self) -> AsyncIterator[bytes]: ) print("Yield") yield streamer.yield_item(data) + if cause_error and i > 2: + raise RuntimeError("Test Error") await asyncio.sleep(0.05) end_time = time.time() @@ -79,9 +81,11 @@ def __init__( self._generator = generator self._string_generator = string_generator - async def run_remote(self) -> ConsumerOutput: + async def run_remote(self, cause_error: bool) -> ConsumerOutput: print("Entering Consumer") - reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote()) + reader = streaming.stream_reader( + STREAM_TYPES, self._generator.run_remote(cause_error) + ) print("Consuming...") header = await reader.read_header() chunks = [] @@ -103,5 +107,5 @@ async def run_remote(self) -> ConsumerOutput: if __name__ == "__main__": with chains.run_local(): chain = Consumer() - result = asyncio.run(chain.run_remote()) + result = asyncio.run(chain.run_remote(False)) print(result) diff --git a/truss-chains/tests/itest_chain/itest_chain.py b/truss-chains/tests/itest_chain/itest_chain.py index cdff415f0..56ad216d8 100644 --- a/truss-chains/tests/itest_chain/itest_chain.py +++ b/truss-chains/tests/itest_chain/itest_chain.py @@ -31,6 +31,11 @@ def run_remote(self, length: int) -> str: return (template * repetitions)[:length] +def validate_data(data): + if len(data) > 30: + raise ValueError(f"This input is too long: {len(data)}.") + + class TextReplicator(chains.ChainletBase): remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM) @@ -44,8 +49,7 @@ def __init__(self): self.multiplier = 2 def run_remote(self, data: str) -> str: - if len(data) > 30: - raise ValueError(f"This input is too long: {len(data)}.") + validate_data(data) return data * self.multiplier @@ -123,9 +127,7 @@ async def run_remote( extra_arg=123, ) print(pydantic_default_arg, simple_default_arg) - value = 0 - for part in text_parts.parts: - value += self._text_to_num.run_remote(part) + value = self._accumulate_parts(text_parts.parts) return ( value, data, @@ -133,3 +135,9 @@ async def run_remote( pydantic_default_arg, simple_default_arg, ) + + def _accumulate_parts(self, parts) -> int: + value = 0 + for part in parts: + value += self._text_to_num.run_remote(part) + return value diff --git a/truss-chains/tests/test_e2e.py b/truss-chains/tests/test_e2e.py index ca45a32dc..cc1272019 100644 --- a/truss-chains/tests/test_e2e.py +++ b/truss-chains/tests/test_e2e.py @@ -1,4 +1,6 @@ import logging +import re +import time from pathlib import Path import pytest @@ -23,6 +25,7 @@ def test_chain(): service = deployment_client.push(entrypoint, options) url = service.run_remote_url.replace("host.docker.internal", "localhost") + time.sleep(1.0) # Wait for models to be ready. # Call without providing values for default arguments. response = requests.post( @@ -73,11 +76,34 @@ def test_chain(): url, json={"length": 300, "num_partitions": 3}, stream=True ) print(response) + assert response.status_code == 500 + error = definitions.RemoteErrorDetail.model_validate(response.json()["error"]) error_str = error.format() print(error_str) - assert "ValueError: This input is too long: 100." in error_str - assert response.status_code == 500 + + error_regex = r""" +Chainlet-Traceback \(most recent call last\): + File \".*?/itest_chain\.py\", line \d+, in run_remote + value = self\._accumulate_parts\(text_parts\.parts\) + File \".*?/itest_chain\.py\", line \d+, in _accumulate_parts + value \+= self\._text_to_num\.run_remote\(part\) +ValueError: \(showing chained remote errors, root error at the bottom\) +├─ Error in dependency Chainlet `TextToNum`: +│ Chainlet-Traceback \(most recent call last\): +│ File \".*?/itest_chain\.py\", line \d+, in run_remote +│ generated_text = self\._replicator\.run_remote\(data\) +│ ValueError: \(showing chained remote errors, root error at the bottom\) +│ ├─ Error in dependency Chainlet `TextReplicator`: +│ │ Chainlet-Traceback \(most recent call last\): +│ │ File \".*?/itest_chain\.py\", line \d+, in run_remote +│ │ validate_data\(data\) +│ │ File \".*?/itest_chain\.py\", line \d+, in validate_data +│ │ raise ValueError\(f\"This input is too long: \{len\(data\)\}\.\"\) +╰ ╰ ValueError: This input is too long: \d+\. + """ + + assert re.match(error_regex.strip(), error_str.strip(), re.MULTILINE), error_str @pytest.mark.asyncio @@ -137,7 +163,8 @@ def test_streaming_chain(): ), ) assert service is not None - response = service.run_remote({}) + + response = service.run_remote({"cause_error": False}) assert response.status_code == 200 print(response.json()) result = response.json() @@ -150,6 +177,13 @@ def test_streaming_chain(): assert result["footer"]["duration_sec"] > 0 assert result["strings"] == "First second last." + # TODO: build error handling for stream reader. + # response = service.run_remote({"cause_error": True}) + # assert response.status_code == 200 + # print(response.json()) + # result = response.json() + # print(result) + @pytest.mark.asyncio async def test_streaming_chain_local(): @@ -157,7 +191,7 @@ async def test_streaming_chain_local(): chain_root = examples_root / "streaming" / "streaming_chain.py" with framework.import_target(chain_root, "Consumer") as entrypoint: with public_api.run_local(): - result = await entrypoint().run_remote() + result = await entrypoint().run_remote(cause_error=False) print(result) assert result.header.msg == "Start." assert result.chunks[0].words == ["G"] diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 47c853819..7de3c5188 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -633,7 +633,6 @@ class RemoteErrorDetail(SafeModel): error response. """ - remote_name: str exception_cls_name: str exception_module_name: Optional[str] exception_message: str @@ -654,8 +653,7 @@ def format(self) -> str: else "" ) error = ( - f"{RemoteErrorDetail.__name__} in `{self.remote_name}`\n" - f"Traceback (most recent call last):\n" + f"Chainlet-Traceback (most recent call last):\n" f"{stack}{self.exception_cls_name}: {self.exception_message}{exc_info}" ) return error diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index 51f71bf7a..0c883ddcb 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -452,10 +452,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> ) # Add error handling context manager: parts.append( - _indent( - f"with stub.trace_parent(request), utils.exception_to_http_error(" - f'chainlet_name="{chainlet_descriptor.name}"):' - ) + _indent("with stub.trace_parent(request), utils.exception_to_http_error():") ) # Invoke Chainlet. if ( diff --git a/truss-chains/truss_chains/remote_chainlet/utils.py b/truss-chains/truss_chains/remote_chainlet/utils.py index e7bf68cc1..8b90f006a 100644 --- a/truss-chains/truss_chains/remote_chainlet/utils.py +++ b/truss-chains/truss_chains/remote_chainlet/utils.py @@ -140,30 +140,34 @@ def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseM # Error Propagation Utils. ############################################################# -def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: - """Raises `starlette.exceptions.HTTPExceptionn` with `RemoteErrorDetail`.""" +def _handle_exception(exception: Exception) -> NoReturn: + """Raises `HTTPException` with `RemoteErrorDetail`.""" if hasattr(exception, "__module__"): exception_module_name = exception.__module__ else: exception_module_name = None error_stack = traceback.extract_tb(exception.__traceback__) - # Exclude the error handling functions from the stack trace. - exclude_frames = { - exception_to_http_error.__name__, - response_raise_errors.__name__, - async_response_raise_errors.__name__, - } - final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] - stack = list( - [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] - ) + # Filter everything before (model.py) and after (stubs, error handling) so that only + # user-defined code remains. See test_e2e.py::test_chain for expected results. + model_predict_index = 0 + first_stub_index = len(error_stack) + for i, frame in enumerate(error_stack): + if frame.filename.endswith("model/model.py") and frame.name == "predict": + model_predict_index = i + 1 + if frame.filename.endswith("remote_chainlet/stub.py") and frame.name.startswith( + "predict" # predict sycnc|async|stream. + ): + first_stub_index = i - 1 + break + + final_tb = error_stack[model_predict_index:first_stub_index] + stack = [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] error = definitions.RemoteErrorDetail( - remote_name=chainlet_name, exception_cls_name=exception.__class__.__name__, exception_module_name=exception_module_name, exception_message=str(exception), - user_stack_trace=stack, + user_stack_trace=list(stack), ) raise fastapi.HTTPException( status_code=500, detail=error.model_dump() @@ -171,12 +175,11 @@ def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: @contextlib.contextmanager -def exception_to_http_error(chainlet_name: str) -> Iterator[None]: - # TODO: move chainlet name from here to caller side. +def exception_to_http_error() -> Iterator[None]: try: yield except Exception as e: - _handle_exception(e, chainlet_name) + _handle_exception(e) def _resolve_exception_class( @@ -213,8 +216,9 @@ def _handle_response_error(response_json: dict, remote_name: str): except KeyError as e: logging.error(f"response_json: {response_json}") raise ValueError( - "Could not get `error` field from JSON from error response" + "Could not get `error` field from JSON from chainlet error response" ) from e + try: error = definitions.RemoteErrorDetail.model_validate(error_json) except pydantic.ValidationError as e: @@ -222,15 +226,20 @@ def _handle_response_error(response_json: dict, remote_name: str): msg = f"Remote error occurred in `{remote_name}`: '{error_json}'" raise definitions.GenericRemoteException(msg) from None raise ValueError( - "Could not parse error. Error details are expected to be either a " + "Could not parse chainlet error. Error details are expected to be either a " "plain string (old truss models) or a serialized " - f"`definitions.RemoteErrorDetail.__name__`, got:\n{repr(error_json)}" + f"`{definitions.RemoteErrorDetail.__name__}`, got:\n{repr(error_json)}" ) from e + exception_cls = _resolve_exception_class(error) + error_format = textwrap.indent(error.format(), "│ ") + *lines, last_line = error_format.splitlines() + last_line = f"╰{last_line[1:]}" if last_line.startswith("│") else last_line + error_format = "\n".join(lines + [last_line]) msg = ( - f"(showing remote errors, root message at the bottom)\n" - f"--> Preceding Remote Cause:\n" - f"{textwrap.indent(error.format(), ' ')}" + f"(showing chained remote errors, root error at the bottom)\n" + f"├─ Error in dependency Chainlet `{remote_name}`:\n" + f"{error_format}" ) raise exception_cls(msg) diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 7b20539ba..8cba9dd98 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -173,7 +173,6 @@ async def _read(self) -> tuple[_Delimiter, bytes]: if not length: return delimiter, b"" data_bytes = await self._stream.readexactly(length) - print(f"Read Delimiter: {delimiter}") return delimiter, data_bytes async def read_items(self) -> AsyncIterator[ItemT]: diff --git a/truss/base/constants.py b/truss/base/constants.py index a514ad7e3..0cce6fa13 100644 --- a/truss/base/constants.py +++ b/truss/base/constants.py @@ -109,7 +109,7 @@ TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft" TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7" TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3" -BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev4"] +BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev8"] AUDIO_MODEL_TRTLLM_REQUIREMENTS = [ "--extra-index-url https://pypi.nvidia.com", "tensorrt_cu12_bindings==10.2.0.post1", diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index bd589f1ce..1c19dc261 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import warnings from enum import Enum @@ -17,6 +19,7 @@ class TrussTRTLLMModel(str, Enum): DEEPSEEK = "deepseek" WHISPER = "whisper" QWEN = "qwen" + ENCODER = "encoder" class TrussTRTLLMQuantizationType(str, Enum): @@ -48,6 +51,20 @@ class CheckpointSource(str, Enum): class CheckpointRepository(BaseModel): source: CheckpointSource repo: str + revision: Optional[str] = None + + def __init__(self, **data): + super().__init__(**data) + if self.source == CheckpointSource.HF: + self._validate_hf_repo_id() + + def _validate_hf_repo_id(self): + try: + validate_repo_id(self.repo) + except HFValidationError as e: + raise ValueError( + f"HuggingFace repository validation failed: {str(e)}" + ) from e class TrussTRTLLMBatchSchedulerPolicy(str, Enum): @@ -59,6 +76,16 @@ class TrussSpecDecMode(str, Enum): DRAFT_EXTERNAL: str = "DRAFT_TOKENS_EXTERNAL" +class TrussTRTLLMRuntimeConfiguration(BaseModel): + kv_cache_free_gpu_mem_fraction: float = 0.9 + enable_chunked_context: bool = False + batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = ( + TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT + ) + request_default_max_tokens: Optional[int] = None + total_token_limit: int = 500000 + + class TrussTRTLLMBuildConfiguration(BaseModel): base_model: TrussTRTLLMModel max_seq_len: int @@ -78,8 +105,7 @@ class TrussTRTLLMBuildConfiguration(BaseModel): TrussTRTLLMPluginConfiguration() ) num_builder_gpus: Optional[int] = None - speculative_decoding_mode: Optional[TrussSpecDecMode] = None - max_draft_len: Optional[int] = None + speculator: Optional[TrussSpeculatorConfiguration] = None @validator("max_beam_width") def check_max_beam_width(cls, v: int): @@ -90,100 +116,93 @@ def check_max_beam_width(cls, v: int): ) return v - -class TrussTRTLLMRuntimeConfiguration(BaseModel): - kv_cache_free_gpu_mem_fraction: float = 0.9 - enable_chunked_context: bool = False - batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = ( - TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT - ) - request_default_max_tokens: Optional[int] = None - # Speculative Decoding runtime configuration, ignored for non spec dec configurations - num_draft_tokens: Optional[int] = ( - None # number of draft tokens to be sampled from draft model in speculative decoding scheme - ) - - -class TRTLLMConfiguration(BaseModel): - runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration() - build: TrussTRTLLMBuildConfiguration - - def __init__(self, **data): - super().__init__(**data) - self._validate_kv_cache_flags() - if self.build.checkpoint_repository.source == CheckpointSource.HF: - self._validate_hf_repo_id() - def _validate_kv_cache_flags(self): - if self.build is None: - return self - if not self.build.plugin_configuration.paged_kv_cache and ( - self.build.plugin_configuration.use_paged_context_fmha - or self.build.plugin_configuration.use_fp8_context_fmha + if not self.plugin_configuration.paged_kv_cache and ( + self.plugin_configuration.use_paged_context_fmha + or self.plugin_configuration.use_fp8_context_fmha ): raise ValueError( "Using paged context fmha or fp8 context fmha requires requires paged kv cache" ) if ( - self.build.plugin_configuration.use_fp8_context_fmha - and not self.build.plugin_configuration.use_paged_context_fmha + self.plugin_configuration.use_fp8_context_fmha + and not self.plugin_configuration.use_paged_context_fmha ): raise ValueError("Using fp8 context fmha requires paged context fmha") return self - def _validate_hf_repo_id(self): - try: - validate_repo_id(self.build.checkpoint_repository.repo) - except HFValidationError as e: - raise ValueError( - f"HuggingFace repository validation failed: {str(e)}" - ) from e + def _validate_speculator_config(self): + if self.speculator: + if self.base_model is TrussTRTLLMModel.WHISPER: + raise ValueError("Speculative decoding for Whisper is not supported.") + if not all( + [ + self.plugin_configuration.use_paged_context_fmha, + self.plugin_configuration.paged_kv_cache, + ] + ): + raise ValueError( + "KV cache block reuse must be enabled for speculative decoding target model." + ) + if self.speculator.build: + if ( + self.tensor_parallel_count + != self.speculator.build.tensor_parallel_count + ): + raise ValueError( + "Speculative decoding requires the same tensor parallelism for target and draft models." + ) @property - def requires_build(self): - if self.build is not None: - return True - return False + def max_draft_len(self) -> Optional[int]: + if self.speculator: + return self.speculator.num_draft_tokens + return None - # TODO(Abu): Replace this with model_dump(json=True) - # when pydantic v2 is used here - def to_json_dict(self, verbose=True): - return json.loads(self.json(exclude_unset=not verbose)) + def __init__(self, **data): + super().__init__(**data) + self._validate_kv_cache_flags() + self._validate_speculator_config() -class TRTLLMSpeculativeDecodingConfiguration(BaseModel): - target: TRTLLMConfiguration - draft: TRTLLMConfiguration - total_token_limit: int = 500000 +class TrussSpeculatorConfiguration(BaseModel): + speculative_decoding_mode: TrussSpecDecMode + num_draft_tokens: int + checkpoint_repository: Optional[CheckpointRepository] = None + runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration() + build: Optional[TrussTRTLLMBuildConfiguration] = None def __init__(self, **data): super().__init__(**data) - self._spec_dec_configs = [ - self.target.build.speculative_decoding_mode, - self.target.build.max_draft_len, - ] + ( - [self.draft.runtime.num_draft_tokens] - if self.draft.runtime and self.draft.runtime.num_draft_tokens - else [False] - ) - self._validate_spec_dec() - - def _validate_spec_dec(self): - if any(self._spec_dec_configs): - if not all(self._spec_dec_configs): - raise ValueError( - "Speculative decoding requires all of `target.build.speculative_decoding_mode`, `target.build.max_draft_len`, and `draft.runtime.num_draft_tokens` to be configured." - ) - for trt_llm_config in [self.target, self.draft]: - if trt_llm_config.build.base_model is TrussTRTLLMModel.WHISPER: - raise ValueError("Speculative decoding for Whisper is not supported.") - if ( - self.target.build.tensor_parallel_count - != self.draft.build.tensor_parallel_count - ): + self._validate_checkpoint() + + def _validate_checkpoint(self): + if not (bool(self.checkpoint_repository) ^ bool(self.build)): + raise ValueError( + "Speculative decoding requires exactly one of checkpoint_repository or build to be configured." + ) + + @property + def resolved_checkpoint_repository(self) -> CheckpointRepository: + if self.build: + return self.build.checkpoint_repository + elif self.checkpoint_repository: + return self.checkpoint_repository + else: raise ValueError( - "Speculative decoding requires the same tensor parallelism for target and draft models." + "Speculative decoding requires exactly one of checkpoint_repository or build to be configured." ) + +class TRTLLMConfiguration(BaseModel): + runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration() + build: TrussTRTLLMBuildConfiguration + + @property + def requires_build(self): + return self.build is not None + + # TODO(Abu): Replace this with model_dump(json=True) + # when pydantic v2 is used here def to_json_dict(self, verbose=True): return json.loads(self.json(exclude_unset=not verbose)) diff --git a/truss/base/truss_config.py b/truss/base/truss_config.py index 8809599d3..be4900aec 100644 --- a/truss/base/truss_config.py +++ b/truss/base/truss_config.py @@ -3,19 +3,18 @@ from dataclasses import _MISSING_TYPE, dataclass, field, fields from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar import yaml from truss.base.constants import ( HTTP_PUBLIC_BLOB_BACKEND, - TRTLLM_SPEC_DEC_TARGET_MODEL_NAME, ) from truss.base.custom_types import ModelFrameworkType from truss.base.errors import ValidationError from truss.base.trt_llm_config import ( TRTLLMConfiguration, - TRTLLMSpeculativeDecodingConfiguration, + TrussTRTLLMBuildConfiguration, TrussTRTLLMQuantizationType, ) from truss.base.validation import ( @@ -562,9 +561,7 @@ class TrussConfig: base_image: Optional[BaseImage] = None docker_server: Optional[DockerServer] = None model_cache: ModelCache = field(default_factory=ModelCache) - trt_llm: Optional[ - Union[TRTLLMConfiguration, TRTLLMSpeculativeDecodingConfiguration] - ] = None + trt_llm: Optional[TRTLLMConfiguration] = None build_commands: List[str] = field(default_factory=list) use_local_chains_src: bool = False @@ -578,11 +575,11 @@ def canonical_python_version(self) -> str: }[self.python_version] @property - def parsed_trt_llm_configs(self) -> List[TRTLLMConfiguration]: + def parsed_trt_llm_build_configs(self) -> List[TrussTRTLLMBuildConfiguration]: if self.trt_llm: - if isinstance(self.trt_llm, TRTLLMSpeculativeDecodingConfiguration): - return [self.trt_llm.target, self.trt_llm.draft] - return [self.trt_llm] + if self.trt_llm.build.speculator and self.trt_llm.build.speculator.build: + return [self.trt_llm.build, self.trt_llm.build.speculator.build] + return [self.trt_llm.build] return [] @staticmethod @@ -631,10 +628,7 @@ def from_dict(d): ModelCache.from_list, ), trt_llm=transform_optional( - d.get("trt_llm"), - lambda x: (TRTLLMConfiguration(**x)) - if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME not in d.get("trt_llm") - else (TRTLLMSpeculativeDecodingConfiguration(**x)), + d.get("trt_llm"), lambda x: (TRTLLMConfiguration(**x)) ), build_commands=d.get("build_commands", []), use_local_chains_src=d.get("use_local_chains_src", False), @@ -688,16 +682,16 @@ def clone(self): return TrussConfig.from_dict(self.to_dict()) def _validate_trt_llm_config(self) -> None: - for trt_llm_config in self.parsed_trt_llm_configs: + if self.trt_llm: if ( - trt_llm_config.build.quantization_type + self.trt_llm.build.quantization_type is TrussTRTLLMQuantizationType.WEIGHTS_ONLY_INT8 and self.resources.accelerator.accelerator is Accelerator.A100 ): raise ValueError( "Weight only int8 quantization on A100 accelerators is not currently supported" ) - elif trt_llm_config.build.quantization_type in [ + elif self.trt_llm.build.quantization_type in [ TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV, ] and self.resources.accelerator.accelerator not in [ @@ -708,7 +702,7 @@ def _validate_trt_llm_config(self) -> None: raise ValueError( "FP8 quantization is only supported on L4 and H100 accelerators" ) - tensor_parallel_count = trt_llm_config.build.tensor_parallel_count + tensor_parallel_count = self.trt_llm.build.tensor_parallel_count if tensor_parallel_count != self.resources.accelerator.count: raise ValueError( @@ -813,10 +807,6 @@ def obj_to_dict(obj, verbose: bool = False): d["trt_llm"] = transform_optional( field_curr_value, lambda data: data.to_json_dict(verbose=verbose) ) - elif isinstance(field_curr_value, TRTLLMSpeculativeDecodingConfiguration): - d["trt_llm"] = transform_optional( - field_curr_value, lambda data: data.to_json_dict(verbose=verbose) - ) elif isinstance(field_curr_value, BaseImage): d["base_image"] = transform_optional( field_curr_value, lambda data: data.to_dict() diff --git a/truss/cli/cli.py b/truss/cli/cli.py index fd5755cde..2687302a4 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -1158,11 +1158,11 @@ def push( console.print( f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi." ) - for trt_llm_config in tr.spec.config.parsed_trt_llm_configs: + for trt_llm_build_config in tr.spec.config.parsed_trt_llm_build_configs: if ( - trt_llm_config.build.quantization_type + trt_llm_build_config.quantization_type in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV] - and not trt_llm_config.build.num_builder_gpus + and not trt_llm_build_config.num_builder_gpus ): fp8_and_num_builder_gpus_text = ( "Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. " diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index af9814f16..fdc96a485 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -432,7 +432,7 @@ def patch_draft_truss_two_step(self, model_name, patch_request): query_string = f""" mutation {{ stage_patch_for_draft_truss(name: "{model_name}", - client_version: "TRUSS", + client_version: "{truss.version()}", patch: "{patch}", ) {{ id, @@ -457,7 +457,7 @@ def sync_draft_truss(self, model_name): query_string = f""" mutation {{ sync_draft_truss(name: "{model_name}", - client_version: "TRUSS", + client_version: "{truss.version()}", ) {{ id, name, @@ -474,29 +474,6 @@ def sync_draft_truss(self, model_name): logging.debug(f"Failed to sync patch: {result}") return result - def patch_draft_truss(self, model_name, patch_request): - patch = base64_encoded_json_str(patch_request.to_dict()) - query_string = f""" - mutation {{ - patch_draft_truss(name: "{model_name}", - client_version: "TRUSS", - patch: "{patch}", - ) {{ - id, - name, - version_id - succeeded - needs_full_deploy - error - }} - }} - """ - resp = self._post_graphql_query(query_string) - result = resp["data"]["patch_draft_truss"] - if not result["succeeded"]: - logging.debug(f"Unsuccessful response: {result}") - return result - def get_deployment(self, model_id: str, deployment_id: str) -> Any: headers = self._auth_token.header() resp = requests.get( diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index afeca69ca..e32edb2ca 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -519,22 +519,20 @@ def _patch( return PatchResult( PatchStatus.SKIPPED, "No changes observed, skipping patching." ) + + def do_patch(): + if should_create_patch: + resp = self._api.patch_draft_truss_two_step(model_name, patch_request) + else: + resp = self._api.sync_draft_truss(model_name) + return resp + try: if console: with console.status("Applying patch..."): - if should_create_patch: - resp = self._api.patch_draft_truss_two_step( - model_name, patch_request - ) - else: - resp = self._api.sync_draft_truss(model_name) + resp = do_patch() else: - if should_create_patch: - resp = self._api.patch_draft_truss_two_step( - model_name, patch_request - ) - else: - resp = self._api.sync_draft_truss(model_name) + resp = do_patch() except ReadTimeout: return PatchResult( diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 6feea1eca..cff58f2ca 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -534,9 +534,7 @@ async def preprocess( request: starlette.requests.Request, ) -> Any: descriptor = self.model_descriptor.preprocess - if not descriptor: - return inputs - + assert descriptor, "`preprocess` must only be called if model has it." args = ArgConfig.prepare_args(descriptor, inputs, request) with errors.intercept_exceptions(self._logger, self._model_file_name): if descriptor.is_async: @@ -573,9 +571,7 @@ async def postprocess( # and postprocess is skipped. # The result type can be the same as for predict. descriptor = self.model_descriptor.postprocess - if not descriptor: - return result - + assert descriptor, "`postprocess` must only be called if model has it." args = ArgConfig.prepare_args(descriptor, result, request) with errors.intercept_exceptions(self._logger, self._model_file_name): if descriptor.is_async: @@ -658,11 +654,14 @@ async def __call__( """ Returns result from: preprocess -> predictor -> postprocess. """ - with self._tracer.start_as_current_span("call-pre") as span_pre: - with tracing.section_as_event( - span_pre, "preprocess" - ), tracing.detach_context(): - preprocess_result = await self.preprocess(inputs, request) + if self.model_descriptor.preprocess: + with self._tracer.start_as_current_span("call-pre") as span_pre: + with tracing.section_as_event( + span_pre, "preprocess" + ), tracing.detach_context(): + preprocess_result = await self.preprocess(inputs, request) + else: + preprocess_result = inputs span_predict = self._tracer.start_span("call-predict") async with deferred_semaphore_and_span( @@ -730,12 +729,15 @@ async def __call__( return predict_result - with self._tracer.start_as_current_span("call-post") as span_post: - with tracing.section_as_event( - span_post, "postprocess" - ), tracing.detach_context(): - postprocess_result = await self.postprocess(predict_result, request) - return postprocess_result + if self.model_descriptor.postprocess: + with self._tracer.start_as_current_span("call-post") as span_post: + with tracing.section_as_event( + span_post, "postprocess" + ), tracing.detach_context(): + postprocess_result = await self.postprocess(predict_result, request) + return postprocess_result + else: + return predict_result async def _gather_generator( diff --git a/truss/templates/trtllm-briton/src/extension.py b/truss/templates/trtllm-briton/src/extension.py index 6b9cf51d0..164b7e905 100644 --- a/truss/templates/trtllm-briton/src/extension.py +++ b/truss/templates/trtllm-briton/src/extension.py @@ -1,10 +1,8 @@ from briton.spec_dec_truss_model import Model as SpecDecModel from briton.trtllm_config import ( TRTLLMConfiguration, - TRTLLMSpeculativeDecodingConfiguration, ) from briton.truss_model import Model -from pydantic import ValidationError # TODO(pankaj) Define an ABC base class for this. That baseclass should live in # a new, smaller truss sub-library, perhaps called `truss-runtime`` for inclusion @@ -41,12 +39,11 @@ class Extension: def __init__(self, *args, **kwargs): self._config = kwargs["config"] trt_llm_config = self._config.get("trt_llm") - try: - TRTLLMConfiguration(**trt_llm_config) - self._model = Model(*args, **kwargs) - except ValidationError as _: - TRTLLMSpeculativeDecodingConfiguration(**trt_llm_config) + config = TRTLLMConfiguration(**trt_llm_config) + if config.build.speculator: self._model = SpecDecModel(*args, **kwargs) + else: + self._model = Model(*args, **kwargs) def model_override(self): """Return a model object. diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index c1401e905..5b5c527a7 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -9,7 +9,6 @@ from truss.base.custom_types import ModelFrameworkType from truss.base.trt_llm_config import ( - TRTLLMSpeculativeDecodingConfiguration, TrussSpecDecMode, TrussTRTLLMQuantizationType, ) @@ -77,34 +76,65 @@ def trtllm_config(default_config) -> Dict[str, Any]: @pytest.fixture -def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]: +def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]: spec_dec_config = copy.deepcopy(trtllm_config) spec_dec_config["trt_llm"] = { - "target": { - "build": { - "base_model": "llama", - "max_seq_len": 2048, - "max_batch_size": 512, - "checkpoint_repository": { - "source": "HF", - "repo": "meta/llama4-500B", - }, - "gather_all_token_logits": False, + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + "plugin_configuration": { + "paged_kv_cache": True, + "gemm_plugin": "auto", + "use_paged_context_fmha": True, + }, + "speculator": { "speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL, - "max_draft_len": 10, + "num_draft_tokens": 4, + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + }, }, }, - "draft": { - "build": { - "base_model": "llama", - "max_seq_len": 2048, - "max_batch_size": 512, + } + return spec_dec_config + + +@pytest.fixture +def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]: + spec_dec_config = copy.deepcopy(trtllm_config) + spec_dec_config["trt_llm"] = { + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + "plugin_configuration": { + "paged_kv_cache": True, + "gemm_plugin": "auto", + "use_paged_context_fmha": True, + }, + "speculator": { + "speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL, + "num_draft_tokens": 4, "checkpoint_repository": { "source": "HF", "repo": "meta/llama4-500B", }, }, - "runtime": {"num_draft_tokens": 4}, }, } return spec_dec_config @@ -549,7 +579,13 @@ def test_plugin_paged_fp8_context_fmha_check(trtllm_config): @pytest.mark.parametrize("verbose, expect_equal", [(False, True), (True, False)]) -def test_to_dict_trtllm(verbose, expect_equal, trtllm_config, trtllm_spec_dec_config): +def test_to_dict_trtllm( + verbose, + expect_equal, + trtllm_config, + trtllm_spec_dec_config, + trtllm_spec_dec_config_full, +): assert ( TrussConfig.from_dict(trtllm_config).to_dict(verbose=verbose) == trtllm_config ) == expect_equal @@ -557,39 +593,49 @@ def test_to_dict_trtllm(verbose, expect_equal, trtllm_config, trtllm_spec_dec_co TrussConfig.from_dict(trtllm_spec_dec_config).to_dict(verbose=verbose) == trtllm_spec_dec_config ) == expect_equal + assert ( + TrussConfig.from_dict(trtllm_spec_dec_config_full).to_dict(verbose=verbose) + == trtllm_spec_dec_config_full + ) == expect_equal @pytest.mark.parametrize("should_raise", [False, True]) def test_from_dict_spec_dec_trt_llm(should_raise, trtllm_spec_dec_config): test_config = copy.deepcopy(trtllm_spec_dec_config) if should_raise: - test_config["trt_llm"]["target"]["build"]["speculative_decoding_mode"] = None + test_config["trt_llm"]["build"]["speculator"]["speculative_decoding_mode"] = ( + None + ) + with pytest.raises(ValueError): + TrussConfig.from_dict(test_config) + test_config["trt_llm"]["build"]["speculator"]["checkpoint_repository"] = None + with pytest.raises(ValueError): + TrussConfig.from_dict(test_config) + test_config["trt_llm"]["build"]["speculator"]["checkpoint_repository"] = ( + trtllm_spec_dec_config[ + "trt_llm" + ]["build"]["speculator"]["checkpoint_repository"] + ) + test_config["trt_llm"]["build"]["plugin_configuration"][ + "use_paged_context_fmha" + ] = False with pytest.raises(ValueError): TrussConfig.from_dict(test_config) - test_config["trt_llm"]["target"]["build"]["speculative_decoding_mode"] = ( + test_config["trt_llm"]["build"]["plugin_configuration"][ + "use_paged_context_fmha" + ] = True + test_config["trt_llm"]["build"]["speculator"]["speculative_decoding_mode"] = ( trtllm_spec_dec_config[ "trt_llm" - ]["target"]["build"]["speculative_decoding_mode"] + ]["build"]["speculator"]["speculative_decoding_mode"] ) - test_config["trt_llm"]["draft"]["runtime"]["num_draft_tokens"] = None + test_config["trt_llm"]["build"]["speculator"]["num_draft_tokens"] = None with pytest.raises(ValueError): TrussConfig.from_dict(test_config) else: TrussConfig.from_dict(trtllm_spec_dec_config) -@pytest.mark.parametrize("spec_dec_enabled", [False, True]) -def test_trtllm_spec_dec(spec_dec_enabled, trtllm_config, trtllm_spec_dec_config): - config = trtllm_config - if spec_dec_enabled: - config = trtllm_spec_dec_config - truss_config = TrussConfig.from_dict(config) - assert ( - isinstance(truss_config.trt_llm, TRTLLMSpeculativeDecodingConfiguration) - == spec_dec_enabled - ) - - def test_from_yaml_invalid_requirements_configuration(): invalid_requirements = { "requirements_file": "requirements.txt", diff --git a/truss/trt_llm/config_checks.py b/truss/trt_llm/config_checks.py index 2562af6e1..5fe80761d 100644 --- a/truss/trt_llm/config_checks.py +++ b/truss/trt_llm/config_checks.py @@ -9,9 +9,9 @@ def is_missing_secrets_for_trt_llm_builder(tr: TrussHandle) -> bool: - for trt_llm_config in tr.spec.config.parsed_trt_llm_configs: - source = trt_llm_config.build.checkpoint_repository.source - hf_model_id = trt_llm_config.build.checkpoint_repository.repo + for trt_llm_build_config in tr.spec.config.parsed_trt_llm_build_configs: + source = trt_llm_build_config.checkpoint_repository.source + hf_model_id = trt_llm_build_config.checkpoint_repository.repo if ( source == CheckpointSource.HF and HF_ACCESS_TOKEN_KEY not in tr.spec.secrets