-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Set default predict_concurrency when using trt-llm to 512 (#954) * Set default predict_concurrency when using trt-llm to 512 * update tests * Truss changes to support lazy data that reads bptr secret and fetches from remote (#963) * lazy data resolution support * add support for lazy data resolver in truss * remove lazy loader reference from template * fetch in model wrapper * duplicate download util for shared template * concurrent download * fix path reference * use updated expiration_timestamp type --------- Co-authored-by: Pankaj Gupta <[email protected]> * Update push docs. (#965) * Adding initial code to implement build commands (#961) * Adding initial code to implement build commands * Adding some tests * Adding docker integration tests * making build command an empty list by default * removing unnecessary build_commands list for loop thing * correct secrets str in docs (#968) * Fix lazy data resolver error handling (#967) * [chains] Add external_package_dirs option. Usage in Whiper model chainlet. (#966) * add truss chains init (#973) * [BT-10657] Wire up truss chains deploy (#969) * Wire up the new chains mutations to truss chains deploy. * Add comment. * Respond to PR feedback. * * Prune docker build cache in integration tests. (#976) * Show requirement file content before pip install. * For all tests running docker containers, show container logs if an exception was raised. * Update control requirements to truss 0.9.14 (required also incrementing httpx version). * Bump version to 0.9.15 --------- Co-authored-by: Bryce Dubayah <[email protected]> Co-authored-by: joostinyi <[email protected]> Co-authored-by: Pankaj Gupta <[email protected]> Co-authored-by: Sidharth Shanker <[email protected]> Co-authored-by: Het Trivedi <[email protected]> Co-authored-by: rcano-baseten <[email protected]> Co-authored-by: Marius Killinger <[email protected]>
- Loading branch information
1 parent
7e17bf7
commit 713f74d
Showing
39 changed files
with
867 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import Optional | ||
|
||
# flake8: noqa F402 | ||
# This location assumes `fde`-repo is checked out at the same level as `truss`-repo. | ||
_LOCAL_WHISPER_LIB = "../../../../fde/whisper-trt/src" | ||
import sys | ||
|
||
sys.path.append(_LOCAL_WHISPER_LIB) | ||
|
||
import base64 | ||
|
||
import pydantic | ||
import truss_chains as chains | ||
from huggingface_hub import snapshot_download | ||
|
||
|
||
# TODO: The I/O types below should actually be taken from `whisper_trt.types`. | ||
# But that cannot be imported without having `tensorrt_llm` installed. | ||
# It could be fixed, by making that module importable without any special requirements. | ||
class Segment(pydantic.BaseModel): | ||
start_time_sec: float | ||
end_time_sec: float | ||
text: str | ||
start: float # TODO: deprecate, use field with unit (seconds). | ||
end: float # TODO: deprecate, use field with unit (seconds). | ||
|
||
|
||
class WhisperResult(pydantic.BaseModel): | ||
segments: list[Segment] | ||
language: Optional[str] | ||
language_code: Optional[str] = pydantic.Field( | ||
..., | ||
description="IETF language tag, e.g. 'en', see. " | ||
"https://en.wikipedia.org/wiki/IETF_language_tag.", | ||
) | ||
|
||
|
||
class WhisperInput(pydantic.BaseModel): | ||
audio_b64: str | ||
|
||
|
||
@chains.mark_entrypoint | ||
class WhisperModel(chains.ChainletBase): | ||
|
||
remote_config = chains.RemoteConfig( | ||
docker_image=chains.DockerImage( | ||
base_image="baseten/truss-server-base:3.10-gpu-v0.9.0", | ||
apt_requirements=["python3.10-venv", "openmpi-bin", "libopenmpi-dev"], | ||
pip_requirements=[ | ||
"--extra-index-url https://pypi.nvidia.com", | ||
"tensorrt_llm==0.10.0.dev2024042300", | ||
"hf_transfer", | ||
"janus", | ||
"kaldialign", | ||
"librosa", | ||
"mpi4py==3.1.4", | ||
"safetensors", | ||
"soundfile", | ||
"tiktoken", | ||
"torchaudio", | ||
"async-batcher>=0.2.0", | ||
"pydantic>=2.7.1", | ||
], | ||
external_package_dirs=[chains.make_abs_path_here(_LOCAL_WHISPER_LIB)], | ||
), | ||
compute=chains.Compute(gpu="A10G", predict_concurrency=128), | ||
assets=chains.Assets(secret_keys=["hf_access_token"]), | ||
) | ||
|
||
def __init__( | ||
self, | ||
context: chains.DeploymentContext = chains.depends_context(), | ||
) -> None: | ||
snapshot_download( | ||
repo_id="baseten/whisper_trt_large-v3_A10G_i224_o512_bs8_bw5", | ||
local_dir=context.data_dir, | ||
allow_patterns=["**"], | ||
token=context.secrets["hf_access_token"], | ||
) | ||
from whisper_trt import WhisperModel | ||
|
||
self._model = WhisperModel(str(context.data_dir), max_queue_time=0.050) | ||
|
||
async def run_remote(self, request: WhisperInput) -> WhisperResult: | ||
binary_data = base64.b64decode(request.audio_b64.encode("utf-8")) | ||
waveform = self._model.preprocess_audio(binary_data) | ||
return await self._model.transcribe( | ||
waveform, timestamps=True, raise_when_trimmed=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,35 @@ | ||
import random | ||
|
||
# For more on chains, check out https://truss.baseten.co/chains/intro. | ||
import truss_chains as chains | ||
|
||
|
||
class DummyGenerateData(chains.ChainletBase): | ||
def run_remote(self) -> str: | ||
return "abc" | ||
# By inhereting chains.ChainletBase, the chains framework will know to create a chainlet that hosts the RandInt class. | ||
class RandInt(chains.ChainletBase): | ||
|
||
# run_remote must be implemented by all chainlets. This is the code that will be executed at inference time. | ||
def run_remote(self, max_value: int) -> int: | ||
return random.randint(1, max_value) | ||
|
||
|
||
# The @chains.mark_entrypoint decorator indicates that this Chainlet is the entrypoint. | ||
# Each chain must have exactly one entrypoint. | ||
@chains.mark_entrypoint | ||
class HelloWorld(chains.ChainletBase): | ||
# chains.depends indicates that the HelloWorld chainlet depends on the RandInt Chainlet | ||
# this enables the HelloWorld chainlet to call the RandInt chainlet | ||
def __init__(self, rand_int=chains.depends(RandInt, retries=3)) -> None: | ||
self._rand_int = rand_int | ||
|
||
# Nesting the classes is a hack to make it *appear* like SplitText is from a different | ||
# module. | ||
class shared_chainlet: | ||
class DummySplitText(chains.ChainletBase): | ||
def run_remote(self, data: str) -> list[str]: | ||
return [data[:2], data[2:]] | ||
def run_remote(self, max_value: int) -> str: | ||
num_repetitions = self._rand_int.run_remote(max_value) | ||
return "Hello World! " * num_repetitions | ||
|
||
|
||
class DummyExample(chains.ChainletBase): | ||
def __init__( | ||
self, | ||
data_generator: DummyGenerateData = chains.depends(DummyGenerateData), | ||
splitter: shared_chainlet.DummySplitText = chains.depends( | ||
shared_chainlet.DummySplitText | ||
), | ||
context: chains.DeploymentContext = chains.depends_context(), | ||
) -> None: | ||
self._data_generator = data_generator | ||
self._data_splitter = splitter | ||
self._context = context | ||
if __name__ == "__main__": | ||
with chains.run_local(): | ||
hello_world_chain = HelloWorld() | ||
result = hello_world_chain.run_remote(max_value=5) | ||
|
||
def run_remote(self) -> list[str]: | ||
return self._data_splitter.run_remote(self._data_generator.run_remote()) | ||
print(result) | ||
# Hello World! Hello World! Hello World! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.