Skip to content

Commit

Permalink
create task graph after environment is ready
Browse files Browse the repository at this point in the history
  • Loading branch information
1ntEgr8 committed Nov 22, 2024
1 parent 6264772 commit bf658fe
Showing 1 changed file with 79 additions and 49 deletions.
128 changes: 79 additions & 49 deletions rpc/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import time
import asyncio
from concurrent import futures
from collections import namedtuple
from urllib.parse import urlparse
from typing import Optional
from typing import Optional, Dict
from enum import Enum
from dataclasses import dataclass

# TODO: refactor out the need to import main to get common flags
import main
Expand Down Expand Up @@ -74,7 +74,25 @@ def get_next_workload(self, current_time: EventTime) -> Optional[Workload]:
return self._workload


RegisteredTaskGraph = namedtuple("RegisteredTaskGraph", ["graph", "stage_id_mapping"])
# TODO(elton): rename to RegisteredApplication
# TODO(elton): write documentation on how to use


@dataclass
class RegisteredTaskGraph:
gen: any # TODO(elton): proper type
task_graph: TaskGraph = None
stage_id_mapping: any = None # TODO(elton): proper type
last_gen: any = None # TODO(elton): proper type

def __init__(self, gen):
self.gen = gen

def generate_task_graph(self, release_time):
task_graph, stage_id_mapping = self.gen(release_time)
self.task_graph = task_graph
self.stage_id_mapping = stage_id_mapping
self.last_gen = release_time


class Servicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer):
Expand All @@ -84,7 +102,7 @@ def __init__(self) -> None:
log_dir=FLAGS.log_dir,
log_file=FLAGS.log_file_name,
log_level=FLAGS.log_level,
fmt='[%(asctime)s] {%(funcName)s:%(lineno)d} - %(message)s',
fmt="[%(asctime)s] {%(funcName)s:%(lineno)d} - %(message)s",
)
self._csv_logger = setup_csv_logging(
name=__name__,
Expand Down Expand Up @@ -185,10 +203,8 @@ async def DeregisterFramework(self, request, context):

async def RegisterDriver(self, request, context):
stime = self.__stime()

msg = (
f"[{stime}] Successfully registered driver for app id {request.id}"
)

msg = f"[{stime}] Successfully registered driver for app id {request.id}"
self._logger.info(msg)
return erdos_scheduler_pb2.RegisterDriverResponse(
success=True,
Expand All @@ -198,7 +214,7 @@ async def RegisterDriver(self, request, context):

async def DeregisterDriver(self, request, context):
stime = self.__stime()

if request.id not in self._registered_task_graphs:
msg = f"[{stime}] Task graph of id '{request.id}' is not registered or does not exist"
self._logger.error(msg)
Expand Down Expand Up @@ -261,34 +277,37 @@ async def RegisterTaskGraph(self, request, context):
}
)

# Construct the task graph
try:
task_graph, stage_id_mapping = self._data_loaders[
DataLoader.TPCH
].make_task_graph(
id=request.id,
query_num=query_num,
release_time=stime,
dependencies=dependencies,
dataset_size=dataset_size,
max_executors_per_job=max_executors_per_job,
runtime_unit=EventTime.Unit.S,
)
except Exception as e:
msg = f"[{stime}] Failed to load TPCH query {query_num}. Exception: {e}"
return erdos_scheduler_pb2.RegisterTaskGraphResponse(
success=False, message=msg, num_executors=0
)
def gen(release_time):
# Construct the task graph
try:
task_graph, stage_id_mapping = self._data_loaders[
DataLoader.TPCH
].make_task_graph(
id=request.id,
query_num=query_num,
release_time=release_time,
dependencies=dependencies,
dataset_size=dataset_size,
max_executors_per_job=max_executors_per_job,
runtime_unit=EventTime.Unit.S,
)
except Exception as e:
msg = f"[{stime}] Failed to load TPCH query {query_num}. Exception: {e}"
return erdos_scheduler_pb2.RegisterTaskGraphResponse(
success=False, message=msg, num_executors=0
)

return task_graph, stage_id_mapping

else:
msg = f"[{stime}] The service only supports TPCH queries"
return erdos_scheduler_pb2.RegisterTaskGraphResponse(
success=False, message=msg, num_executors=0
)

self._registered_task_graphs[request.id] = RegisteredTaskGraph(
task_graph, stage_id_mapping
)
msg = f"[{stime}] Registered task graph '{task_graph.name}' successfully"
self._registered_task_graphs[request.id] = RegisteredTaskGraph(gen=gen)

