diff --git a/src/streamdiffusion/tools/install-tensorrt.py b/src/streamdiffusion/tools/install-tensorrt.py index 182871c4..b2a32c3d 100644 --- a/src/streamdiffusion/tools/install-tensorrt.py +++ b/src/streamdiffusion/tools/install-tensorrt.py @@ -31,7 +31,7 @@ def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()): if not is_installed("tensorrt"): run_pip(f"install {cudnn_name} --no-cache-dir") run_pip( - "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir" + f"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post{cu}.dev4 --no-cache-dir" ) if not is_installed("polygraphy"):