Skip to content

Commit

Permalink
add bei specfic user migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Feb 3, 2025
1 parent 53db85f commit fcf4ca2
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator

from truss.base.constants import BEI_REQUIRED_MAX_NUM_TOKENS
logger = logging.getLogger(__name__)
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
Expand Down Expand Up @@ -118,6 +119,12 @@ class TrussTRTLLMBuildConfiguration(BaseModel):

class Config:
extra = "forbid"

def __init__(self, **data):
super().__init__(**data)
self._validate_kv_cache_flags()
self._validate_speculator_config()
self._bei_specfic_patches()

@validator("max_beam_width")
def check_max_beam_width(cls, v: int):
Expand All @@ -127,6 +134,29 @@ def check_max_beam_width(cls, v: int):
"max_beam_width greater than 1 is not currently supported"
)
return v

def _bei_specfic_migration(self):
"""performs embedding specfic optimizations (no kv-cache, high batch size)"""
if self.base_model == TrussTRTLLMModel.ENCODER:
# Encoder specific settings
logger.info(
f"Your setting of `build.max_seq_len={self.max_seq_len}` is not used and "
"automatically inferred from the model repo config.json -> `max_position_embeddings`"
)

if self.max_num_tokens < BEI_REQUIRED_MAX_NUM_TOKENS:
logger.warning(
f"build.max_num_tokens={self.max_num_tokens}, upgrading to {BEI_REQUIRED_MAX_NUM_TOKENS}"
)
self.max_num_tokens = BEI_REQUIRED_MAX_NUM_TOKENS
self.plugin_configuration.paged_kv_cache = False
self.plugin_configuration.use_paged_context_fmha = False

if "_kv" in self.quantization_type.value:
raise ValueError(
"encoder does not have a kv-cache, therefore a kv specfic datatype is not valid"
f"you selected build.quantization_type {self.quantization_type}"
)

def _validate_kv_cache_flags(self):
if not self.plugin_configuration.paged_kv_cache and (
Expand Down Expand Up @@ -176,10 +206,7 @@ def max_draft_len(self) -> Optional[int]:
return self.speculator.num_draft_tokens
return None

def __init__(self, **data):
super().__init__(**data)
self._validate_kv_cache_flags()
self._validate_speculator_config()



class TrussSpeculatorConfiguration(BaseModel):
Expand Down

0 comments on commit fcf4ca2

Please sign in to comment.