msg = f"[{stime}] Registered task graph '{request.id}' successfully"
self._logger.info(msg)
return erdos_scheduler_pb2.RegisterTaskGraphResponse(
success=True,
Expand All @@ -307,9 +326,12 @@ async def RegisterEnvironmentReady(self, request, context):
message=msg,
)

task_graph = self._registered_task_graphs[request.id].graph
r = self._registered_task_graphs[request.id]

# Generate the task graph now
r.generate_task_graph(stime)

self._workload_loader.add_task_graph(task_graph)
self._workload_loader.add_task_graph(r.task_graph)

update_workload_event = Event(
event_type=EventType.UPDATE_WORKLOAD,
Expand All @@ -324,7 +346,7 @@ async def RegisterEnvironmentReady(self, request, context):
self._simulator._event_queue.add_event(update_workload_event)
self._simulator._event_queue.add_event(scheduler_start_event)

msg = f"[{stime}] Successfully marked environment as ready for task graph '{task_graph.name}'"
msg = f"[{stime}] Successfully marked environment as ready for task graph '{r.task_graph.name}'"
self._logger.info(msg)
return erdos_scheduler_pb2.RegisterEnvironmentReadyResponse(
success=True,
Expand All @@ -347,9 +369,10 @@ async def RegisterWorker(self, request, context):
cpu_resource = Resource(name="Slot")
worker_resources = Resources(
resource_vector={
cpu_resource: request.cores if not FLAGS.override_worker_cpu_count
else 640
},
# TODO(elton): handle override worker cpu count?
# cpu_resource: request.cores,
cpu_resource: 640,
},
_logger=self._logger,
)
worker = Worker(
Expand Down Expand Up @@ -381,11 +404,20 @@ async def GetPlacements(self, request, context):
message=msg,
)

task_graph, stage_id_mapping = self._registered_task_graphs[request.id]
r = self._registered_task_graphs[request.id]

if r.task_graph is None:
msg = f"[{stime}] Task graph '{request.id}' is not ready"
self._logger.error(msg)
return erdos_scheduler_pb2.GetPlacementsResponse(
success=True,
message=msg,
placements=[],
)

# Check if the task graph is active
if task_graph.is_complete():
msg = f"[{stime}] Task graph '{task_graph.name}' is complete. No more placements to provide."
if r.task_graph.is_complete():
msg = f"[{stime}] Task graph '{r.task_graph.name}' is complete. No more placements to provide."
self._logger.error(msg)
return erdos_scheduler_pb2.GetPlacementsResponse(
success=False,
Expand All @@ -394,18 +426,18 @@ async def GetPlacements(self, request, context):

with self._lock:
sim_placements = self._simulator.get_current_placements_for_task_graph(
task_graph.name
r.task_graph.name
)

self._logger.info(
f"Received the following placements for '{task_graph.name}': {sim_placements}"
f"Received the following placements for '{r.task_graph.name}': {sim_placements}"
)

# Construct response. Notably, we apply stage-id mapping
placements = []
for placement in sim_placements:
worker_id = self.__get_worker_id()
task_id = stage_id_mapping[placement.task.name]
task_id = r.stage_id_mapping[placement.task.name]
cores = sum(x for _, x in placement.execution_strategy.resources.resources)

if placement.placement_type not in (
Expand Down Expand Up @@ -443,12 +475,10 @@ async def NotifyTaskCompletion(self, request, context):
message=msg,
)

task_graph, stage_id_mapping = self._registered_task_graphs[
request.application_id
]
task = task_graph.get_task(stage_id_mapping[request.task_id])
r = self._registered_task_graphs[request.application_id]
task = r.task_graph.get_task(r.stage_id_mapping[request.task_id])
if task is None:
msg = f"[{stime}] Task '{request.task_id}' does not exist in the task graph '{task_graph.name}'"
msg = f"[{stime}] Task '{request.task_id}' does not exist in the task graph '{r.task_graph.name}'"
self._logger.error(msg)
return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=False,
Expand Down Expand Up @@ -482,7 +512,7 @@ async def NotifyTaskCompletion(self, request, context):
self._simulator._event_queue.add_event(task_finished_event)
self._simulator._event_queue.add_event(scheduler_start_event)

msg = f"[{stime}] Successfully processed completion of task '{request.task_id}' of task graph '{task_graph.name}'"
msg = f"[{stime}] Successfully processed completion of task '{request.task_id}' of task graph '{r.task_graph.name}'"
self._logger.info(msg)
return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=True,
Expand Down

0 comments on commit bf658fe

Please sign in to comment.