Skip to content

Commit

Permalink
add shutdown rpc method
Browse files Browse the repository at this point in the history
  • Loading branch information
1ntEgr8 committed Dec 3, 2024
1 parent 48a9711 commit b1942ea
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
26 changes: 21 additions & 5 deletions rpc/launch_tpch_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from workload import JobGraph
from utils import EventTime
from data.tpch_loader import make_release_policy
from rpc import erdos_scheduler_pb2
from rpc import erdos_scheduler_pb2_grpc

import grpc


def map_dataset_to_deadline(dataset_size):
Expand Down Expand Up @@ -49,13 +53,14 @@ def launch_query(query_number, args):
# )

try:
cmd = ' '.join(cmd)
cmd = " ".join(cmd)
print("Launching:", cmd)
subprocess.Popen(
p = subprocess.Popen(
cmd,
shell=True,
)
print("Query launched successfully.")
return p
except Exception as e:
print(f"Error launching query: {e}")

Expand Down Expand Up @@ -187,7 +192,9 @@ def main():
default=1234,
help="RNG seed for generating inter-arrival periods and picking DAGs (default: 1234)",
)
parser.add_argument("--queries", type=int, nargs='+', help="Launch specific queries")
parser.add_argument(
"--queries", type=int, nargs="+", help="Launch specific queries"
)

args = parser.parse_args()

Expand All @@ -197,7 +204,7 @@ def main():
os.environ["TPCH_INPUT_DATA_DIR"] = str(args.tpch_spark_path.resolve() / "dbgen")

if args.queries:
assert(len(args.queries) == args.num_queries)
assert len(queries) == args.num_queries

rng = random.Random(args.rng_seed)

Expand All @@ -206,6 +213,7 @@ def main():
print("Release times:", release_times)

# Launch queries
ps = []
inter_arrival_times = [release_times[0].time]
for i in range(len(release_times) - 1):
inter_arrival_times.append(release_times[i + 1].time - release_times[i].time)
Expand All @@ -215,14 +223,22 @@ def main():
query_number = args.queries[i]
else:
query_number = rng.randint(1, 22)
launch_query(query_number, args)
ps.append(launch_query(query_number, args))
print(
"Current time: ",
time.strftime("%Y-%m-%d %H:%M:%S"),
" launching query: ",
query_number,
)

for p in ps:
p.wait()

channel = grpc.insecure_channel("localhost:50051")
stub = erdos_scheduler_pb2_grpc.SchedulerServiceStub(channel)
response = stub.Shutdown(erdos_scheduler_pb2.Empty())
channel.close()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions rpc/protos/rpc/erdos_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ service SchedulerService {

/// Notifies the Scheduler that a Task from a particular TaskGraph has completed.option
rpc NotifyTaskCompletion(NotifyTaskCompletionRequest) returns (NotifyTaskCompletionResponse) {}

rpc Shutdown(Empty) returns (Empty) {}
}


Expand Down Expand Up @@ -201,3 +203,5 @@ message GetPlacementsResponse {
string message = 3;
bool terminate = 4; // terminate the task graph
}

message Empty {}
15 changes: 13 additions & 2 deletions rpc/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def canonical_task_id(self, stage_id: int):


class Servicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer):
def __init__(self) -> None:
def __init__(self, server) -> None:
self._server = server

# Override some flags

# Enable orchestrated mode
Expand Down Expand Up @@ -230,6 +232,7 @@ def __init__(self) -> None:
self._registered_app_drivers = (
{}
) # Spark driver id differs from taskgraph name (application id)
self._shutdown = False
self._lock = threading.Lock()

super().__init__()
Expand Down Expand Up @@ -357,6 +360,10 @@ async def DeregisterDriver(self, request, context):

msg = f"[{stime}] Successfully de-registered driver with id {request.id} for task graph {task_graph_name}"
self._logger.info(msg)

if len(self._registered_app_drivers) == 0 and self._shutdown:
await self._server.stop(0)

return erdos_scheduler_pb2.DeregisterDriverResponse(
success=True,
message=msg,
Expand Down Expand Up @@ -755,6 +762,10 @@ async def NotifyTaskCompletion(self, request, context):
message=msg,
)

async def Shutdown(self, request, context):
self._shutdown = True
return erdos_scheduler_pb2.Empty()

async def _tick_simulator(self):
while True:
with self._lock:
Expand Down Expand Up @@ -819,7 +830,7 @@ def main(_argv):
loop = asyncio.get_event_loop()

server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=FLAGS.max_workers))
servicer = Servicer()
servicer = Servicer(server)
erdos_scheduler_pb2_grpc.add_SchedulerServiceServicer_to_server(servicer, server)
server.add_insecure_port(f"[::]:{FLAGS.port}")

Expand Down

0 comments on commit b1942ea

Please sign in to comment.