Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Implement MPI Plugin for Kubeflow Trainer #2394

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
"type": "boolean"
},
"sshAuthMountPath": {
"description": "Directory where SSH keys are mounted.",
"description": "Directory where SSH keys are mounted. Defaults to /root/.ssh.",
"type": "string"
}
}
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/onsi/gomega v1.35.1
github.com/open-policy-agent/cert-controller v0.12.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.31.0
k8s.io/api v0.31.3
k8s.io/apimachinery v0.31.3
k8s.io/client-go v0.31.3
Expand All @@ -27,8 +28,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.12.1 // indirect
github.com/evanphx/json-patch v5.9.0+incompatible // indirect
github.com/emicklei/go-restful/v3 v3.12.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
Expand Down Expand Up @@ -63,7 +63,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
Expand Down
14 changes: 8 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU=
github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls=
github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/emicklei/go-restful/v3 v3.12.0 h1:y2DdzBAURM29NFF94q6RaY4vjIH1rtwDapwQtU84iWk=
github.com/emicklei/go-restful/v3 v3.12.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U=
github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg=
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
Expand Down Expand Up @@ -115,6 +115,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
Expand All @@ -125,8 +127,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
2 changes: 2 additions & 0 deletions manifests/base/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ rules:
- apiGroups:
- ""
resources:
- configmaps
- secrets
verbs:
- create
- get
- list
- update
Expand Down
1 change: 1 addition & 0 deletions manifests/base/runtimes/pretraining/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
- torch_distributed.yaml
- mpi_distributed.yaml
45 changes: 45 additions & 0 deletions manifests/base/runtimes/pretraining/mpi_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TODO (andreyvelich): Change this to DeepSpeed or MLX runtime.
apiVersion: kubeflow.org/v1alpha1
kind: ClusterTrainingRuntime
metadata:
name: mpi-distributed
labels:
training.kubeflow.org/phase: pre-training
spec:
mlPolicy:
numNodes: 1
mpi:
numProcPerNode: 1
mpiImplementation: OpenMPI
sshAuthMountPath: /root/.ssh
template:
spec:
# TODO (andreyvelich): Use dependsOn when it is released.
startupPolicy:
startupPolicyOrder: InOrder
replicatedJobs:
- name: launcher
template:
spec:
template:
spec:
# TODO (andreyvelich): Change the command with mpirun.
containers:
- name: launcher
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
- name: trainer-node
template:
spec:
template:
spec:
containers:
- name: trainer
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
5 changes: 3 additions & 2 deletions pkg/apis/trainer/v1alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ type MPIMLPolicySource struct {

// Implementation name for the MPI to create the appropriate hostfile.
// Defaults to OpenMPI.
MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"`
MPIImplementation MPIImplementation `json:"mpiImplementation,omitempty"`

// Directory where SSH keys are mounted.
SSHAuthMountPath *string `json:"sshAuthMountPath,omitempty"`
// Defaults to /root/.ssh.
SSHAuthMountPath string `json:"sshAuthMountPath,omitempty"`

// Whether to run training process on the launcher Job.
// Defaults to false.
Expand Down
10 changes: 0 additions & 10 deletions pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 55 additions & 15 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ const (
// PodGroupKind is the Kind name for the PodGroup.
PodGroupKind string = "PodGroup"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"

// Distributed envs for torchrun.
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
// TorchEnvNumNodes is the env name for the number of training nodes.
Expand All @@ -52,25 +72,45 @@ const (
// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"
// JobLauncher is the Job name for the launcher.
JobLauncher string = "launcher"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"
// ContainerLauncher is the container name for the launcher.
ContainerLauncher string = "launcher"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"
// MPISSHAuthSecretSuffix is the name suffix for Secret with MPI SSH keys.
MPISSHAuthSecretSuffix string = "-mpi-ssh-auth"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"
// MPISSHAuthVolumeName is the volume name for Secret with MPI SSH keys.
MPISSHAuthVolumeName string = "mpi-ssh-auth"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"
// MPISSHPrivateKeyFile is the file name for the private key.
MPISSHPrivateKeyFile string = "id_rsa"

// MPISSHPublicKey is the value in Secret data for the public key.
MPISSHPublicKey string = "ssh-publickey"

// MPISSHPublicKeyFile is the file name for the public key.
MPISSHPublicKeyFile string = MPISSHPrivateKeyFile + ".pub"

// MPISSHAuthorizedKeys is the file name for authorized keys.
MPISSHAuthorizedKeys string = "authorized_keys"

// MPIHostfilePath is the directory for the MPI hostfile.
MPIHostfileDir string = "/etc/mpi"

// MPIHostfileName is the file name for the MPI hostfile.
MPIHostfileName string = "hostfile"

// MPIHostfileConfigMapSuffix is the name suffix for ConfigMap with MPI hostfile.
MPIHostfileConfigMapSuffix string = "-mpi-hostfile"

// MPIHostfileVolumeName is the volume name for ConfigMap with MPI hostfile.
MPIHostfileVolumeName string = "mpi-hostfile"

// Distributed envs for mpirun.
// Values for OpenMPI implementation.
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe
return nil, err
}
if obj != nil {
objs = append(objs, obj)
objs = append(objs, obj...)
}
}
return objs, nil
Expand Down
7 changes: 5 additions & 2 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ func TestNew(t *testing.T) {
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
&jobset.JobSet{},
&mpi.MPI{},
},
componentBuilderPlugins: []framework.ComponentBuilderPlugin{
&coscheduling.CoScheduling{},
&jobset.JobSet{},
&mpi.MPI{},
},
terminalConditionPlugins: []framework.TerminalConditionPlugin{
&jobset.JobSet{},
Expand Down Expand Up @@ -528,11 +530,12 @@ func TestWatchExtensionPlugins(t *testing.T) {
registry fwkplugins.Registry
wantPlugins []framework.WatchExtensionPlugin
}{
"coscheding and jobset are performed": {
"coscheduling, jobset, and mpi are performed": {
registry: fwkplugins.NewRegistry(),
wantPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
&jobset.JobSet{},
&mpi.MPI{},
},
},
"an empty registry": {
Expand All @@ -541,7 +544,7 @@ func TestWatchExtensionPlugins(t *testing.T) {
}
cmpOpts := []cmp.Option{
cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }),
cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, jobset.JobSet{}),
cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, jobset.JobSet{}, mpi.MPI{}),
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions pkg/runtime/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type Plugin interface {
Name() string
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
}

type WatchExtensionPlugin interface {
Plugin
ReconcilerBuilders() []runtime.ReconcilerBuilder
Expand All @@ -47,14 +52,9 @@ type EnforceMLPolicyPlugin interface {
EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Plugin
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) (client.Object, error)
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change it to slice type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow plugins to return multiply objects to be created.
E.g. MPI Plugin creates Secret + ConfigMap:

return []client.Object{secret, configMap}, nil

}

type TerminalConditionPlugin interface {
Expand Down
Loading
Loading