From b1942ea491ea33d0810e9aa8c6536d97add8c39e Mon Sep 17 00:00:00 2001 From: Elton Leander Pinto Date: Mon, 2 Dec 2024 22:59:36 -0500 Subject: [PATCH] add shutdown rpc method --- rpc/launch_tpch_queries.py | 26 +++++++++++++++++++++----- rpc/protos/rpc/erdos_scheduler.proto | 4 ++++ rpc/service.py | 15 +++++++++++++-- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/rpc/launch_tpch_queries.py b/rpc/launch_tpch_queries.py index 09beae64..73fd3a44 100644 --- a/rpc/launch_tpch_queries.py +++ b/rpc/launch_tpch_queries.py @@ -11,6 +11,10 @@ from workload import JobGraph from utils import EventTime from data.tpch_loader import make_release_policy +from rpc import erdos_scheduler_pb2 +from rpc import erdos_scheduler_pb2_grpc + +import grpc def map_dataset_to_deadline(dataset_size): @@ -49,13 +53,14 @@ def launch_query(query_number, args): # ) try: - cmd = ' '.join(cmd) + cmd = " ".join(cmd) print("Launching:", cmd) - subprocess.Popen( + p = subprocess.Popen( cmd, shell=True, ) print("Query launched successfully.") + return p except Exception as e: print(f"Error launching query: {e}") @@ -187,7 +192,9 @@ def main(): default=1234, help="RNG seed for generating inter-arrival periods and picking DAGs (default: 1234)", ) - parser.add_argument("--queries", type=int, nargs='+', help="Launch specific queries") + parser.add_argument( + "--queries", type=int, nargs="+", help="Launch specific queries" + ) args = parser.parse_args() @@ -197,7 +204,7 @@ def main(): os.environ["TPCH_INPUT_DATA_DIR"] = str(args.tpch_spark_path.resolve() / "dbgen") if args.queries: - assert(len(args.queries) == args.num_queries) + assert len(queries) == args.num_queries rng = random.Random(args.rng_seed) @@ -206,6 +213,7 @@ def main(): print("Release times:", release_times) # Launch queries + ps = [] inter_arrival_times = [release_times[0].time] for i in range(len(release_times) - 1): inter_arrival_times.append(release_times[i + 1].time - release_times[i].time) @@ -215,7 +223,7 @@ def main(): query_number = args.queries[i] else: query_number = rng.randint(1, 22) - launch_query(query_number, args) + ps.append(launch_query(query_number, args)) print( "Current time: ", time.strftime("%Y-%m-%d %H:%M:%S"), @@ -223,6 +231,14 @@ def main(): query_number, ) + for p in ps: + p.wait() + + channel = grpc.insecure_channel("localhost:50051") + stub = erdos_scheduler_pb2_grpc.SchedulerServiceStub(channel) + response = stub.Shutdown(erdos_scheduler_pb2.Empty()) + channel.close() + if __name__ == "__main__": main() diff --git a/rpc/protos/rpc/erdos_scheduler.proto b/rpc/protos/rpc/erdos_scheduler.proto index e49ec8c4..262254da 100644 --- a/rpc/protos/rpc/erdos_scheduler.proto +++ b/rpc/protos/rpc/erdos_scheduler.proto @@ -47,6 +47,8 @@ service SchedulerService { /// Notifies the Scheduler that a Task from a particular TaskGraph has completed.option rpc NotifyTaskCompletion(NotifyTaskCompletionRequest) returns (NotifyTaskCompletionResponse) {} + + rpc Shutdown(Empty) returns (Empty) {} } @@ -201,3 +203,5 @@ message GetPlacementsResponse { string message = 3; bool terminate = 4; // terminate the task graph } + +message Empty {} diff --git a/rpc/service.py b/rpc/service.py index 6ab9b143..6ebe5d48 100644 --- a/rpc/service.py +++ b/rpc/service.py @@ -136,7 +136,9 @@ def canonical_task_id(self, stage_id: int): class Servicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer): - def __init__(self) -> None: + def __init__(self, server) -> None: + self._server = server + # Override some flags # Enable orchestrated mode @@ -230,6 +232,7 @@ def __init__(self) -> None: self._registered_app_drivers = ( {} ) # Spark driver id differs from taskgraph name (application id) + self._shutdown = False self._lock = threading.Lock() super().__init__() @@ -357,6 +360,10 @@ async def DeregisterDriver(self, request, context): msg = f"[{stime}] Successfully de-registered driver with id {request.id} for task graph {task_graph_name}" self._logger.info(msg) + + if len(self._registered_app_drivers) == 0 and self._shutdown: + await self._server.stop(0) + return erdos_scheduler_pb2.DeregisterDriverResponse( success=True, message=msg, @@ -755,6 +762,10 @@ async def NotifyTaskCompletion(self, request, context): message=msg, ) + async def Shutdown(self, request, context): + self._shutdown = True + return erdos_scheduler_pb2.Empty() + async def _tick_simulator(self): while True: with self._lock: @@ -819,7 +830,7 @@ def main(_argv): loop = asyncio.get_event_loop() server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=FLAGS.max_workers)) - servicer = Servicer() + servicer = Servicer(server) erdos_scheduler_pb2_grpc.add_SchedulerServiceServicer_to_server(servicer, server) server.add_insecure_port(f"[::]:{FLAGS.port}")