From 6d58ea997c0250d52cb5a5e349fe7e485e36d004 Mon Sep 17 00:00:00 2001 From: saileshd1402 Date: Fri, 17 Jan 2025 23:06:23 +0530 Subject: [PATCH] Testing CI in JAX example (#2385) * Add MNIST example with SPMD for JAX Illustrate how to use JAX's `pmap` to express and execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension Signed-off-by: Sandipan Panda * Update CONTRIBUTING.md Use -- server-side to install the latest local changes of Training Operator control plane Signed-off-by: Sandipan Panda * Add JAXJob output Signed-off-by: Sandipan Panda * Update JAXJob CI images Signed-off-by: Sandipan Panda * Adjust jaxjob spmd example batch size Signed-off-by: Sandipan Panda * Add JAX Example Docker Image Build in CI Signed-off-by: sailesh duddupudi * Fix script name typo Signed-off-by: sailesh duddupudi * Update script permissions Signed-off-by: sailesh duddupudi * Add KIND_CLUSTER env var Signed-off-by: sailesh duddupudi * Increase timeouts Signed-off-by: sailesh duddupudi * Test higher resources Signed-off-by: sailesh duddupudi * Increase Timeout Signed-off-by: sailesh duddupudi * remove resource reqs Signed-off-by: sailesh duddupudi * test low batch size Signed-off-by: sailesh duddupudi * test small batch size Signed-off-by: sailesh duddupudi * Hardcode number of batches Signed-off-by: sailesh duddupudi --------- Signed-off-by: Sandipan Panda Signed-off-by: sailesh duddupudi Co-authored-by: Sandipan Panda --- .github/workflows/integration-tests.yaml | 14 ++ .github/workflows/publish-example-images.yaml | 4 + CONTRIBUTING.md | 2 +- examples/jax/jax-dist-spmd-mnist/Dockerfile | 29 +++ examples/jax/jax-dist-spmd-mnist/README.md | 132 ++++++++++++++ examples/jax/jax-dist-spmd-mnist/datasets.py | 97 ++++++++++ .../jaxjob_dist_spmd_mnist_gloo.yaml | 16 ++ .../spmd_mnist_classifier_fromscratch.py | 171 ++++++++++++++++++ pkg/controller.v1/jax/envvar_test.go | 2 +- pkg/webhooks/jax/jaxjob_webhook_test.go | 2 +- scripts/gha/build-jax-mnist-image.sh | 25 +++ .../kubeflow/training/constants/constants.py | 2 +- sdk/python/test/e2e/test_e2e_jaxjob.py | 5 +- 13 files changed, 494 insertions(+), 7 deletions(-) create mode 100644 examples/jax/jax-dist-spmd-mnist/Dockerfile create mode 100644 examples/jax/jax-dist-spmd-mnist/README.md create mode 100644 examples/jax/jax-dist-spmd-mnist/datasets.py create mode 100644 examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml create mode 100644 examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py create mode 100755 scripts/gha/build-jax-mnist-image.sh diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index d6fdd6389a..a450a76b16 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -65,12 +65,26 @@ jobs: python-version: ${{ matrix.python-version }} gang-scheduler-name: ${{ matrix.gang-scheduler-name }} + - name: Build JAX Job Example Image + run: | + ./scripts/gha/build-jax-mnist-image.sh + env: + JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test + + - name: Load JAX Job Example Image + run: | + kind load docker-image ${{ env.JAX_JOB_CI_IMAGE }} --name ${{ env.KIND_CLUSTER }} + env: + KIND_CLUSTER: training-operator-cluster + JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test + - name: Run tests run: | pip install pytest python3 -m pip install -e sdk/python; pytest -s sdk/python/test/e2e --log-cli-level=debug --namespace=default env: GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }} + JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test - name: Collect volcano logs if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }} diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index 74dc242551..5012714b57 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -74,3 +74,7 @@ jobs: platforms: linux/amd64 dockerfile: examples/pytorch/deepspeed-demo/Dockerfile context: examples/pytorch/deepspeed-demo + - component-name: jaxjob-dist-spmd-mnist + platforms: linux/amd64,linux/arm64 + dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile + context: examples/jax/jax-dist-spmd-mnist/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eca6af84d7..a7bd8ef76e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -66,7 +66,7 @@ Note, that for the example job below, the PyTorchJob uses the `kubeflow` namespa From here we can apply the manifests to the cluster. ```sh -kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone" +kubectl apply --server-side -k "github.com/kubeflow/training-operator/manifests/overlays/standalone" ``` Then we can patch it with the latest operator image. diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile new file mode 100644 index 0000000000..1538d26507 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.13 + +RUN pip install --upgrade pip +RUN pip install --upgrade jax[k8s] absl-py + +RUN apt-get update && apt-get install -y \ + build-essential \ + cmake \ + git \ + libgoogle-glog-dev \ + libgflags-dev \ + libprotobuf-dev \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* + +RUN git clone https://github.com/facebookincubator/gloo.git \ + && cd gloo \ + && git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \ + && mkdir build \ + && cd build \ + && cmake ../ \ + && make \ + && make install + +WORKDIR /app + +ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/ + +ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md new file mode 100644 index 0000000000..d57a4d80fc --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/README.md @@ -0,0 +1,132 @@ +## An MNIST example with single-program multiple-data (SPMD) data parallelism. + +The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute +[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also +minimizing dependencies by avoiding the use of higher-level layers and +optimizers libraries. + +Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py. + +```bash +$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +``` + +--- + +```bash +$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist +``` + +``` +NAME READY STATUS RESTARTS AGE +jaxjob-mnist-worker-0 0/1 Completed 0 108m +jaxjob-mnist-worker-1 0/1 Completed 0 108m +``` + +--- +```bash +$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) +$ kubectl logs -f ${PODNAME} -n kubeflow +``` + +``` +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] +JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] +JAX device count:16 +JAX local device count:8 +JAX process count:2 +Epoch 0 in 1809.25 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 1 in 0.51 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 2 in 0.69 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 3 in 0.81 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 4 in 0.91 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 5 in 0.97 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 6 in 1.12 sec +Training set accuracy 0.09035000205039978 +Test set accuracy 0.08919999748468399 +Epoch 7 in 1.11 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 8 in 1.21 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 9 in 1.29 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 + +``` + +--- + +```bash +$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow +``` + +``` +apiVersion: kubeflow.org/v1 +kind: JAXJob +metadata: + annotations: + kubectl.kubernetes.io/last-applied-configuration: | + {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} + creationTimestamp: "2024-12-18T16:47:28Z" + generation: 1 + name: jaxjob-mnist + namespace: kubeflow + resourceVersion: "3620" + uid: 15f1db77-3326-405d-95e6-3d9a0d581611 +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest + imagePullPolicy: Always + name: jax +status: + completionTime: "2024-12-18T17:22:11Z" + conditions: + - lastTransitionTime: "2024-12-18T16:47:28Z" + lastUpdateTime: "2024-12-18T16:47:28Z" + message: JAXJob jaxjob-mnist is created. + reason: JAXJobCreated + status: "True" + type: Created + - lastTransitionTime: "2024-12-18T16:50:57Z" + lastUpdateTime: "2024-12-18T16:50:57Z" + message: JAXJob kubeflow/jaxjob-mnist is running. + reason: JAXJobRunning + status: "False" + type: Running + - lastTransitionTime: "2024-12-18T17:22:11Z" + lastUpdateTime: "2024-12-18T17:22:11Z" + message: JAXJob kubeflow/jaxjob-mnist successfully completed. + reason: JAXJobSucceeded + status: "True" + type: Succeeded + replicaStatuses: + Worker: + selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker + succeeded: 2 + startTime: "2024-12-18T16:47:28Z" + +``` diff --git a/examples/jax/jax-dist-spmd-mnist/datasets.py b/examples/jax/jax-dist-spmd-mnist/datasets.py new file mode 100644 index 0000000000..60fb8ce25b --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/datasets.py @@ -0,0 +1,97 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets used in examples.""" + + +import array +import gzip +import os +import struct +import urllib.request +from os import path + +import numpy as np + +_DATA = "/tmp/jax_example_data/" + + +def _download(url, filename): + """Download a url to a file in the JAX data temp directory.""" + if not path.exists(_DATA): + os.makedirs(_DATA) + out_file = path.join(_DATA, filename) + if not path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + print(f"downloaded {url} to {_DATA}") + + +def _partial_flatten(x): + """Flatten all but the first dimension of an ndarray.""" + return np.reshape(x, (x.shape[0], -1)) + + +def _one_hot(x, k, dtype=np.float32): + """Create a one-hot encoding of x of size k.""" + return np.array(x[:, None] == np.arange(k), dtype) + + +def mnist_raw(): + """Download and parse the raw MNIST dataset.""" + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + + def parse_labels(filename): + with gzip.open(filename, "rb") as fh: + _ = struct.unpack(">II", fh.read(8)) + return np.array(array.array("B", fh.read()), dtype=np.uint8) + + def parse_images(filename): + with gzip.open(filename, "rb") as fh: + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) + return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( + num_data, rows, cols + ) + + for filename in [ + "train-images-idx3-ubyte.gz", + "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + "t10k-labels-idx1-ubyte.gz", + ]: + _download(base_url + filename, filename) + + train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) + train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) + test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) + + return train_images, train_labels, test_images, test_labels + + +def mnist(permute_train=False): + """Download, parse and process MNIST data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = mnist_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.0) + test_images = _partial_flatten(test_images) / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml new file mode 100644 index 0000000000..e124b2efef --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml @@ -0,0 +1,16 @@ +apiVersion: "kubeflow.org/v1" +kind: JAXJob +metadata: + name: jaxjob-mnist + namespace: kubeflow +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - name: jax + image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest + imagePullPolicy: Always diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py new file mode 100644 index 0000000000..ca0e9f5165 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -0,0 +1,171 @@ +# Copyright 2024 kubeflow.org. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An MNIST example with single-program multiple-data (SPMD) data parallelism. + +The aim here is to illustrate how to use JAX's `pmap` to express and execute +SPMD programs for data parallelism along a batch dimension, while also +minimizing dependencies by avoiding the use of higher-level layers and +optimizers libraries. +""" + +import multiprocessing +import os +import time +from functools import partial + +import numpy as np +import numpy.random as npr + +# JAX will treat your CPU as a single device by default, regardless of the number +# of cores available. Unfortunately, this means that using `pmap` is not possible out +# of the box – we’ll first need to instruct JAX to split the CPU into multiple devices. +# This variable has to be set before JAX or any library that imports it is imported + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format( + multiprocessing.cpu_count() +) + +import datasets # noqa +import jax # noqa +import jax.numpy as jnp # noqa +from jax import grad, jit, lax, pmap # noqa +from jax.scipy.special import logsumexp # noqa +from jax.tree_util import tree_map # noqa + +jax.config.update("jax_cpu_collectives_implementation", "gloo") + +process_id = int(os.getenv("PROCESS_ID")) +num_processes = int(os.getenv("NUM_PROCESSES")) +coordinator_address = ( + f"{os.getenv('COORDINATOR_ADDRESS')}:{int(os.getenv('COORDINATOR_PORT'))}" +) + +jax.distributed.initialize( + coordinator_address=coordinator_address, + num_processes=num_processes, + process_id=process_id, +) + + +def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): + return [ + (scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n, in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + +def predict(params, inputs): + activations = inputs + for w, b in params[:-1]: + outputs = jnp.dot(activations, w) + b + activations = jnp.tanh(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(activations, final_w) + final_b + return logits - logsumexp(logits, axis=1, keepdims=True) + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -jnp.mean(jnp.sum(preds * targets, axis=1)) + + +@jit +def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(predict(params, inputs), axis=1) + return jnp.mean(predicted_class == target_class) + + +if __name__ == "__main__": + layer_sizes = [784, 1024, 1024, 10] + param_scale = 0.1 + step_size = 0.001 + num_epochs = 10 + # For this manual SPMD example, we get the number of devices (e.g. CPU, + # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. + num_devices = jax.local_device_count() + batch_size = num_devices * 5 + + train_images, train_labels, test_images, test_labels = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + + # Increasing number of batches requires more resources. + num_batches = 10 + + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size : (i + 1) * batch_size] # noqa + images, labels = train_images[batch_idx], train_labels[batch_idx] + # For this SPMD example, we reshape the data batch dimension into two + # batch dimensions, one of which is mapped over parallel devices. + batch_size_per_device, ragged = divmod(images.shape[0], num_devices) + if ragged: + msg = "batch size must be divisible by device count, got {} and {}." + raise ValueError(msg.format(batch_size, num_devices)) + shape_prefix = (num_devices, batch_size_per_device) + images = images.reshape(shape_prefix + images.shape[1:]) + labels = labels.reshape(shape_prefix + labels.shape[1:]) + yield images, labels + + batches = data_stream() + + @partial(pmap, axis_name="batch") + def spmd_update(params, batch): + grads = grad(loss)(params, batch) + # We compute the total gradients, summing across the device-mapped axis, + # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. + grads = [(lax.psum(dw, "batch"), lax.psum(db, "batch")) for dw, db in grads] + return [ + (w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads) + ] + + # We replicate the parameters so that the constituent arrays have a leading + # dimension of size equal to the number of devices we're pmapping over. + init_params = init_random_params(param_scale, layer_sizes) + + def replicate_array(x): + return np.broadcast_to(x, (num_devices,) + x.shape) + + replicated_params = tree_map(replicate_array, init_params) + + print(f"JAX global devices:{jax.devices()}") + print(f"JAX local devices:{jax.local_devices()}") + + print(f"JAX device count:{jax.device_count()}") + print(f"JAX local device count:{jax.local_device_count()}") + print(f"JAX process count:{jax.process_count()}") + + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + replicated_params = spmd_update(replicated_params, next(batches)) + epoch_time = time.time() - start_time + + # We evaluate using the jitted `accuracy` function (not using pmap) by + # grabbing just one of the replicated parameter values. + params = tree_map(lambda x: x[0], replicated_params) + train_acc = accuracy(params, (train_images, train_labels)) + test_acc = accuracy(params, (test_images, test_labels)) + print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + print(f"Training set accuracy {train_acc}") + print(f"Test set accuracy {test_acc}") diff --git a/pkg/controller.v1/jax/envvar_test.go b/pkg/controller.v1/jax/envvar_test.go index 9920e89bbb..3b0f0b5691 100644 --- a/pkg/controller.v1/jax/envvar_test.go +++ b/pkg/controller.v1/jax/envvar_test.go @@ -30,7 +30,7 @@ func TestSetPodEnv(t *testing.T) { Spec: corev1.PodSpec{ Containers: []corev1.Container{{ Name: "jax", - Image: "docker.io/kubeflow/jaxjob-simple:latest", + Image: "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest", Ports: []corev1.ContainerPort{{ Name: kubeflowv1.JAXJobDefaultPortName, ContainerPort: validPort, diff --git a/pkg/webhooks/jax/jaxjob_webhook_test.go b/pkg/webhooks/jax/jaxjob_webhook_test.go index bfbc0eb29c..a6463fb3aa 100644 --- a/pkg/webhooks/jax/jaxjob_webhook_test.go +++ b/pkg/webhooks/jax/jaxjob_webhook_test.go @@ -156,7 +156,7 @@ func TestValidateV1JAXJob(t *testing.T) { Containers: []corev1.Container{ { Name: "", - Image: "gcr.io/kubeflow-ci/jaxjob-simple_test:1.0", + Image: "gcr.io/kubeflow-ci/jaxjob-dist-spmd-mnist_test:1.0", }, }, }, diff --git a/scripts/gha/build-jax-mnist-image.sh b/scripts/gha/build-jax-mnist-image.sh new file mode 100755 index 0000000000..b9a30fa18f --- /dev/null +++ b/scripts/gha/build-jax-mnist-image.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The script is used to build images needed to run JAX Job E2E test. + + +set -o errexit +set -o nounset +set -o pipefail + +# Build Image for MNIST example with SPMD for JAX +docker build examples/jax/jax-dist-spmd-mnist -t ${JAX_JOB_CI_IMAGE} -f examples/jax/jax-dist-spmd-mnist/Dockerfile diff --git a/sdk/python/kubeflow/training/constants/constants.py b/sdk/python/kubeflow/training/constants/constants.py index dba4d49681..2a5415ea26 100644 --- a/sdk/python/kubeflow/training/constants/constants.py +++ b/sdk/python/kubeflow/training/constants/constants.py @@ -153,7 +153,7 @@ JAXJOB_PLURAL = "jaxjobs" JAXJOB_CONTAINER = "jax" JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower() -JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-simple:latest" +JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest" # Dictionary to get plural, model, and container for each Job kind. JOB_PARAMETERS = { diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 6223c8a988..7471f67338 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,7 +155,6 @@ def generate_jaxjob( def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, - image="docker.io/kubeflow/jaxjob-simple:latest", - command=["python", "train.py"], - resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), + image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"), + resources=V1ResourceRequirements(limits={"memory": "3Gi", "cpu": "1.2"}), )