Skip to content

Commit

Permalink
Rename kubeflowv1 to trainer pkg
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Feb 6, 2025
1 parent f6dd14b commit d15a9c5
Show file tree
Hide file tree
Showing 30 changed files with 309 additions and 309 deletions.
10 changes: 5 additions & 5 deletions cmd/trainer-controller-manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ import (
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/controller"
ctrlpkg "sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/healthz"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
"sigs.k8s.io/controller-runtime/pkg/webhook"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
kubeflowcontroller "github.com/kubeflow/trainer/pkg/controller"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/controller"
"github.com/kubeflow/trainer/pkg/runtime"
runtimecore "github.com/kubeflow/trainer/pkg/runtime/core"
"github.com/kubeflow/trainer/pkg/util/cert"
Expand All @@ -56,7 +56,7 @@ var (

func init() {
utilruntime.Must(clientgoscheme.AddToScheme(scheme))
utilruntime.Must(kubeflowv1.AddToScheme(scheme))
utilruntime.Must(trainer.AddToScheme(scheme))
utilruntime.Must(jobsetv1alpha2.AddToScheme(scheme))
utilruntime.Must(schedulerpluginsv1alpha1.AddToScheme(scheme))
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func setupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, cer
<-certsReady
setupLog.Info("Certs ready")

if failedCtrlName, err := kubeflowcontroller.SetupControllers(mgr, runtimes, controller.Options{}); err != nil {
if failedCtrlName, err := controller.SetupControllers(mgr, runtimes, ctrlpkg.Options{}); err != nil {
setupLog.Error(err, "Could not create controller", "controller", failedCtrlName)
os.Exit(1)
}
Expand Down
4 changes: 2 additions & 2 deletions hack/swagger/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"k8s.io/kube-openapi/pkg/common"
"k8s.io/kube-openapi/pkg/validation/spec"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
)

// Generate Kubeflow Training OpenAPI specification.
Expand All @@ -38,7 +38,7 @@ func main() {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name)))
}

for k, v := range kubeflowv1.GetOpenAPIDefinitions(refCallback) {
for k, v := range trainer.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

Expand Down
44 changes: 22 additions & 22 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/controller"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
jobruntimes "github.com/kubeflow/trainer/pkg/runtime"
)

Expand Down Expand Up @@ -71,7 +71,7 @@ func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder,
// +kubebuilder:rbac:groups=trainer.kubeflow.org,resources=trainjobs/finalizers,verbs=get;update;patch

func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
var trainJob kubeflowv1.TrainJob
var trainJob trainer.TrainJob
if err := r.client.Get(ctx, req.NamespacedName, &trainJob); err != nil {
return ctrl.Result{}, client.IgnoreNotFound(err)
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, err
}

func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv1.TrainJob) (objsOpState, error) {
func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) (objsOpState, error) {
log := ctrl.LoggerFrom(ctx)

objs, err := runtime.NewObjects(ctx, trainJob)
Expand Down Expand Up @@ -144,61 +144,61 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
return creationSucceeded, nil
}

func setCreatedCondition(trainJob *kubeflowv1.TrainJob, opState objsOpState) {
func setCreatedCondition(trainJob *trainer.TrainJob, opState objsOpState) {
var newCond metav1.Condition
switch opState {
case creationSucceeded:
newCond = metav1.Condition{
Type: kubeflowv1.TrainJobCreated,
Type: trainer.TrainJobCreated,
Status: metav1.ConditionTrue,
Message: constants.TrainJobJobsCreationSucceededMessage,
Reason: kubeflowv1.TrainJobJobsCreationSucceededReason,
Reason: trainer.TrainJobJobsCreationSucceededReason,
}
case buildFailed:
newCond = metav1.Condition{
Type: kubeflowv1.TrainJobCreated,
Type: trainer.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: constants.TrainJobJobsBuildFailedMessage,
Reason: kubeflowv1.TrainJobJobsBuildFailedReason,
Reason: trainer.TrainJobJobsBuildFailedReason,
}
// TODO (tenzen-y): Provide more granular message based on creation or update failure.
case creationFailed, updateFailed:
newCond = metav1.Condition{
Type: kubeflowv1.TrainJobCreated,
Type: trainer.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: constants.TrainJobJobsCreationFailedMessage,
Reason: kubeflowv1.TrainJobJobsCreationFailedReason,
Reason: trainer.TrainJobJobsCreationFailedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setSuspendedCondition(trainJob *kubeflowv1.TrainJob) {
func setSuspendedCondition(trainJob *trainer.TrainJob) {
var newCond metav1.Condition
switch {
case ptr.Deref(trainJob.Spec.Suspend, false):
newCond = metav1.Condition{
Type: kubeflowv1.TrainJobSuspended,
Type: trainer.TrainJobSuspended,
Status: metav1.ConditionTrue,
Message: constants.TrainJobSuspendedMessage,
Reason: kubeflowv1.TrainJobSuspendedReason,
Reason: trainer.TrainJobSuspendedReason,
}
case meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv1.TrainJobSuspended):
case meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobSuspended):
newCond = metav1.Condition{
Type: kubeflowv1.TrainJobSuspended,
Type: trainer.TrainJobSuspended,
Status: metav1.ConditionFalse,
Message: constants.TrainJobResumedMessage,
Reason: kubeflowv1.TrainJobResumedReason,
Reason: trainer.TrainJobResumedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv1.TrainJob) error {
func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
if err != nil {
return err
Expand All @@ -209,12 +209,12 @@ func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trai
return nil
}

func isTrainJobFinished(trainJob *kubeflowv1.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv1.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv1.TrainJobFailed)
func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
}

func runtimeRefToGroupKind(runtimeRef kubeflowv1.RuntimeRef) schema.GroupKind {
func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
Expand All @@ -224,7 +224,7 @@ func runtimeRefToGroupKind(runtimeRef kubeflowv1.RuntimeRef) schema.GroupKind {
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
b := ctrl.NewControllerManagedBy(mgr).
WithOptions(options).
For(&kubeflowv1.TrainJob{})
For(&trainer.TrainJob{})
for _, runtime := range r.runtimes {
for _, registrar := range runtime.EventHandlerRegistrars() {
if registrar != nil {
Expand Down
16 changes: 8 additions & 8 deletions pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
)

Expand All @@ -42,8 +42,8 @@ type ClusterTrainingRuntime struct {
var _ runtime.Runtime = (*ClusterTrainingRuntime)(nil)

var ClusterTrainingRuntimeGroupKind = schema.GroupKind{
Group: kubeflowv1.GroupVersion.Group,
Kind: kubeflowv1.ClusterTrainingRuntimeKind,
Group: trainer.GroupVersion.Group,
Kind: trainer.ClusterTrainingRuntimeKind,
}.String()

func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndexer) (runtime.Runtime, error) {
Expand All @@ -52,27 +52,27 @@ func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndex
}, nil
}

func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv1.TrainJob) ([]client.Object, error) {
var clTrainingRuntime kubeflowv1.ClusterTrainingRuntime
func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) {
var clTrainingRuntime trainer.ClusterTrainingRuntime
if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name}, &clTrainingRuntime); err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err)
}
return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy)
}

func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv1.TrainJob) (*metav1.Condition, error) {
func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return nil
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv1.TrainJob) (admission.Warnings, field.ErrorList) {
func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv1.ClusterTrainingRuntime{}); err != nil {
}, &trainer.ClusterTrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
Expand Down
14 changes: 7 additions & 7 deletions pkg/runtime/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
testingutil "github.com/kubeflow/trainer/pkg/util/testing"
)

