forked from kubeflow/training-operator
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Testing CI in JAX example (kubeflow#2385)
* Add MNIST example with SPMD for JAX Illustrate how to use JAX's `pmap` to express and execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension Signed-off-by: Sandipan Panda <[email protected]> * Update CONTRIBUTING.md Use -- server-side to install the latest local changes of Training Operator control plane Signed-off-by: Sandipan Panda <[email protected]> * Add JAXJob output Signed-off-by: Sandipan Panda <[email protected]> * Update JAXJob CI images Signed-off-by: Sandipan Panda <[email protected]> * Adjust jaxjob spmd example batch size Signed-off-by: Sandipan Panda <[email protected]> * Add JAX Example Docker Image Build in CI Signed-off-by: sailesh duddupudi <[email protected]> * Fix script name typo Signed-off-by: sailesh duddupudi <[email protected]> * Update script permissions Signed-off-by: sailesh duddupudi <[email protected]> * Add KIND_CLUSTER env var Signed-off-by: sailesh duddupudi <[email protected]> * Increase timeouts Signed-off-by: sailesh duddupudi <[email protected]> * Test higher resources Signed-off-by: sailesh duddupudi <[email protected]> * Increase Timeout Signed-off-by: sailesh duddupudi <[email protected]> * remove resource reqs Signed-off-by: sailesh duddupudi <[email protected]> * test low batch size Signed-off-by: sailesh duddupudi <[email protected]> * test small batch size Signed-off-by: sailesh duddupudi <[email protected]> * Hardcode number of batches Signed-off-by: sailesh duddupudi <[email protected]> --------- Signed-off-by: Sandipan Panda <[email protected]> Signed-off-by: sailesh duddupudi <[email protected]> Co-authored-by: Sandipan Panda <[email protected]>
- Loading branch information
1 parent
1dfa40c
commit 6d58ea9
Showing
13 changed files
with
494 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
FROM python:3.13 | ||
|
||
RUN pip install --upgrade pip | ||
RUN pip install --upgrade jax[k8s] absl-py | ||
|
||
RUN apt-get update && apt-get install -y \ | ||
build-essential \ | ||
cmake \ | ||
git \ | ||
libgoogle-glog-dev \ | ||
libgflags-dev \ | ||
libprotobuf-dev \ | ||
protobuf-compiler \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
RUN git clone https://github.com/facebookincubator/gloo.git \ | ||
&& cd gloo \ | ||
&& git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \ | ||
&& mkdir build \ | ||
&& cd build \ | ||
&& cmake ../ \ | ||
&& make \ | ||
&& make install | ||
|
||
WORKDIR /app | ||
|
||
ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/ | ||
|
||
ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
## An MNIST example with single-program multiple-data (SPMD) data parallelism. | ||
|
||
The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute | ||
[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also | ||
minimizing dependencies by avoiding the use of higher-level layers and | ||
optimizers libraries. | ||
|
||
Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py. | ||
|
||
```bash | ||
$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml | ||
``` | ||
|
||
--- | ||
|
||
```bash | ||
$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist | ||
``` | ||
|
||
``` | ||
NAME READY STATUS RESTARTS AGE | ||
jaxjob-mnist-worker-0 0/1 Completed 0 108m | ||
jaxjob-mnist-worker-1 0/1 Completed 0 108m | ||
``` | ||
|
||
--- | ||
```bash | ||
$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) | ||
$ kubectl logs -f ${PODNAME} -n kubeflow | ||
``` | ||
|
||
``` | ||
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ | ||
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ | ||
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ | ||
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ | ||
JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] | ||
JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] | ||
JAX device count:16 | ||
JAX local device count:8 | ||
JAX process count:2 | ||
Epoch 0 in 1809.25 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 1 in 0.51 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 2 in 0.69 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 3 in 0.81 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 4 in 0.91 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 5 in 0.97 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 6 in 1.12 sec | ||
Training set accuracy 0.09035000205039978 | ||
Test set accuracy 0.08919999748468399 | ||
Epoch 7 in 1.11 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 8 in 1.21 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
Epoch 9 in 1.29 sec | ||
Training set accuracy 0.09871666878461838 | ||
Test set accuracy 0.09799999743700027 | ||
``` | ||
|
||
--- | ||
|
||
```bash | ||
$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow | ||
``` | ||
|
||
``` | ||
apiVersion: kubeflow.org/v1 | ||
kind: JAXJob | ||
metadata: | ||
annotations: | ||
kubectl.kubernetes.io/last-applied-configuration: | | ||
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} | ||
creationTimestamp: "2024-12-18T16:47:28Z" | ||
generation: 1 | ||
name: jaxjob-mnist | ||
namespace: kubeflow | ||
resourceVersion: "3620" | ||
uid: 15f1db77-3326-405d-95e6-3d9a0d581611 | ||
spec: | ||
jaxReplicaSpecs: | ||
Worker: | ||
replicas: 2 | ||
restartPolicy: OnFailure | ||
template: | ||
spec: | ||
containers: | ||
- image: docker.io/sandipanify/jaxjob-spmd-mnist:latest | ||
imagePullPolicy: Always | ||
name: jax | ||
status: | ||
completionTime: "2024-12-18T17:22:11Z" | ||
conditions: | ||
- lastTransitionTime: "2024-12-18T16:47:28Z" | ||
lastUpdateTime: "2024-12-18T16:47:28Z" | ||
message: JAXJob jaxjob-mnist is created. | ||
reason: JAXJobCreated | ||
status: "True" | ||
type: Created | ||
- lastTransitionTime: "2024-12-18T16:50:57Z" | ||
lastUpdateTime: "2024-12-18T16:50:57Z" | ||
message: JAXJob kubeflow/jaxjob-mnist is running. | ||
reason: JAXJobRunning | ||
status: "False" | ||
type: Running | ||
- lastTransitionTime: "2024-12-18T17:22:11Z" | ||
lastUpdateTime: "2024-12-18T17:22:11Z" | ||
message: JAXJob kubeflow/jaxjob-mnist successfully completed. | ||
reason: JAXJobSucceeded | ||
status: "True" | ||
type: Succeeded | ||
replicaStatuses: | ||
Worker: | ||
selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker | ||
succeeded: 2 | ||
startTime: "2024-12-18T16:47:28Z" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2018 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Datasets used in examples.""" | ||
|
||
|
||
import array | ||
import gzip | ||
import os | ||
import struct | ||
import urllib.request | ||
from os import path | ||
|
||
import numpy as np | ||
|
||
_DATA = "/tmp/jax_example_data/" | ||
|
||
|
||
def _download(url, filename): | ||
"""Download a url to a file in the JAX data temp directory.""" | ||
if not path.exists(_DATA): | ||
os.makedirs(_DATA) | ||
out_file = path.join(_DATA, filename) | ||
if not path.isfile(out_file): | ||
urllib.request.urlretrieve(url, out_file) | ||
print(f"downloaded {url} to {_DATA}") | ||
|
||
|
||
def _partial_flatten(x): | ||
"""Flatten all but the first dimension of an ndarray.""" | ||
return np.reshape(x, (x.shape[0], -1)) | ||
|
||
|
||
def _one_hot(x, k, dtype=np.float32): | ||
"""Create a one-hot encoding of x of size k.""" | ||
return np.array(x[:, None] == np.arange(k), dtype) | ||
|
||
|
||
def mnist_raw(): | ||
"""Download and parse the raw MNIST dataset.""" | ||
# CVDF mirror of http://yann.lecun.com/exdb/mnist/ | ||
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | ||
|
||
def parse_labels(filename): | ||
with gzip.open(filename, "rb") as fh: | ||
_ = struct.unpack(">II", fh.read(8)) | ||
return np.array(array.array("B", fh.read()), dtype=np.uint8) | ||
|
||
def parse_images(filename): | ||
with gzip.open(filename, "rb") as fh: | ||
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) | ||
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( | ||
num_data, rows, cols | ||
) | ||
|
||
for filename in [ | ||
"train-images-idx3-ubyte.gz", | ||
"train-labels-idx1-ubyte.gz", | ||
"t10k-images-idx3-ubyte.gz", | ||
"t10k-labels-idx1-ubyte.gz", | ||
]: | ||
_download(base_url + filename, filename) | ||
|
||
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) | ||
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) | ||
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) | ||
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) | ||
|
||
return train_images, train_labels, test_images, test_labels | ||
|
||
|
||
def mnist(permute_train=False): | ||
"""Download, parse and process MNIST data to unit scale and one-hot labels.""" | ||
train_images, train_labels, test_images, test_labels = mnist_raw() | ||
|
||
train_images = _partial_flatten(train_images) / np.float32(255.0) | ||
test_images = _partial_flatten(test_images) / np.float32(255.0) | ||
train_labels = _one_hot(train_labels, 10) | ||
test_labels = _one_hot(test_labels, 10) | ||
|
||
if permute_train: | ||
perm = np.random.RandomState(0).permutation(train_images.shape[0]) | ||
train_images = train_images[perm] | ||
train_labels = train_labels[perm] | ||
|
||
return train_images, train_labels, test_images, test_labels |
16 changes: 16 additions & 0 deletions
16
examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
apiVersion: "kubeflow.org/v1" | ||
kind: JAXJob | ||
metadata: | ||
name: jaxjob-mnist | ||
namespace: kubeflow | ||
spec: | ||
jaxReplicaSpecs: | ||
Worker: | ||
replicas: 2 | ||
restartPolicy: OnFailure | ||
template: | ||
spec: | ||
containers: | ||
- name: jax | ||
image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest | ||
imagePullPolicy: Always |
Oops, something went wrong.