Skip to content

Commit

Permalink
Update unit tests to use apply configurations
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Feb 14, 2025
1 parent b3a865e commit c9ff23c
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 209 deletions.
26 changes: 17 additions & 9 deletions pkg/runtime/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"k8s.io/apimachinery/pkg/runtime"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand All @@ -41,7 +41,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
cases := map[string]struct {
trainJob *trainer.TrainJob
clusterTrainingRuntime *trainer.ClusterTrainingRuntime
wantObjs []client.Object
wantObjs []runtime.Object
wantError error
}{
"succeeded to build PodGroup and JobSet with NumNodes from the Runtime and container from the Trainer.": {
Expand All @@ -63,23 +63,23 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
InitContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
NumNodes(100).
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer.
MinResources(corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("101"), // Every replica has 1 CPU = 101 CPUs in total.
}).
SchedulingTimeout(120).
Obj(),
Obj(t),
},
},
"missing trainingRuntime resource": {
Expand All @@ -95,7 +95,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
},
}
cmpOpts := []cmp.Option{
cmpopts.SortSlices(func(a, b client.Object) bool {
cmpopts.SortSlices(func(a, b runtime.Object) bool {
return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String()
}),
cmpopts.EquateEmpty(),
Expand All @@ -109,8 +109,9 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
if tc.clusterTrainingRuntime != nil {
clientBuilder.WithObjects(tc.clusterTrainingRuntime)
}
c := clientBuilder.Build()

trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder))
trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder))
if err != nil {
t.Fatal(err)
}
Expand All @@ -120,15 +121,22 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
t.Fatal("Failed type assertion from Runtime interface to TrainingRuntime")
}

clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder))
clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder))
if err != nil {
t.Fatal(err)
}

objs, err := clTrainingRuntime.NewObjects(ctx, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 {

resultObjs, err := testingutil.ToObject(c.Scheme(), objs...)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 {
t.Errorf("Unexpected objects (-want,+got):\n%s", diff)
}
})
Expand Down
43 changes: 27 additions & 16 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"k8s.io/apimachinery/pkg/runtime"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand All @@ -44,7 +44,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
cases := map[string]struct {
trainingRuntime *trainer.TrainingRuntime
trainJob *trainer.TrainJob
wantObjs []client.Object
wantObjs []runtime.Object
wantError error
}{
// Test cases for the PlainML MLPolicy.
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
InitContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
NumNodes(30).
Expand All @@ -82,7 +82,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Annotation("conflictAnnotation", "override").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
MinMember(31). // 31 replicas = 30 Trainer nodes + 1 Initializer.
Expand All @@ -93,7 +93,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
corev1.ResourceCPU: resource.MustParse("31"),
}).
SchedulingTimeout(120).
Obj(),
Obj(t),
},
},
"succeeded to build JobSet with NumNodes from the Runtime and container from the TrainJob.": {
Expand Down Expand Up @@ -136,7 +136,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
NumNodes(100).
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Expand All @@ -157,7 +157,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
},
},
"succeeded to build JobSet with dataset and model initializer from the TrainJob.": {
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
NumNodes(100).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
},
},
// Test cases for the Torch MLPolicy.
Expand All @@ -277,7 +277,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
NumNodes(30).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Expand Down Expand Up @@ -311,7 +311,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
},
},
"succeeded to build JobSet with Torch values from the Runtime and envs.": {
Expand Down Expand Up @@ -354,7 +354,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).
Obj(),
wantObjs: []client.Object{
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
NumNodes(100).
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Expand Down Expand Up @@ -400,7 +400,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
Obj(t),
},
},
// Failed test cases.
Expand All @@ -417,9 +417,12 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
}
cmpOpts := []cmp.Option{
cmpopts.SortSlices(func(a, b client.Object) bool {
cmpopts.SortSlices(func(a, b runtime.Object) bool {
return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String()
}),
cmpopts.SortSlices(func(a, b corev1.EnvVar) bool {
return a.Name < b.Name
}),
cmpopts.EquateEmpty(),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
}
Expand All @@ -431,16 +434,24 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
if tc.trainingRuntime != nil {
clientBuilder.WithObjects(tc.trainingRuntime)
}
c := clientBuilder.Build()

trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder))
trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder))
if err != nil {
t.Fatal(err)
}

objs, err := trainingRuntime.NewObjects(ctx, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 {

resultObjs, err := testingutil.ToObject(c.Scheme(), objs...)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 {
t.Errorf("Unexpected objects (-want,+got):\n%s", diff)
}
})
Expand Down
Loading

0 comments on commit c9ff23c

Please sign in to comment.