forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild.sh
executable file
·40 lines (31 loc) · 1.7 KB
/
build.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env bash
# JAX builder for Jetson (architecture: ARM64, CUDA support)
set -ex
# Install LLVM/clang dev packages
./llvm.sh 18 all
echo "Building JAX for Jetson"
# Clone JAX repository
git clone --branch "jax-v${JAX_BUILD_VERSION}" --depth=1 --recursive https://github.com/google/jax /opt/jax || \
git clone --depth=1 --recursive https://github.com/google/jax /opt/jax
cd /opt/jax
# Build jaxlib from source with detected versions
BUILD_FLAGS='--enable_cuda --enable_nccl=False '
BUILD_FLAGS+='--cuda_compute_capabilities="sm_87" '
BUILD_FLAGS+='--cuda_version=12.6.1 --cudnn_version=9.4.0 '
# BUILD_FLAGS+='--bazel_options=--repo_env=LOCAL_CUDA_PATH="/usr/local/cuda-12.6" '
# BUILD_FLAGS+='--bazel_options=--repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn/" '
BUILD_FLAGS+='--output_path=/opt/wheels '
python3 build/build.py $BUILD_FLAGS
python3 build/build.py $BUILD_FLAGS --build_gpu_kernel_plugin=cuda --build_gpu_plugin
# Build the jax pip wheels
pip3 wheel --wheel-dir=/opt/wheels --no-deps --verbose .
# Upload the wheels to mirror
twine upload --verbose /opt/wheels/jaxlib-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax_cuda12_pjrt-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax_cuda12_plugin-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
twine upload --verbose /opt/wheels/jax-*.whl || echo "failed to upload wheel to ${TWINE_REPOSITORY_URL}"
# Install them into the container
cd /opt/wheels/
pip3 install --verbose --no-cache-dir jaxlib*.whl jax_cuda12_plugin*.whl jax_cuda12_pjrt*.whl opt_einsum
pip3 install --verbose --no-cache-dir --no-dependencies jax*.whl
cd /opt/jax