Skip to content

Commit

Permalink
[RayJob] use ValidateRayClusterSpec instead
Browse files Browse the repository at this point in the history
Signed-off-by: fscnick <[email protected]>
  • Loading branch information
fscnick committed Feb 24, 2025
1 parent a364c2d commit 864bb09
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (v *VolcanoBatchScheduler) Name() string {
func (v *VolcanoBatchScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, app *rayv1.RayCluster) error {
var minMember int32
var totalResource corev1.ResourceList
if !utils.IsAutoscalingEnabled(app) {
if !utils.IsAutoscalingEnabled(&app.Spec) {
minMember = utils.CalculateDesiredReplicas(ctx, app) + 1
totalResource = utils.CalculateDesiredResources(app)
} else {
Expand Down
2 changes: 1 addition & 1 deletion ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func DefaultHeadPodTemplate(ctx context.Context, instance rayv1.RayCluster, head
initTemplateAnnotations(instance, &podTemplate)

// if in-tree autoscaling is enabled, then autoscaler container should be injected into head pod.
if utils.IsAutoscalingEnabled(&instance) {
if utils.IsAutoscalingEnabled(&instance.Spec) {
// The default autoscaler is not compatible with Kubernetes. As a result, we disable
// the monitor process by default and inject a KubeRay autoscaler side container into the head pod.
headSpec.RayStartParams["no-monitor"] = "true"
Expand Down
14 changes: 7 additions & 7 deletions ray-operator/controllers/ray/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (r *RayClusterReconciler) rayClusterReconcile(ctx context.Context, instance
return ctrl.Result{}, nil
}

if err := utils.ValidateRayClusterSpec(instance); err != nil {
if err := utils.ValidateRayClusterSpec(&instance.Spec, instance.Annotations); err != nil {
logger.Error(err, fmt.Sprintf("The RayCluster spec is invalid %s/%s", instance.Namespace, instance.Name))
r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(utils.InvalidRayClusterSpec),
"The RayCluster spec is invalid %s/%s: %v", instance.Namespace, instance.Name, err)
Expand Down Expand Up @@ -859,7 +859,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
// diff < 0 indicates the need to delete some Pods to match the desired number of replicas. However,
// randomly deleting Pods is certainly not ideal. So, if autoscaling is enabled for the cluster, we
// will disable random Pod deletion, making Autoscaler the sole decision-maker for Pod deletions.
enableInTreeAutoscaling := utils.IsAutoscalingEnabled(instance)
enableInTreeAutoscaling := utils.IsAutoscalingEnabled(&instance.Spec)

// TODO (kevin85421): `enableRandomPodDelete` is a feature flag for KubeRay v0.6.0. If users want to use
// the old behavior, they can set the environment variable `ENABLE_RANDOM_POD_DELETE` to `true`. When the
Expand Down Expand Up @@ -1090,7 +1090,7 @@ func (r *RayClusterReconciler) buildHeadPod(ctx context.Context, instance rayv1.
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, instance, instance.Namespace) // Fully Qualified Domain Name
// The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.)
headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams)
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance)
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance.Spec)
podConf := common.DefaultHeadPodTemplate(ctx, instance, instance.Spec.HeadGroupSpec, podName, headPort)
if len(r.headSidecarContainers) > 0 {
podConf.Spec.Containers = append(podConf.Spec.Containers, r.headSidecarContainers...)
Expand Down Expand Up @@ -1118,7 +1118,7 @@ func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv

// The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.)
headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams)
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance)
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance.Spec)
podTemplateSpec := common.DefaultWorkerPodTemplate(ctx, instance, worker, podName, fqdnRayIP, headPort)
if len(r.workerSidecarContainers) > 0 {
podTemplateSpec.Spec.Containers = append(podTemplateSpec.Spec.Containers, r.workerSidecarContainers...)
Expand Down Expand Up @@ -1504,7 +1504,7 @@ func (r *RayClusterReconciler) updateHeadInfo(ctx context.Context, instance *ray

func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if !utils.IsAutoscalingEnabled(instance) {
if !utils.IsAutoscalingEnabled(&instance.Spec) {
return nil
}

Expand Down Expand Up @@ -1561,7 +1561,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Con

func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if !utils.IsAutoscalingEnabled(instance) {
if !utils.IsAutoscalingEnabled(&instance.Spec) {
return nil
}

Expand Down Expand Up @@ -1603,7 +1603,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, inst

func (r *RayClusterReconciler) reconcileAutoscalerRoleBinding(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if !utils.IsAutoscalingEnabled(instance) {
if !utils.IsAutoscalingEnabled(&instance.Spec) {
return nil
}

Expand Down
14 changes: 3 additions & 11 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,9 @@ func ManagedByExternalController(controllerName *string) *string {
return nil
}

func IsAutoscalingEnabled[T *rayv1.RayCluster | *rayv1.RayJob | *rayv1.RayService](obj T) bool {
switch obj := (interface{})(obj).(type) {
case *rayv1.RayCluster:
return obj.Spec.EnableInTreeAutoscaling != nil && *obj.Spec.EnableInTreeAutoscaling
case *rayv1.RayJob:
return obj.Spec.RayClusterSpec != nil && obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling
case *rayv1.RayService:
return obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling
default:
panic(fmt.Sprintf("unsupported type: %T", obj))
}
func IsAutoscalingEnabled(spec *rayv1.RayClusterSpec) bool {
return spec != nil && spec.EnableInTreeAutoscaling != nil &&
*spec.EnableInTreeAutoscaling
}

// Check if the RayCluster has GCS fault tolerance enabled.
Expand Down
12 changes: 6 additions & 6 deletions ray-operator/controllers/ray/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,18 +739,18 @@ func TestErrRayClusterReplicaFailureReason(t *testing.T) {
func TestIsAutoscalingEnabled(t *testing.T) {
// Test: RayCluster
cluster := &rayv1.RayCluster{}
assert.False(t, IsAutoscalingEnabled(cluster))
assert.False(t, IsAutoscalingEnabled(&cluster.Spec))

cluster = &rayv1.RayCluster{
Spec: rayv1.RayClusterSpec{
EnableInTreeAutoscaling: ptr.To[bool](true),
},
}
assert.True(t, IsAutoscalingEnabled(cluster))
assert.True(t, IsAutoscalingEnabled(&cluster.Spec))

// Test: RayJob
job := &rayv1.RayJob{}
assert.False(t, IsAutoscalingEnabled(job))
assert.False(t, IsAutoscalingEnabled(job.Spec.RayClusterSpec))

job = &rayv1.RayJob{
Spec: rayv1.RayJobSpec{
Expand All @@ -759,11 +759,11 @@ func TestIsAutoscalingEnabled(t *testing.T) {
},
},
}
assert.True(t, IsAutoscalingEnabled(job))
assert.True(t, IsAutoscalingEnabled(job.Spec.RayClusterSpec))

// Test: RayService
service := &rayv1.RayService{}
assert.False(t, IsAutoscalingEnabled(service))
assert.False(t, IsAutoscalingEnabled(&service.Spec.RayClusterSpec))

service = &rayv1.RayService{
Spec: rayv1.RayServiceSpec{
Expand All @@ -772,7 +772,7 @@ func TestIsAutoscalingEnabled(t *testing.T) {
},
},
}
assert.True(t, IsAutoscalingEnabled(service))
assert.True(t, IsAutoscalingEnabled(&service.Spec.RayClusterSpec))
}

func TestIsGCSFaultToleranceEnabled(t *testing.T) {
Expand Down
100 changes: 46 additions & 54 deletions ray-operator/controllers/ray/utils/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,67 @@ func ValidateRayClusterStatus(instance *rayv1.RayCluster) error {
}

// Validation for invalid Ray Cluster configurations.
func ValidateRayClusterSpec(instance *rayv1.RayCluster) error {
if len(instance.Spec.HeadGroupSpec.Template.Spec.Containers) == 0 {
func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]string) error {
if len(spec.HeadGroupSpec.Template.Spec.Containers) == 0 {
return fmt.Errorf("headGroupSpec should have at least one container")
}

for _, workerGroup := range instance.Spec.WorkerGroupSpecs {
for _, workerGroup := range spec.WorkerGroupSpecs {
if len(workerGroup.Template.Spec.Containers) == 0 {
return fmt.Errorf("workerGroupSpec should have at least one container")
}
}

if err := ValidateGCSFaultTolerance(&instance.Spec, instance.Annotations); err != nil {
return err
if annotations[RayFTEnabledAnnotationKey] != "" && spec.GcsFaultToleranceOptions != nil {
return fmt.Errorf("%s annotation and GcsFaultToleranceOptions are both set. "+
"Please use only GcsFaultToleranceOptions to configure GCS fault tolerance", RayFTEnabledAnnotationKey)
}

if !IsGCSFaultToleranceEnabled(spec, annotations) {
if EnvVarExists(RAY_REDIS_ADDRESS, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env) {
return fmt.Errorf("%s is set which implicitly enables GCS fault tolerance, "+
"but GcsFaultToleranceOptions is not set. Please set GcsFaultToleranceOptions "+
"to enable GCS fault tolerance", RAY_REDIS_ADDRESS)
}
}

headContainer := spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex]
if spec.GcsFaultToleranceOptions != nil {
if redisPassword := spec.HeadGroupSpec.RayStartParams["redis-password"]; redisPassword != "" {
return fmt.Errorf("cannot set `redis-password` in rayStartParams when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead")
}

if EnvVarExists(REDIS_PASSWORD, headContainer.Env) {
return fmt.Errorf("cannot set `REDIS_PASSWORD` env var in head Pod when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead")
}

if EnvVarExists(RAY_REDIS_ADDRESS, headContainer.Env) {
return fmt.Errorf("cannot set `RAY_REDIS_ADDRESS` env var in head Pod when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisAddress instead")
}

if annotations[RayExternalStorageNSAnnotationKey] != "" {
return fmt.Errorf("cannot set `ray.io/external-storage-namespace` annotation when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.ExternalStorageNamespace instead")
}
}
if spec.HeadGroupSpec.RayStartParams["redis-username"] != "" || EnvVarExists(REDIS_USERNAME, headContainer.Env) {
return fmt.Errorf("cannot set redis username in rayStartParams or environment variables" +
" - use GcsFaultToleranceOptions.RedisUsername instead")
}

if !features.Enabled(features.RayJobDeletionPolicy) {
for _, workerGroup := range instance.Spec.WorkerGroupSpecs {
for _, workerGroup := range spec.WorkerGroupSpecs {
if workerGroup.Suspend != nil && *workerGroup.Suspend {
return fmt.Errorf("suspending worker groups is currently available when the RayJobDeletionPolicy feature gate is enabled")
}
}
}

if IsAutoscalingEnabled(instance) {
for _, workerGroup := range instance.Spec.WorkerGroupSpecs {
if IsAutoscalingEnabled(spec) {
for _, workerGroup := range spec.WorkerGroupSpecs {
if workerGroup.Suspend != nil && *workerGroup.Suspend {
// TODO (rueian): This can be supported in future Ray. We should check the RayVersion once we know the version.
return fmt.Errorf("suspending worker groups is not currently supported with Autoscaler enabled")
Expand Down Expand Up @@ -78,7 +114,7 @@ func ValidateRayJobSpec(rayJob *rayv1.RayJob) error {
}

if rayJob.Spec.RayClusterSpec != nil {
if err := ValidateGCSFaultTolerance(rayJob.Spec.RayClusterSpec, rayJob.Annotations); err != nil {
if err := ValidateRayClusterSpec(rayJob.Spec.RayClusterSpec, rayJob.Annotations); err != nil {
return err
}
}
Expand Down Expand Up @@ -109,7 +145,7 @@ func ValidateRayJobSpec(rayJob *rayv1.RayJob) error {
}
}

if policy == rayv1.DeleteWorkersDeletionPolicy && IsAutoscalingEnabled(rayJob) {
if policy == rayv1.DeleteWorkersDeletionPolicy && IsAutoscalingEnabled(rayJob.Spec.RayClusterSpec) {
// TODO (rueian): This can be supported in a future Ray version. We should check the RayVersion once we know it.
return fmt.Errorf("DeletionPolicy=DeleteWorkers currently does not support RayCluster with autoscaling enabled")
}
Expand All @@ -135,47 +171,3 @@ func ValidateRayServiceSpec(rayService *rayv1.RayService) error {
}
return nil
}

func ValidateGCSFaultTolerance(spec *rayv1.RayClusterSpec, annotations map[string]string) error {
if annotations[RayFTEnabledAnnotationKey] != "" && spec.GcsFaultToleranceOptions != nil {
return fmt.Errorf("%s annotation and GcsFaultToleranceOptions are both set. "+
"Please use only GcsFaultToleranceOptions to configure GCS fault tolerance", RayFTEnabledAnnotationKey)
}

if !IsGCSFaultToleranceEnabled(spec, annotations) {
if EnvVarExists(RAY_REDIS_ADDRESS, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env) {
return fmt.Errorf("%s is set which implicitly enables GCS fault tolerance, "+
"but GcsFaultToleranceOptions is not set. Please set GcsFaultToleranceOptions "+
"to enable GCS fault tolerance", RAY_REDIS_ADDRESS)
}
}

headContainer := spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex]
if spec.GcsFaultToleranceOptions != nil {
if redisPassword := spec.HeadGroupSpec.RayStartParams["redis-password"]; redisPassword != "" {
return fmt.Errorf("cannot set `redis-password` in rayStartParams when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead")
}

if EnvVarExists(REDIS_PASSWORD, headContainer.Env) {
return fmt.Errorf("cannot set `REDIS_PASSWORD` env var in head Pod when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead")
}

if EnvVarExists(RAY_REDIS_ADDRESS, headContainer.Env) {
return fmt.Errorf("cannot set `RAY_REDIS_ADDRESS` env var in head Pod when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisAddress instead")
}

if annotations[RayExternalStorageNSAnnotationKey] != "" {
return fmt.Errorf("cannot set `ray.io/external-storage-namespace` annotation when " +
"GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.ExternalStorageNamespace instead")
}
}
if spec.HeadGroupSpec.RayStartParams["redis-username"] != "" || EnvVarExists(REDIS_USERNAME, headContainer.Env) {
return fmt.Errorf("cannot set redis username in rayStartParams or environment variables" +
" - use GcsFaultToleranceOptions.RedisUsername instead")
}

return nil
}
10 changes: 5 additions & 5 deletions ray-operator/controllers/ray/utils/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func TestValidateRayClusterSpecGcsFaultToleranceOptions(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateGCSFaultTolerance(&rayv1.RayClusterSpec{
err := ValidateRayClusterSpec(&rayv1.RayClusterSpec{
GcsFaultToleranceOptions: tt.gcsFaultToleranceOptions,
HeadGroupSpec: rayv1.HeadGroupSpec{
RayStartParams: tt.rayStartParams,
Expand Down Expand Up @@ -334,7 +334,7 @@ func TestValidateRayClusterSpecRedisPassword(t *testing.T) {
},
},
}
err := ValidateRayClusterSpec(rayCluster)
err := ValidateRayClusterSpec(&rayCluster.Spec, rayCluster.Annotations)
if tt.expectError {
require.Error(t, err)
} else {
Expand Down Expand Up @@ -404,7 +404,7 @@ func TestValidateRayClusterSpecRedisUsername(t *testing.T) {
},
},
}
err := ValidateRayClusterSpec(rayCluster)
err := ValidateRayClusterSpec(&rayCluster.Spec, rayCluster.Annotations)
if tt.expectError {
require.Error(t, err)
assert.EqualError(t, err, tt.errorMessage)
Expand Down Expand Up @@ -476,7 +476,7 @@ func TestValidateRayClusterSpecEmptyContainers(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateRayClusterSpec(tt.rayCluster)
err := ValidateRayClusterSpec(&tt.rayCluster.Spec, tt.rayCluster.Annotations)
if tt.expectError {
require.Error(t, err)
assert.EqualError(t, err, tt.errorMessage)
Expand Down Expand Up @@ -553,7 +553,7 @@ func TestValidateRayClusterSpecSuspendingWorkerGroup(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer features.SetFeatureGateDuringTest(t, features.RayJobDeletionPolicy, tt.featureGate)()
err := ValidateRayClusterSpec(tt.rayCluster)
err := ValidateRayClusterSpec(&tt.rayCluster.Spec, tt.rayCluster.Annotations)
if tt.expectError {
require.Error(t, err)
assert.EqualError(t, err, tt.errorMessage)
Expand Down

0 comments on commit 864bb09

Please sign in to comment.