Expand All @@ -39,8 +39,8 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
}

cases := map[string]struct {
trainJob *kubeflowv1.TrainJob
clusterTrainingRuntime *kubeflowv1.ClusterTrainingRuntime
trainJob *trainer.TrainJob
clusterTrainingRuntime *trainer.ClusterTrainingRuntime
wantObjs []client.Object
wantError error
}{
Expand All @@ -56,7 +56,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv1.SchemeGroupVersion.WithKind(kubeflowv1.ClusterTrainingRuntimeKind), "test-runtime").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
Container("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Expand All @@ -70,10 +70,10 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ControllerReference(kubeflowv1.SchemeGroupVersion.WithKind(kubeflowv1.TrainJobKind), "test-job", "uid").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv1.SchemeGroupVersion.WithKind(kubeflowv1.TrainJobKind), "test-job", "uid").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer.
MinResources(corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("101"), // Every replica has 1 CPU = 101 CPUs in total.
Expand All @@ -85,7 +85,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
"missing trainingRuntime resource": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
RuntimeRef(kubeflowv1.SchemeGroupVersion.WithKind(kubeflowv1.ClusterTrainingRuntimeKind), "test-runtime").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
Obj(),
Expand Down
22 changes: 11 additions & 11 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

kubeflowv1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
fwkcore "github.com/kubeflow/trainer/pkg/runtime/framework/core"
fwkplugins "github.com/kubeflow/trainer/pkg/runtime/framework/plugins"
Expand All @@ -45,19 +45,19 @@ type TrainingRuntime struct {
}

var TrainingRuntimeGroupKind = schema.GroupKind{
Group: kubeflowv1.GroupVersion.Group,
Kind: kubeflowv1.TrainingRuntimeKind,
Group: trainer.GroupVersion.Group,
Kind: trainer.TrainingRuntimeKind,
}.String()

var _ runtime.Runtime = (*TrainingRuntime)(nil)

var trainingRuntimeFactory *TrainingRuntime

func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) {
if err := indexer.IndexField(ctx, &kubeflowv1.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil {
if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err)
}
if err := indexer.IndexField(ctx, &kubeflowv1.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil {
if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err)
}
fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer)
Expand All @@ -71,8 +71,8 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie
return trainingRuntimeFactory, nil
}

func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv1.TrainJob) ([]client.Object, error) {
var trainingRuntime kubeflowv1.TrainingRuntime
func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) {
var trainingRuntime trainer.TrainingRuntime
err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime)
if err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedTrainingRuntime, err)
Expand All @@ -81,7 +81,7 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv1.T
}

func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *kubeflowv1.TrainJob, jobSetTemplateSpec kubeflowv1.JobSetTemplateSpec, mlPolicy *kubeflowv1.MLPolicy, podGroupPolicy *kubeflowv1.PodGroupPolicy,
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) ([]client.Object, error) {
propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
Expand Down Expand Up @@ -128,7 +128,7 @@ func (r *TrainingRuntime) buildObjects(
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv1.TrainJob) (*metav1.Condition, error) {
func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
}

Expand All @@ -140,11 +140,11 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return builders
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv1.TrainJob) (admission.Warnings, field.ErrorList) {
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv1.TrainingRuntime{}); err != nil {
}, &trainer.TrainingRuntime{}); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
Expand Down
Loading

0 comments on commit d15a9c5

Please sign in to comment.