Skip to content

Commit

Permalink
Testing CI in JAX example (kubeflow#2385)
Browse files Browse the repository at this point in the history
* Add MNIST example with SPMD for JAX

Illustrate how to use JAX's `pmap` to express and execute
single-program multiple-data (SPMD) programs for data parallelism
along a batch dimension

Signed-off-by: Sandipan Panda <[email protected]>

* Update CONTRIBUTING.md

Use -- server-side to install the latest local changes of Training
Operator control plane

Signed-off-by: Sandipan Panda <[email protected]>

* Add JAXJob output

Signed-off-by: Sandipan Panda <[email protected]>

* Update JAXJob CI images

Signed-off-by: Sandipan Panda <[email protected]>

* Adjust jaxjob spmd example batch size

Signed-off-by: Sandipan Panda <[email protected]>

* Add JAX Example Docker Image Build in CI

Signed-off-by: sailesh duddupudi <[email protected]>

* Fix script name typo

Signed-off-by: sailesh duddupudi <[email protected]>

* Update script permissions

Signed-off-by: sailesh duddupudi <[email protected]>

* Add KIND_CLUSTER env var

Signed-off-by: sailesh duddupudi <[email protected]>

* Increase timeouts

Signed-off-by: sailesh duddupudi <[email protected]>

* Test higher resources

Signed-off-by: sailesh duddupudi <[email protected]>

* Increase Timeout

Signed-off-by: sailesh duddupudi <[email protected]>

* remove resource reqs

Signed-off-by: sailesh duddupudi <[email protected]>

* test low batch size

Signed-off-by: sailesh duddupudi <[email protected]>

* test small batch size

Signed-off-by: sailesh duddupudi <[email protected]>

* Hardcode number of batches

Signed-off-by: sailesh duddupudi <[email protected]>

---------

Signed-off-by: Sandipan Panda <[email protected]>
Signed-off-by: sailesh duddupudi <[email protected]>
Co-authored-by: Sandipan Panda <[email protected]>
  • Loading branch information
saileshd1402 and sandipanpanda authored Jan 17, 2025
1 parent 1dfa40c commit 6d58ea9
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 7 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,26 @@ jobs:
python-version: ${{ matrix.python-version }}
gang-scheduler-name: ${{ matrix.gang-scheduler-name }}

- name: Build JAX Job Example Image
run: |
./scripts/gha/build-jax-mnist-image.sh
env:
JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Load JAX Job Example Image
run: |
kind load docker-image ${{ env.JAX_JOB_CI_IMAGE }} --name ${{ env.KIND_CLUSTER }}
env:
KIND_CLUSTER: training-operator-cluster
JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Run tests
run: |
pip install pytest
python3 -m pip install -e sdk/python; pytest -s sdk/python/test/e2e --log-cli-level=debug --namespace=default
env:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/publish-example-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ jobs:
platforms: linux/amd64
dockerfile: examples/pytorch/deepspeed-demo/Dockerfile
context: examples/pytorch/deepspeed-demo
- component-name: jaxjob-dist-spmd-mnist
platforms: linux/amd64,linux/arm64
dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile
context: examples/jax/jax-dist-spmd-mnist/
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Note, that for the example job below, the PyTorchJob uses the `kubeflow` namespa

From here we can apply the manifests to the cluster.
```sh
kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone"
kubectl apply --server-side -k "github.com/kubeflow/training-operator/manifests/overlays/standalone"
```

Then we can patch it with the latest operator image.
Expand Down
29 changes: 29 additions & 0 deletions examples/jax/jax-dist-spmd-mnist/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
FROM python:3.13

RUN pip install --upgrade pip
RUN pip install --upgrade jax[k8s] absl-py

RUN apt-get update && apt-get install -y \
build-essential \
cmake \
git \
libgoogle-glog-dev \
libgflags-dev \
libprotobuf-dev \
protobuf-compiler \
&& rm -rf /var/lib/apt/lists/*

RUN git clone https://github.com/facebookincubator/gloo.git \
&& cd gloo \
&& git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \
&& mkdir build \
&& cd build \
&& cmake ../ \
&& make \
&& make install

WORKDIR /app

ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/

ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"]
132 changes: 132 additions & 0 deletions examples/jax/jax-dist-spmd-mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
## An MNIST example with single-program multiple-data (SPMD) data parallelism.

The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute
[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also
minimizing dependencies by avoiding the use of higher-level layers and
optimizers libraries.

Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py.

```bash
$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml
```

---

```bash
$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist
```

```
NAME READY STATUS RESTARTS AGE
jaxjob-mnist-worker-0 0/1 Completed 0 108m
jaxjob-mnist-worker-1 0/1 Completed 0 108m
```

---
```bash
$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
$ kubectl logs -f ${PODNAME} -n kubeflow
```

```
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/
JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)]
JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
JAX device count:16
JAX local device count:8
JAX process count:2
Epoch 0 in 1809.25 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 1 in 0.51 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 2 in 0.69 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 3 in 0.81 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 4 in 0.91 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 5 in 0.97 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 6 in 1.12 sec
Training set accuracy 0.09035000205039978
Test set accuracy 0.08919999748468399
Epoch 7 in 1.11 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 8 in 1.21 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 9 in 1.29 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
```

---

```bash
$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow
```

```
apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
annotations:
kubectl.kubernetes.io/last-applied-configuration: |
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
creationTimestamp: "2024-12-18T16:47:28Z"
generation: 1
name: jaxjob-mnist
namespace: kubeflow
resourceVersion: "3620"
uid: 15f1db77-3326-405d-95e6-3d9a0d581611
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- image: docker.io/sandipanify/jaxjob-spmd-mnist:latest
imagePullPolicy: Always
name: jax
status:
completionTime: "2024-12-18T17:22:11Z"
conditions:
- lastTransitionTime: "2024-12-18T16:47:28Z"
lastUpdateTime: "2024-12-18T16:47:28Z"
message: JAXJob jaxjob-mnist is created.
reason: JAXJobCreated
status: "True"
type: Created
- lastTransitionTime: "2024-12-18T16:50:57Z"
lastUpdateTime: "2024-12-18T16:50:57Z"
message: JAXJob kubeflow/jaxjob-mnist is running.
reason: JAXJobRunning
status: "False"
type: Running
- lastTransitionTime: "2024-12-18T17:22:11Z"
lastUpdateTime: "2024-12-18T17:22:11Z"
message: JAXJob kubeflow/jaxjob-mnist successfully completed.
reason: JAXJobSucceeded
status: "True"
type: Succeeded
replicaStatuses:
Worker:
selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
succeeded: 2
startTime: "2024-12-18T16:47:28Z"
```
97 changes: 97 additions & 0 deletions examples/jax/jax-dist-spmd-mnist/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Datasets used in examples."""


import array
import gzip
import os
import struct
import urllib.request
from os import path

import numpy as np

_DATA = "/tmp/jax_example_data/"


def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print(f"downloaded {url} to {_DATA}")


def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return np.reshape(x, (x.shape[0], -1))


def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)


def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
_ = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)

def parse_images(filename):
with gzip.open(filename, "rb") as fh:
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(
num_data, rows, cols
)

for filename in [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]:
_download(base_url + filename, filename)

train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))

return train_images, train_labels, test_images, test_labels


def mnist(permute_train=False):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
train_images, train_labels, test_images, test_labels = mnist_raw()

train_images = _partial_flatten(train_images) / np.float32(255.0)
test_images = _partial_flatten(test_images) / np.float32(255.0)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

if permute_train:
perm = np.random.RandomState(0).permutation(train_images.shape[0])
train_images = train_images[perm]
train_labels = train_labels[perm]

return train_images, train_labels, test_images, test_labels
16 changes: 16 additions & 0 deletions examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
apiVersion: "kubeflow.org/v1"
kind: JAXJob
metadata:
name: jaxjob-mnist
namespace: kubeflow
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax
image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest
imagePullPolicy: Always
Loading

0 comments on commit 6d58ea9

Please sign in to comment.