diff --git a/python/ray/dashboard/modules/job/common.py b/python/ray/dashboard/modules/job/common.py index a1ca1202b3c50..14dae823148c3 100644 --- a/python/ray/dashboard/modules/job/common.py +++ b/python/ray/dashboard/modules/job/common.py @@ -428,6 +428,8 @@ class JobSubmitRequest: entrypoint_resources: Optional[Dict[str, float]] = None # Optional virtual cluster ID for job. virtual_cluster_id: Optional[str] = None + # Optional replica sets for job + replica_sets: Optional[Dict[str, int]] = None def __post_init__(self): if not isinstance(self.entrypoint, str): @@ -521,6 +523,23 @@ def __post_init__(self): f"got {type(self.virtual_cluster_id)}" ) + if self.replica_sets is not None: + if not isinstance(self.replica_sets, dict): + raise TypeError( + "replica_sets must be a dict, " f"got {type(self.replica_sets)}" + ) + else: + for k in self.replica_sets.keys(): + if not isinstance(k, str): + raise TypeError( + "replica_sets keys must be strings, " f"got {type(k)}" + ) + for v in self.replica_sets.values(): + if not isinstance(v, int): + raise TypeError( + "replica_sets values must be integers, " f"got {type(v)}" + ) + @dataclass class JobSubmitResponse: diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index 2e5f7ae6f4a73..2af6279866d20 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -21,6 +21,8 @@ upload_package_to_gcs, ) from ray._private.utils import get_or_create_event_loop +from ray.core.generated import gcs_service_pb2_grpc +from ray.core.generated.gcs_service_pb2 import CreateJobClusterRequest from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.modules.job.common import ( JobDeleteResponse, @@ -31,6 +33,7 @@ JobSubmitResponse, http_uri_components_to_uri, ) +from ray.dashboard.modules.job.job_manager import generate_job_id from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType from ray.dashboard.modules.job.utils import ( find_job_by_ids, @@ -163,6 +166,12 @@ def __init__(self, dashboard_head): self._gcs_aio_client = dashboard_head.gcs_aio_client self._job_info_client = None + self._gcs_virtual_cluster_info_stub = ( + gcs_service_pb2_grpc.VirtualClusterInfoGcsServiceStub( + dashboard_head.aiogrpc_gcs_channel + ) + ) + # It contains all `JobAgentSubmissionClient` that # `JobHead` has ever used, and will not be deleted # from it unless `JobAgentSubmissionClient` is no @@ -340,6 +349,30 @@ async def submit_job(self, req: Request) -> Response: self.get_target_agent(), timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, ) + + if ( + submit_request.virtual_cluster_id is not None + and submit_request.replica_sets is not None + and len(submit_request.replica_sets) > 0 + ): + # Use the submission ID or generate a new one + submission_id = submit_request.submission_id or submit_request.job_id + if submission_id is None: + submit_request.submission_id = generate_job_id() + job_cluster_id = await self._create_job_cluster( + submit_request.submission_id, + submit_request.virtual_cluster_id, + submit_request.replica_sets, + ) + # If cluster creation fails + if job_cluster_id is None: + return Response( + text="Create Job Cluster Failed.", + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + # Overwrite the virtual cluster ID in submit request + submit_request.virtual_cluster_id = job_cluster_id + resp = await job_agent_client.submit_job_internal(submit_request) except asyncio.TimeoutError: return Response( @@ -580,6 +613,21 @@ def get_job_driver_agent_client( return self._agents[driver_node_id] + async def _create_job_cluster(self, job_id, virtual_cluster_id, replica_sets): + request = CreateJobClusterRequest( + job_id=job_id, + virtual_cluster_id=virtual_cluster_id, + replica_sets=replica_sets, + ) + reply = await (self._gcs_virtual_cluster_info_stub.CreateJobCluster(request)) + if reply.status.code != 0: + logger.warning( + f"failed to create job cluster for {job_id} in" + f" {virtual_cluster_id}, message: {reply.status.message}" + ) + return None + return reply.job_cluster_id + async def run(self, server): if not self._job_info_client: self._job_info_client = JobInfoStorageClient(self._gcs_aio_client) diff --git a/python/ray/dashboard/modules/job/sdk.py b/python/ray/dashboard/modules/job/sdk.py index 7fa8b73003ee7..ae70f0f90280f 100644 --- a/python/ray/dashboard/modules/job/sdk.py +++ b/python/ray/dashboard/modules/job/sdk.py @@ -131,6 +131,7 @@ def submit_job( metadata: Optional[Dict[str, str]] = None, submission_id: Optional[str] = None, virtual_cluster_id: Optional[str] = None, + replica_sets: Optional[Dict[str, int]] = None, entrypoint_num_cpus: Optional[Union[int, float]] = None, entrypoint_num_gpus: Optional[Union[int, float]] = None, entrypoint_memory: Optional[int] = None, @@ -231,6 +232,7 @@ def submit_job( entrypoint=entrypoint, submission_id=submission_id, virtual_cluster_id=virtual_cluster_id, + replica_sets=replica_sets, runtime_env=runtime_env, metadata=metadata, entrypoint_num_cpus=entrypoint_num_cpus, diff --git a/python/ray/dashboard/modules/job/tests/test_job_with_virtual_cluster.py b/python/ray/dashboard/modules/job/tests/test_job_with_virtual_cluster.py index d2447dd11cf12..b98d1b5673316 100644 --- a/python/ray/dashboard/modules/job/tests/test_job_with_virtual_cluster.py +++ b/python/ray/dashboard/modules/job/tests/test_job_with_virtual_cluster.py @@ -31,6 +31,7 @@ from ray.tests.conftest import get_default_fixture_ray_kwargs TEMPLATE_ID_PREFIX = "template_id_" +kPrimaryClusterID = "kPrimaryClusterID" logger = logging.getLogger(__name__) @@ -131,11 +132,11 @@ async def create_virtual_cluster( [ { "_system_config": {"gcs_actor_scheduling_enabled": False}, - "ntemplates": 5, + "ntemplates": 3, }, { "_system_config": {"gcs_actor_scheduling_enabled": True}, - "ntemplates": 5, + "ntemplates": 3, }, ], indirect=True, @@ -145,7 +146,7 @@ async def test_mixed_virtual_cluster(job_sdk_client): head_client, gcs_address, cluster = job_sdk_client virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_" node_to_virtual_cluster = {} - ntemplates = 5 + ntemplates = 3 for i in range(ntemplates): virtual_cluster_id = virtual_cluster_id_prefix + str(i) nodes = await create_virtual_cluster( @@ -340,5 +341,229 @@ def _check_recover( head_client.stop_job(job_id) +@pytest.mark.parametrize( + "job_sdk_client", + [ + { + "_system_config": {"gcs_actor_scheduling_enabled": False}, + "ntemplates": 4, + }, + { + "_system_config": {"gcs_actor_scheduling_enabled": True}, + "ntemplates": 4, + }, + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_exclusive_virtual_cluster(job_sdk_client): + head_client, gcs_address, cluster = job_sdk_client + virtual_cluster_id_prefix = "VIRTUAL_CLUSTER_" + node_to_virtual_cluster = {} + ntemplates = 3 + for i in range(ntemplates): + virtual_cluster_id = virtual_cluster_id_prefix + str(i) + nodes = await create_virtual_cluster( + gcs_address, + virtual_cluster_id, + {TEMPLATE_ID_PREFIX + str(i): 2}, + AllocationMode.EXCLUSIVE, + ) + for node_id in nodes: + assert node_id not in node_to_virtual_cluster + node_to_virtual_cluster[node_id] = virtual_cluster_id + + for node in cluster.worker_nodes: + if node.node_id not in node_to_virtual_cluster: + node_to_virtual_cluster[node.node_id] = kPrimaryClusterID + + @ray.remote + class ControlActor: + def __init__(self): + self._nodes = set() + self._ready = False + + def ready(self): + self._ready = True + + def is_ready(self): + return self._ready + + def add_node(self, node_id): + self._nodes.add(node_id) + + def nodes(self): + return self._nodes + + for i in range(ntemplates + 1): + actor_name = f"test_actors_{i}" + pg_name = f"test_pgs_{i}" + control_actor_name = f"control_{i}" + virtual_cluster_id = virtual_cluster_id_prefix + str(i) + if i == ntemplates: + virtual_cluster_id = kPrimaryClusterID + control_actor = ControlActor.options( + name=control_actor_name, namespace="control" + ).remote() + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) + driver_script = """ +import ray +import time +import asyncio + +ray.init(address="auto") + +control = ray.get_actor(name="{control_actor_name}", namespace="control") + + +@ray.remote(max_restarts=10) +class Actor: + def __init__(self, control, pg): + node_id = ray.get_runtime_context().get_node_id() + ray.get(control.add_node.remote(node_id)) + self._pg = pg + + async def run(self, control): + node_id = ray.get_runtime_context().get_node_id() + await control.add_node.remote(node_id) + + while True: + node_id = ray.util.placement_group_table(self._pg)["bundles_to_node_id"][0] + if node_id == "": + await asyncio.sleep(1) + continue + break + + await control.add_node.remote(node_id) + + await control.ready.remote() + while True: + await asyncio.sleep(1) + + async def get_node_id(self): + while True: + node_id = ray.util.placement_group_table(pg)["bundles_to_node_id"][0] + if node_id == "": + await asyncio.sleep(1) + continue + break + return (ray.get_runtime_context().get_node_id(), node_id) + + +pg = ray.util.placement_group( + bundles=[{{"CPU": 1}}], name="{pg_name}", lifetime="detached" +) + + +@ray.remote +def hello(control): + node_id = ray.get_runtime_context().get_node_id() + ray.get(control.add_node.remote(node_id)) + + +ray.get(hello.remote(control)) +a = Actor.options(name="{actor_name}", + namespace="control", + num_cpus=1, + lifetime="detached").remote( + control, pg +) +ray.get(a.run.remote(control)) + """ + driver_script = driver_script.format( + actor_name=actor_name, + pg_name=pg_name, + control_actor_name=control_actor_name, + ) + test_script_file = path / "test_script.py" + with open(test_script_file, "w+") as file: + file.write(driver_script) + + runtime_env = {"working_dir": tmp_dir} + runtime_env = upload_working_dir_if_needed( + runtime_env, tmp_dir, logger=logger + ) + runtime_env = RuntimeEnv(**runtime_env).to_dict() + + job_id = head_client.submit_job( + entrypoint="python test_script.py", + entrypoint_memory=1, + runtime_env=runtime_env, + virtual_cluster_id=virtual_cluster_id, + replica_sets={TEMPLATE_ID_PREFIX + str(i): 2}, + ) + + def _check_ready(control_actor): + return ray.get(control_actor.is_ready.remote()) + + wait_for_condition(partial(_check_ready, control_actor), timeout=20) + + def _check_virtual_cluster( + control_actor, node_to_virtual_cluster, virtual_cluster_id + ): + nodes = ray.get(control_actor.nodes.remote()) + assert len(nodes) > 0 + for node in nodes: + assert node_to_virtual_cluster[node] == virtual_cluster_id + return True + + wait_for_condition( + partial( + _check_virtual_cluster, + control_actor, + node_to_virtual_cluster, + virtual_cluster_id, + ), + timeout=20, + ) + + supervisor_actor = ray.get_actor( + name=JOB_ACTOR_NAME_TEMPLATE.format(job_id=job_id), + namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE, + ) + actor_info = ray.state.actors(supervisor_actor._actor_id.hex()) + driver_node_id = actor_info["Address"]["NodeID"] + assert node_to_virtual_cluster[driver_node_id] == virtual_cluster_id + + job_info = head_client.get_job_info(job_id) + assert ( + node_to_virtual_cluster[job_info.driver_node_id] == virtual_cluster_id + ) + + nodes_to_remove = ray.get(control_actor.nodes.remote()) + if driver_node_id in nodes_to_remove: + nodes_to_remove.remove(driver_node_id) + + to_remove = [] + for node in cluster.worker_nodes: + if node.node_id in nodes_to_remove: + to_remove.append(node) + for node in to_remove: + cluster.remove_node(node) + + def _check_recover( + nodes_to_remove, actor_name, node_to_virtual_cluster, virtual_cluster_id + ): + actor = ray.get_actor(actor_name, namespace="control") + nodes = ray.get(actor.get_node_id.remote()) + for node_id in nodes: + assert node_id not in nodes_to_remove + assert node_to_virtual_cluster[node_id] == virtual_cluster_id + return True + + wait_for_condition( + partial( + _check_recover, + nodes_to_remove, + actor_name, + node_to_virtual_cluster, + virtual_cluster_id, + ), + timeout=120, + ) + head_client.stop_job(job_id) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/common/virtual_cluster_id.h b/src/ray/common/virtual_cluster_id.h index 16f5ca315351d..7f1a0ecb28f89 100644 --- a/src/ray/common/virtual_cluster_id.h +++ b/src/ray/common/virtual_cluster_id.h @@ -25,8 +25,8 @@ class VirtualClusterID : public SimpleID { public: using SimpleID::SimpleID; - VirtualClusterID BuildJobClusterID(const std::string &job_name) const { - return VirtualClusterID::FromBinary(id_ + kJobClusterIDSeperator + job_name); + VirtualClusterID BuildJobClusterID(const std::string &job_id) const { + return VirtualClusterID::FromBinary(id_ + kJobClusterIDSeperator + job_id); } bool IsJobClusterID() const { diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 9ecb8034903a3..dccebe4dfd4c1 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -387,6 +387,9 @@ void GcsServer::InitClusterResourceScheduler() { auto node_instance_id = NodeID::FromBinary(node_id.Binary()).Hex(); auto virtual_cluster = gcs_virtual_cluster_manager_->GetVirtualCluster(context->virtual_cluster_id); + if (virtual_cluster == nullptr) { + return true; + } RAY_CHECK(virtual_cluster->GetMode() == rpc::AllocationMode::MIXED); // Check if the node is contained within the specified virtual cluster. return virtual_cluster->ContainsNodeInstance(node_instance_id); diff --git a/src/ray/gcs/gcs_server/gcs_virtual_cluster.h b/src/ray/gcs/gcs_server/gcs_virtual_cluster.h index f7b0188f12c66..b5d45c2338256 100644 --- a/src/ray/gcs/gcs_server/gcs_virtual_cluster.h +++ b/src/ray/gcs/gcs_server/gcs_virtual_cluster.h @@ -226,10 +226,10 @@ class ExclusiveCluster : public VirtualCluster { /// Build the job cluster id. /// - /// \param job_name The name of the job. + /// \param job_id The name of the job. /// \return The job cluster id. - std::string BuildJobClusterID(const std::string &job_name) { - return VirtualClusterID::FromBinary(GetID()).BuildJobClusterID(job_name).Binary(); + std::string BuildJobClusterID(const std::string &job_id) { + return VirtualClusterID::FromBinary(GetID()).BuildJobClusterID(job_id).Binary(); } /// Create a job cluster. diff --git a/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.cc b/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.cc index 49fae62859df1..53c29e31a0d2a 100644 --- a/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.cc @@ -121,6 +121,54 @@ void GcsVirtualClusterManager::HandleGetVirtualClusters( GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); } +void GcsVirtualClusterManager::HandleCreateJobCluster( + rpc::CreateJobClusterRequest request, + rpc::CreateJobClusterReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const auto &virtual_cluster_id = request.virtual_cluster_id(); + RAY_LOG(INFO) << "Start creating job cluster in virtual cluster: " + << virtual_cluster_id; + auto virtual_cluster = GetVirtualCluster(virtual_cluster_id); + if (virtual_cluster == nullptr) { + std::ostringstream ostr; + ostr << "Create job cluster for job " << request.job_id() + << " failed, virtual cluster not exists: " << virtual_cluster_id; + std::string message = ostr.str(); + RAY_LOG(ERROR) << message; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::NotFound(message)); + return; + } + if (virtual_cluster->GetMode() != rpc::AllocationMode::EXCLUSIVE) { + std::ostringstream ostr; + ostr << "Create job cluster for job " << request.job_id() + << " failed, virtual cluster is not exclusive: " << virtual_cluster_id; + std::string message = ostr.str(); + RAY_LOG(ERROR) << message; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::InvalidArgument(message)); + return; + } + ReplicaSets replica_sets(request.replica_sets().begin(), request.replica_sets().end()); + + auto exclusive_cluster = dynamic_cast(virtual_cluster.get()); + const std::string &job_cluster_id = + exclusive_cluster->BuildJobClusterID(request.job_id()); + + exclusive_cluster->CreateJobCluster( + job_cluster_id, + std::move(replica_sets), + [reply, send_reply_callback, job_id = request.job_id()]( + const Status &status, std::shared_ptr data) { + if (status.ok()) { + reply->set_job_cluster_id(data->id()); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + } else { + RAY_LOG(ERROR) << "Create job cluster for job " << job_id << " failed, " + << status.message(); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + } + }); +} + Status GcsVirtualClusterManager::VerifyRequest( const rpc::CreateOrUpdateVirtualClusterRequest &request) { const auto &virtual_cluster_id = request.virtual_cluster_id(); diff --git a/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.h b/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.h index de31a2c4d2e54..6759b6164a905 100644 --- a/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.h +++ b/src/ray/gcs/gcs_server/gcs_virtual_cluster_manager.h @@ -72,6 +72,10 @@ class GcsVirtualClusterManager : public rpc::VirtualClusterInfoHandler { rpc::GetVirtualClustersReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleCreateJobCluster(rpc::CreateJobClusterRequest request, + rpc::CreateJobClusterReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + Status VerifyRequest(const rpc::CreateOrUpdateVirtualClusterRequest &request); Status VerifyRequest(const rpc::RemoveVirtualClusterRequest &request); diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 49f7c41831988..2e34ffe555ac2 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -861,6 +861,21 @@ message GetVirtualClustersReply { repeated VirtualClusterTableData virtual_cluster_data_list = 2; } +message CreateJobClusterRequest { + // The job id. + string job_id = 1; + // ID of the virtual cluster that the job belongs to. + string virtual_cluster_id = 2; + // The replica set of the job cluster. + map replica_sets = 3; +} + +message CreateJobClusterReply { + GcsStatus status = 1; + // The job cluster id. + string job_cluster_id = 2; +} + service VirtualClusterInfoGcsService { // Create or update a virtual cluster. rpc CreateOrUpdateVirtualCluster(CreateOrUpdateVirtualClusterRequest) @@ -869,4 +884,6 @@ service VirtualClusterInfoGcsService { rpc RemoveVirtualCluster(RemoveVirtualClusterRequest) returns (RemoveVirtualClusterReply); // Get virtual clusters. rpc GetVirtualClusters(GetVirtualClustersRequest) returns (GetVirtualClustersReply); + // Create job cluster. + rpc CreateJobCluster(CreateJobClusterRequest) returns (CreateJobClusterReply); } diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 8c1de4d67f4f5..5d7d1bbc9fa56 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -571,6 +571,12 @@ class GcsRpcClient { virtual_cluster_info_grpc_client_, /*method_timeout_ms*/ -1, ) + // Create job cluster. + VOID_GCS_RPC_CLIENT_METHOD(VirtualClusterInfoGcsService, + CreateJobCluster, + virtual_cluster_info_grpc_client_, + /*method_timeout_ms*/ -1, ) + std::pair GetAddress() const { return std::make_pair(gcs_address_, gcs_port_); } diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index c8404c28882b8..4295fc66a17ef 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -733,6 +733,10 @@ class VirtualClusterInfoGcsServiceHandler { virtual void HandleGetVirtualClusters(GetVirtualClustersRequest request, GetVirtualClustersReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleCreateJobCluster(CreateJobClusterRequest request, + CreateJobClusterReply *reply, + SendReplyCallback send_reply_callback) = 0; }; class VirtualClusterInfoGrpcService : public GrpcService { @@ -754,6 +758,7 @@ class VirtualClusterInfoGrpcService : public GrpcService { VIRTUAL_CLUSTER_SERVICE_RPC_HANDLER(CreateOrUpdateVirtualCluster); VIRTUAL_CLUSTER_SERVICE_RPC_HANDLER(RemoveVirtualCluster); VIRTUAL_CLUSTER_SERVICE_RPC_HANDLER(GetVirtualClusters); + VIRTUAL_CLUSTER_SERVICE_RPC_HANDLER(CreateJobCluster); } private: