diff --git a/pkg/apis/internal.admin.acorn.io/v1/default.go b/pkg/apis/internal.admin.acorn.io/v1/default.go index 90b803fe3..afc27a444 100644 --- a/pkg/apis/internal.admin.acorn.io/v1/default.go +++ b/pkg/apis/internal.admin.acorn.io/v1/default.go @@ -10,8 +10,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -func getCurrentClusterComputeClassDefault(ctx context.Context, c client.Client, projectDefaultComputeClass string) (*ClusterComputeClassInstance, error) { - clusterComputeClasses := ClusterComputeClassInstanceList{} +func getCurrentClusterComputeClassDefault(ctx context.Context, c client.Client, projectDefault string) (*ClusterComputeClassInstance, error) { + var clusterComputeClasses ClusterComputeClassInstanceList if err := c.List(ctx, &clusterComputeClasses, &client.ListOptions{}); err != nil { return nil, err } @@ -30,12 +30,10 @@ func getCurrentClusterComputeClassDefault(ctx context.Context, c client.Client, } // Create a new variable that isn't being iterated on to get a pointer - if projectDefaultComputeClass != "" { - defaultCCC = z.Pointer(clusterComputeClass) - } + defaultCCC = z.Pointer(clusterComputeClass) } - if clusterComputeClass.Name == projectDefaultComputeClass { + if clusterComputeClass.Name == projectDefault { projectDefaultCCC = z.Pointer(clusterComputeClass) } } @@ -47,8 +45,8 @@ func getCurrentClusterComputeClassDefault(ctx context.Context, c client.Client, return defaultCCC, nil } -func getCurrentProjectComputeClassDefault(ctx context.Context, c client.Client, projectDefaultComputeClass, namespace string) (*ProjectComputeClassInstance, error) { - projectComputeClasses := ProjectComputeClassInstanceList{} +func getCurrentProjectComputeClassDefault(ctx context.Context, c client.Client, projectDefault, namespace string) (*ProjectComputeClassInstance, error) { + var projectComputeClasses ProjectComputeClassInstanceList if err := c.List(ctx, &projectComputeClasses, &client.ListOptions{Namespace: namespace}); err != nil { return nil, err } @@ -67,12 +65,10 @@ func getCurrentProjectComputeClassDefault(ctx context.Context, c client.Client, } // Create a new variable that isn't being iterated on to get a pointer - if projectDefaultComputeClass != "" { - defaultPCC = z.Pointer(projectComputeClass) - } + defaultPCC = z.Pointer(projectComputeClass) } - if projectComputeClass.Name == projectDefaultComputeClass { + if projectDefault projectComputeClass.Name == projectDefault { projectDefaultPCC = z.Pointer(projectComputeClass) } } @@ -94,15 +90,18 @@ func GetDefaultComputeClass(ctx context.Context, c client.Client, namespace stri pcc, err := getCurrentProjectComputeClassDefault(ctx, c, projectDefault, namespace) if err != nil { return "", err - } else if pcc != nil { + } + if pcc != nil { return pcc.Name, nil } ccc, err := getCurrentClusterComputeClassDefault(ctx, c, projectDefault) if err != nil { return "", err - } else if ccc != nil { + } + if ccc != nil { return ccc.Name, nil } + return "", nil } diff --git a/pkg/server/registry/apigroups/acorn/projects/validator.go b/pkg/server/registry/apigroups/acorn/projects/validator.go index 8a7021fe8..e0b2da2e5 100644 --- a/pkg/server/registry/apigroups/acorn/projects/validator.go +++ b/pkg/server/registry/apigroups/acorn/projects/validator.go @@ -6,6 +6,8 @@ import ( "strings" apiv1 "github.com/acorn-io/runtime/pkg/apis/api.acorn.io/v1" + "github.com/acorn-io/runtime/pkg/computeclasses" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" @@ -22,15 +24,28 @@ type Validator struct { Client kclient.Client } -func (v *Validator) Validate(_ context.Context, obj runtime.Object) field.ErrorList { +func (v *Validator) Validate(ctx context.Context, obj runtime.Object) field.ErrorList { var result field.ErrorList project := obj.(*apiv1.Project) - if project.Spec.DefaultRegion != "" && !slices.Contains(project.Spec.SupportedRegions, project.Spec.DefaultRegion) && !slices.Contains(project.Spec.SupportedRegions, apiv1.AllRegions) { - return append(result, field.Invalid(field.NewPath("spec", "defaultRegion"), project.Spec.DefaultRegion, "default region is not in the supported regions list")) + if project.Spec.DefaultRegion != "" && + !slices.Contains(project.Spec.SupportedRegions, project.Spec.DefaultRegion) && + !slices.Contains(project.Spec.SupportedRegions, apiv1.AllRegions) { + result = append(result, field.Invalid(field.NewPath("spec", "defaultRegion"), project.Spec.DefaultRegion, "default region is not in the supported regions list")) } - return nil + if defaultComputeClass := project.Spec.DefaultComputeClass; defaultComputeClass != "" { + if _, err := computeclasses.GetAsProjectComputeClassInstance(ctx, v.Client, project.Name, defaultComputeClass); apierrors.IsNotFound(err) { + // The compute class does not exist, return an invalid error + result = append(result, field.Invalid(field.NewPath("spec", "defaultComputeClass"), defaultComputeClass, "default compute class does not exist")) + } else if err != nil { + // Some other error occurred while trying to get the compute class, return an internal error. + result = append(result, field.InternalError(field.NewPath("spec", "defaultComputeClass"), err)) + } + // TODO(njhale): Validate that the compute class shares the project's supported regions? + } + + return result } func (v *Validator) ValidateUpdate(ctx context.Context, newObj, _ runtime.Object) field.ErrorList {