Skip to content

Commit

Permalink
Merge pull request #85 from coreweave/es/torch-updates
Browse files Browse the repository at this point in the history
feat(torch): Update `torch` libraries to v2.5.0, bundle `triton`, patch TransformerEngine
  • Loading branch information
wbrown authored Oct 23, 2024
2 parents 9d8ad52 + 3a941b9 commit f575d1b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
6 changes: 3 additions & 3 deletions .github/configurations/torch-base.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cuda: [ 12.6.1, 12.4.1, 12.2.2 ]
os: [ ubuntu22.04, ubuntu20.04 ]
include:
- torch: 2.4.1
vision: 0.19.1
audio: 2.4.1
- torch: 2.5.0
vision: 0.20.0
audio: 2.5.0
6 changes: 3 additions & 3 deletions .github/configurations/torch-nccl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ image:
nccl: 2.21.5-1
nccl-tests-hash: 2ff05b2
include:
- torch: 2.4.1
vision: 0.19.1
audio: 2.4.1
- torch: 2.5.0
vision: 0.20.0
audio: 2.5.0
47 changes: 34 additions & 13 deletions torch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
ARG BUILDER_BASE_IMAGE="nvidia/cuda:12.4.1-devel-ubuntu22.04"
ARG FINAL_BASE_IMAGE="nvidia/cuda:12.4.1-base-ubuntu22.04"

ARG BUILD_TORCH_VERSION="2.4.1"
ARG BUILD_TORCH_VISION_VERSION="0.19.1"
ARG BUILD_TORCH_AUDIO_VERSION="2.4.1"
ARG BUILD_TRANSFORMERENGINE_VERSION="458c7de038ed34bdaf471ced4e3162a28055def7"
ARG BUILD_TORCH_VERSION="2.5.0"
ARG BUILD_TORCH_VISION_VERSION="0.20.0"
ARG BUILD_TORCH_AUDIO_VERSION="2.5.0"
ARG BUILD_TRANSFORMERENGINE_VERSION="1.11"
ARG BUILD_FLASH_ATTN_VERSION="2.6.3"
ARG BUILD_TRITON_VERSION=""
ARG BUILD_TRITON="1"
ARG BUILD_TORCH_CUDA_ARCH_LIST="6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6 8.9 9.0+PTX"

# 8.7 is supported in the PyTorch main branch, but not 2.0.0
Expand All @@ -32,7 +33,7 @@ COPY <<-"EOT" /git/clone.sh

# Try cloning REF as a tag prefixed with "v", otherwise fall back
# to git checkout for commit hashes
CLONE --recurse-submodules --shallow-submodules --also-filter-submodules \
CLONE --recurse-submodules --shallow-submodules --also-filter-submodules --no-tags \
"$REPO" -b "v$REF" "$DEST" || { \
CLONE --no-single-branch --no-checkout "$REPO" "$DEST" && \
git -C "$DEST" checkout "$REF" && \
Expand All @@ -59,27 +60,47 @@ RUN ./clone.sh pytorch/audio audio "${BUILD_TORCH_AUDIO_VERSION}"
# The torchaudio build requires that this directory remain a full git repository,
# so no rm -rf audio/.git is done for this one.

# torchaudio is broken for CUDA 12.5+ without this patch (as of v2.4.1)
# torchaudio is broken for CUDA 12.5+ without this patch (up to and including v2.5.0)
# See https://github.com/pytorch/audio/pull/3811
# Fixed as a side effect of https://github.com/pytorch/audio/pull/3843 in versions after v2.5.0
COPY torchaudio-cu125-pr3811.patch /git/patch
RUN git -C audio apply --index /git/patch && rm /git/patch
RUN if grep -qF '#include <float.h>' \
'audio/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu'; \
then :; else git -C audio apply -v --stat --apply /git/patch; \
fi && \
rm /git/patch

FROM downloader-base as transformerengine-downloader
ARG BUILD_TRANSFORMERENGINE_VERSION
RUN ./clone.sh NVIDIA/TransformerEngine TransformerEngine "${BUILD_TRANSFORMERENGINE_VERSION}"

# Include a patch commit that is sort-of part of v1.11 but isn't in their v1.11 release git tag
# See https://github.com/NVIDIA/TransformerEngine/pull/1222
RUN if [ "${BUILD_TRANSFORMERENGINE_VERSION}" = '1.11' ]; then \
wget 'https://github.com/NVIDIA/TransformerEngine/commit/fc034785f5e3a5bc5600a88766d9a1d75137ce77.patch' -qO- \
| git -C TransformerEngine apply -v --stat --apply -; \
fi

FROM downloader-base as flash-attn-downloader
WORKDIR /git
ARG BUILD_FLASH_ATTN_VERSION
RUN ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_VERSION}"

FROM downloader-base as triton-downloader
FROM downloader-base as triton-version
ENV TRITON_COMMIT_FILE='.ci/docker/ci_commit_pins/triton.txt'
COPY --link --from=pytorch-downloader "/git/pytorch/${TRITON_COMMIT_FILE}" /git/version.txt
ARG BUILD_TRITON_VERSION
RUN if [ -n "${BUILD_TRITON_VERSION}" ]; then \
./clone.sh openai/triton triton "${BUILD_TRITON_VERSION}"; \
echo "${BUILD_TRITON_VERSION}" > /git/version.txt; \
fi

FROM downloader-base as triton-downloader
COPY --link --from=triton-version /git/version.txt /git/version.txt
ARG BUILD_TRITON
RUN if [ "${BUILD_TRITON}" = '1' ]; then \
./clone.sh openai/triton triton "$(cat /git/version.txt)"; \
else \
mkdir triton; \
fi;
fi

FROM alpine/curl:8.7.1 as aocl-downloader
WORKDIR /tmp/install
Expand Down Expand Up @@ -258,10 +279,10 @@ ENV CMAKE_PREFIX_PATH=/usr/bin/ \
CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda/ \
CUDNN_LIB_DIR=/usr/local/cuda/lib64

ARG BUILD_TRITON_VERSION
ARG BUILD_TRITON
RUN --mount=type=bind,from=triton-downloader,source=/git/triton,target=triton/,rw \
--mount=type=cache,target=/ccache \
if [ -n "$BUILD_TRITON_VERSION" ]; then \
if [ "$BUILD_TRITON" = '1' ]; then \
export MAX_JOBS="$(./scale.sh "$(./effective_cpu_count.sh)" 3 32)" && \
cd triton/python && \
python3 -m pip wheel -w ../../dist/ --no-build-isolation --no-deps -vv . && \
Expand Down

0 comments on commit f575d1b

Please sign in to comment.