diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile deleted file mode 100644 index 1538d26507..0000000000 --- a/examples/jax/jax-dist-spmd-mnist/Dockerfile +++ /dev/null @@ -1,29 +0,0 @@ -FROM python:3.13 - -RUN pip install --upgrade pip -RUN pip install --upgrade jax[k8s] absl-py - -RUN apt-get update && apt-get install -y \ - build-essential \ - cmake \ - git \ - libgoogle-glog-dev \ - libgflags-dev \ - libprotobuf-dev \ - protobuf-compiler \ - && rm -rf /var/lib/apt/lists/* - -RUN git clone https://github.com/facebookincubator/gloo.git \ - && cd gloo \ - && git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \ - && mkdir build \ - && cd build \ - && cmake ../ \ - && make \ - && make install - -WORKDIR /app - -ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/ - -ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md deleted file mode 100644 index d57a4d80fc..0000000000 --- a/examples/jax/jax-dist-spmd-mnist/README.md +++ /dev/null @@ -1,132 +0,0 @@ -## An MNIST example with single-program multiple-data (SPMD) data parallelism. - -The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute -[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also -minimizing dependencies by avoiding the use of higher-level layers and -optimizers libraries. - -Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py. - -```bash -$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml -``` - ---- - -```bash -$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist -``` - -``` -NAME READY STATUS RESTARTS AGE -jaxjob-mnist-worker-0 0/1 Completed 0 108m -jaxjob-mnist-worker-1 0/1 Completed 0 108m -``` - ---- -```bash -$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) -$ kubectl logs -f ${PODNAME} -n kubeflow -``` - -``` -downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ -downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ -downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ -downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ -JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] -JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] -JAX device count:16 -JAX local device count:8 -JAX process count:2 -Epoch 0 in 1809.25 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 1 in 0.51 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 2 in 0.69 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 3 in 0.81 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 4 in 0.91 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 5 in 0.97 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 6 in 1.12 sec -Training set accuracy 0.09035000205039978 -Test set accuracy 0.08919999748468399 -Epoch 7 in 1.11 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 8 in 1.21 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 -Epoch 9 in 1.29 sec -Training set accuracy 0.09871666878461838 -Test set accuracy 0.09799999743700027 - -``` - ---- - -```bash -$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow -``` - -``` -apiVersion: kubeflow.org/v1 -kind: JAXJob -metadata: - annotations: - kubectl.kubernetes.io/last-applied-configuration: | - {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} - creationTimestamp: "2024-12-18T16:47:28Z" - generation: 1 - name: jaxjob-mnist - namespace: kubeflow - resourceVersion: "3620" - uid: 15f1db77-3326-405d-95e6-3d9a0d581611 -spec: - jaxReplicaSpecs: - Worker: - replicas: 2 - restartPolicy: OnFailure - template: - spec: - containers: - - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest - imagePullPolicy: Always - name: jax -status: - completionTime: "2024-12-18T17:22:11Z" - conditions: - - lastTransitionTime: "2024-12-18T16:47:28Z" - lastUpdateTime: "2024-12-18T16:47:28Z" - message: JAXJob jaxjob-mnist is created. - reason: JAXJobCreated - status: "True" - type: Created - - lastTransitionTime: "2024-12-18T16:50:57Z" - lastUpdateTime: "2024-12-18T16:50:57Z" - message: JAXJob kubeflow/jaxjob-mnist is running. - reason: JAXJobRunning - status: "False" - type: Running - - lastTransitionTime: "2024-12-18T17:22:11Z" - lastUpdateTime: "2024-12-18T17:22:11Z" - message: JAXJob kubeflow/jaxjob-mnist successfully completed. - reason: JAXJobSucceeded - status: "True" - type: Succeeded - replicaStatuses: - Worker: - selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker - succeeded: 2 - startTime: "2024-12-18T16:47:28Z" - -``` diff --git a/examples/jax/jax-dist-spmd-mnist/datasets.py b/examples/jax/jax-dist-spmd-mnist/datasets.py deleted file mode 100644 index 60fb8ce25b..0000000000 --- a/examples/jax/jax-dist-spmd-mnist/datasets.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Datasets used in examples.""" - - -import array -import gzip -import os -import struct -import urllib.request -from os import path - -import numpy as np - -_DATA = "/tmp/jax_example_data/" - - -def _download(url, filename): - """Download a url to a file in the JAX data temp directory.""" - if not path.exists(_DATA): - os.makedirs(_DATA) - out_file = path.join(_DATA, filename) - if not path.isfile(out_file): - urllib.request.urlretrieve(url, out_file) - print(f"downloaded {url} to {_DATA}") - - -def _partial_flatten(x): - """Flatten all but the first dimension of an ndarray.""" - return np.reshape(x, (x.shape[0], -1)) - - -def _one_hot(x, k, dtype=np.float32): - """Create a one-hot encoding of x of size k.""" - return np.array(x[:, None] == np.arange(k), dtype) - - -def mnist_raw(): - """Download and parse the raw MNIST dataset.""" - # CVDF mirror of http://yann.lecun.com/exdb/mnist/ - base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" - - def parse_labels(filename): - with gzip.open(filename, "rb") as fh: - _ = struct.unpack(">II", fh.read(8)) - return np.array(array.array("B", fh.read()), dtype=np.uint8) - - def parse_images(filename): - with gzip.open(filename, "rb") as fh: - _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) - return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( - num_data, rows, cols - ) - - for filename in [ - "train-images-idx3-ubyte.gz", - "train-labels-idx1-ubyte.gz", - "t10k-images-idx3-ubyte.gz", - "t10k-labels-idx1-ubyte.gz", - ]: - _download(base_url + filename, filename) - - train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) - train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) - test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) - test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) - - return train_images, train_labels, test_images, test_labels - - -def mnist(permute_train=False): - """Download, parse and process MNIST data to unit scale and one-hot labels.""" - train_images, train_labels, test_images, test_labels = mnist_raw() - - train_images = _partial_flatten(train_images) / np.float32(255.0) - test_images = _partial_flatten(test_images) / np.float32(255.0) - train_labels = _one_hot(train_labels, 10) - test_labels = _one_hot(test_labels, 10) - - if permute_train: - perm = np.random.RandomState(0).permutation(train_images.shape[0]) - train_images = train_images[perm] - train_labels = train_labels[perm] - - return train_images, train_labels, test_images, test_labels diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml deleted file mode 100644 index e124b2efef..0000000000 --- a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +++ /dev/null @@ -1,16 +0,0 @@ -apiVersion: "kubeflow.org/v1" -kind: JAXJob -metadata: - name: jaxjob-mnist - namespace: kubeflow -spec: - jaxReplicaSpecs: - Worker: - replicas: 2 - restartPolicy: OnFailure - template: - spec: - containers: - - name: jax - image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest - imagePullPolicy: Always diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py deleted file mode 100644 index ca0e9f5165..0000000000 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2024 kubeflow.org. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""An MNIST example with single-program multiple-data (SPMD) data parallelism. - -The aim here is to illustrate how to use JAX's `pmap` to express and execute -SPMD programs for data parallelism along a batch dimension, while also -minimizing dependencies by avoiding the use of higher-level layers and -optimizers libraries. -""" - -import multiprocessing -import os -import time -from functools import partial - -import numpy as np -import numpy.random as npr - -# JAX will treat your CPU as a single device by default, regardless of the number -# of cores available. Unfortunately, this means that using `pmap` is not possible out -# of the box – we’ll first need to instruct JAX to split the CPU into multiple devices. -# This variable has to be set before JAX or any library that imports it is imported - -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format( - multiprocessing.cpu_count() -) - -import datasets # noqa -import jax # noqa -import jax.numpy as jnp # noqa -from jax import grad, jit, lax, pmap # noqa -from jax.scipy.special import logsumexp # noqa -from jax.tree_util import tree_map # noqa - -jax.config.update("jax_cpu_collectives_implementation", "gloo") - -process_id = int(os.getenv("PROCESS_ID")) -num_processes = int(os.getenv("NUM_PROCESSES")) -coordinator_address = ( - f"{os.getenv('COORDINATOR_ADDRESS')}:{int(os.getenv('COORDINATOR_PORT'))}" -) - -jax.distributed.initialize( - coordinator_address=coordinator_address, - num_processes=num_processes, - process_id=process_id, -) - - -def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): - return [ - (scale * rng.randn(m, n), scale * rng.randn(n)) - for m, n, in zip(layer_sizes[:-1], layer_sizes[1:]) - ] - - -def predict(params, inputs): - activations = inputs - for w, b in params[:-1]: - outputs = jnp.dot(activations, w) + b - activations = jnp.tanh(outputs) - - final_w, final_b = params[-1] - logits = jnp.dot(activations, final_w) + final_b - return logits - logsumexp(logits, axis=1, keepdims=True) - - -def loss(params, batch): - inputs, targets = batch - preds = predict(params, inputs) - return -jnp.mean(jnp.sum(preds * targets, axis=1)) - - -@jit -def accuracy(params, batch): - inputs, targets = batch - target_class = jnp.argmax(targets, axis=1) - predicted_class = jnp.argmax(predict(params, inputs), axis=1) - return jnp.mean(predicted_class == target_class) - - -if __name__ == "__main__": - layer_sizes = [784, 1024, 1024, 10] - param_scale = 0.1 - step_size = 0.001 - num_epochs = 10 - # For this manual SPMD example, we get the number of devices (e.g. CPU, - # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. - num_devices = jax.local_device_count() - batch_size = num_devices * 5 - - train_images, train_labels, test_images, test_labels = datasets.mnist() - num_train = train_images.shape[0] - num_complete_batches, leftover = divmod(num_train, batch_size) - - # Increasing number of batches requires more resources. - num_batches = 10 - - def data_stream(): - rng = npr.RandomState(0) - while True: - perm = rng.permutation(num_train) - for i in range(num_batches): - batch_idx = perm[i * batch_size : (i + 1) * batch_size] # noqa - images, labels = train_images[batch_idx], train_labels[batch_idx] - # For this SPMD example, we reshape the data batch dimension into two - # batch dimensions, one of which is mapped over parallel devices. - batch_size_per_device, ragged = divmod(images.shape[0], num_devices) - if ragged: - msg = "batch size must be divisible by device count, got {} and {}." - raise ValueError(msg.format(batch_size, num_devices)) - shape_prefix = (num_devices, batch_size_per_device) - images = images.reshape(shape_prefix + images.shape[1:]) - labels = labels.reshape(shape_prefix + labels.shape[1:]) - yield images, labels - - batches = data_stream() - - @partial(pmap, axis_name="batch") - def spmd_update(params, batch): - grads = grad(loss)(params, batch) - # We compute the total gradients, summing across the device-mapped axis, - # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. - grads = [(lax.psum(dw, "batch"), lax.psum(db, "batch")) for dw, db in grads] - return [ - (w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads) - ] - - # We replicate the parameters so that the constituent arrays have a leading - # dimension of size equal to the number of devices we're pmapping over. - init_params = init_random_params(param_scale, layer_sizes) - - def replicate_array(x): - return np.broadcast_to(x, (num_devices,) + x.shape) - - replicated_params = tree_map(replicate_array, init_params) - - print(f"JAX global devices:{jax.devices()}") - print(f"JAX local devices:{jax.local_devices()}") - - print(f"JAX device count:{jax.device_count()}") - print(f"JAX local device count:{jax.local_device_count()}") - print(f"JAX process count:{jax.process_count()}") - - for epoch in range(num_epochs): - start_time = time.time() - for _ in range(num_batches): - replicated_params = spmd_update(replicated_params, next(batches)) - epoch_time = time.time() - start_time - - # We evaluate using the jitted `accuracy` function (not using pmap) by - # grabbing just one of the replicated parameter values. - params = tree_map(lambda x: x[0], replicated_params) - train_acc = accuracy(params, (train_images, train_labels)) - test_acc = accuracy(params, (test_images, test_labels)) - print(f"Epoch {epoch} in {epoch_time:0.2f} sec") - print(f"Training set accuracy {train_acc}") - print(f"Test set accuracy {test_acc}")