Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Dec 10, 2024
1 parent 527efae commit 6f6297c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ Some standouts:
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| AMD GPU (Linux) | `pip install -U "jax[rocm]"` |
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |

Expand Down
56 changes: 37 additions & 19 deletions build/rocm/README.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,56 @@
# JAX Builds on ROCm
This directory contains files and setup instructions to build and test JAX for ROCm in Docker environment (runtime and CI). You can build, test and run JAX on ROCm yourself!
This directory contains files and setup instructions to build and test JAX for ROCm in a Docker environment (runtime and CI). You can build, test, and run JAX on ROCm yourself!
***
### Build JAX-ROCm in docker for the runtime
## JAX ROCm Releases

1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).
### Overview
We aim to push all ROCm-related changes to the OpenXLA repository. However, there may be times when certain JAX/jaxlib updates for ROCm are not yet reflected in the upstream JAX repository.

To address this, we maintain ROCm-specific JAX/jaxlib branches tied to JAX releases.
These branches are hosted in the ROCm fork of JAX and XLA:
* https://github.com/ROCm/jax
* https://github.com/ROCm/xla

### Branch Naming Convention
Branches are named in the format rocm-jaxlib-[jaxlib-version]. For example:
* For JAX version 0.4.35, the corresponding branch is `rocm-jaxlib-v0.4.35`.
You can access it at: https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.35.

2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in appropriate
options. The example below builds Python 3.12 container.
### Latest JAX Releases for ROCm

GitHub Releases:

```Bash
./build/rocm/ci_build.sh --py_version 3.12
https://github.com/ROCm/jax/releases
```

3. To launch a JAX-ROCm container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).
Docker Images:
Prebuilt ROCm JAX Docker images are available on Docker Hub:
```Bash
https://hub.docker.com/r/rocm/jax-community/tags
```

PyPI Installation:
JAX can also be installed via PyPI using the following command:
```Bash
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./:/jax --name rocm_jax jax-rocm:latest /bin/bash
pip install jax[rocm]
```

Note: Earlier versions of jaxlib for ROCm are available on PyPI [jaxlib-rocm PyPI History](https://pypi.org/project/jaxlib-rocm/#history).

***
### JAX ROCm Releases
We aim to push all ROCm-related changes to the OpenXLA repository. However, there may be times when certain JAX/jaxlib updates for
ROCm are not yet reflected in the upstream JAX repository. To address this, we maintain ROCm-specific JAX/jaxlib branches tied to JAX
releases. These branches are available in the ROCm fork of JAX at https://github.com/ROCm/jax. Look for branches named in the format
rocm-jaxlib-[jaxlib-version]. You can also find corresponding branches in https://github.com/ROCm/xla. For example, for JAX version
0.4.33, the branch is named rocm-jaxlib-v0.4.33, which can be accessed at https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.33.
## Build JAX-ROCm in docker for the runtime

1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).

JAX source-code and related wheels for ROCm are available here
2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in appropriate options. The example below builds Python 3.12 container.

```Bash
https://github.com/ROCm/jax/releases
./build/rocm/ci_build.sh --py_version 3.12
```

***Note:*** Some earlier jaxlib versions on ROCm were released on ***PyPi***.
```
https://pypi.org/project/jaxlib-rocm/#history
3. To launch a JAX-ROCm container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).

```Bash
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./:/jax --name rocm_jax jax-rocm:latest /bin/bash
```
29 changes: 24 additions & 5 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,36 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```

The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`,
Alternatively, the recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`,
and selecting the appropriate options.

To build jaxlib with ROCM support, you can run the following build commands,
suitably adjusted for your paths and ROCM version.
You can also use prebuilt ROCm JAX images available on Docker Hub:
```Bash
https://hub.docker.com/r/rocm/jax-community/tags
```

#### Building JAX for ROCm

Follow these steps to build JAX for ROCm:

1. Clone the Repository
Clone the ROCm-specific fork of JAX:

```Bash
git clone https://github.com/ROCm/jax -b <branch_name>
cd jax
```

2. Build the Wheels
Use the following command to build the wheels for JAX:

```Bash
python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3
```
to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and
jax-rocm-pjrt)
jax-rocm-pjrt). The generated wheels will be located in the `dist/` directory.

#### Additional Information

AMD's fork of the XLA repository may include fixes not present in the upstream
XLA repository. If you experience problems with the upstream repository, you can
Expand All @@ -228,7 +247,7 @@ git clone https://github.com/ROCm/xla.git
and override the XLA repository with which JAX is built:

```
python3 ./build/build.py build --wheels=jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/
python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/
```

For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`.
Expand Down
4 changes: 2 additions & 2 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ refer to

JAX has experimental ROCm support. There are two ways to install JAX:

* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or
* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_).
* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or
* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus).

(install-intel-gpu)=
## Intel GPU
Expand Down

0 comments on commit 6f6297c

Please sign in to comment.