Skip to content

Commit

Permalink
correctly handle task cancellations
Browse files Browse the repository at this point in the history
  • Loading branch information
1ntEgr8 committed Nov 29, 2024
1 parent 9cf701a commit 25f4a3d
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 64 deletions.
2 changes: 1 addition & 1 deletion rpc/protos/rpc/erdos_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ message Placement {
string application_id = 2;
uint32 task_id = 3;
uint32 cores = 4;
bool cancelled = 5; // If the task (and thereby the task graph) should be cancelled
}

message GetPlacementsResponse {
bool success = 1;
repeated Placement placements = 2;
string message = 3;
bool terminate = 4; // terminate the task graph
}
180 changes: 121 additions & 59 deletions rpc/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,11 @@ async def DeregisterDriver(self, request, context):
task_graph_name = self._registered_app_drivers[request.id]
del self._registered_app_drivers[request.id]

# Log stats
log_stats_event = Event(
event_type=EventType.LOG_STATS,
time=stime,
)
with self._lock:
log_stats_event = Event(
event_type=EventType.LOG_STATS,
time=self.__stime(),
)
self._simulator._event_queue.add_event(log_stats_event)

msg = f"[{stime}] Successfully de-registered driver with id {request.id} for task graph {task_graph_name}"
Expand Down Expand Up @@ -469,27 +468,28 @@ async def RegisterEnvironmentReady(self, request, context):
)

r = self._registered_applications[request.id]

# Generate the task graph now
r.generate_task_graph(stime)

self._workload_loader.add_task_graph(r.task_graph)

update_workload_event = Event(
event_type=EventType.UPDATE_WORKLOAD,
time=stime,
)
scheduler_start_event = Event(
event_type=EventType.SCHEDULER_START,
time=stime.to(EventTime.Unit.US),
)

with self._lock:
self._simulator._event_queue.add_event(update_workload_event)
self._simulator._event_queue.add_event(scheduler_start_event)
self._logger.info(
f"[{stime}] Added event {update_workload_event} to the simulator's event queue"
self._simulator._workload.add_task_graph(r.task_graph)
self._simulator._current_task_graph_placements[r.task_graph.name] = {}

for task in r.task_graph.get_releasable_tasks():
task_release_event = Event(
event_type=EventType.TASK_RELEASE,
time=self.__stime(),
task=task,
)
self._logger.info(
f"[{stime}] Added event {task_release_event} to the simulator's event queue",
)
self._simulator._event_queue.add_event(task_release_event)

scheduler_start_event = Event(
event_type=EventType.SCHEDULER_START,
time=self.__stime(),
)
self._simulator._event_queue.add_event(scheduler_start_event)
self._logger.info(
f"[{stime}] Added event {scheduler_start_event} to the simulator's event queue"
)
Expand Down Expand Up @@ -563,30 +563,98 @@ async def GetPlacements(self, request, context):
placements=[],
)

# Check if the task graph is active
if r.task_graph.is_complete():
msg = f"[{stime}] Task graph '{r.task_graph.name}' is complete. No more placements to provide."
self._logger.error(msg)
return erdos_scheduler_pb2.GetPlacementsResponse(
success=False,
success=True,
message=msg,
)

# A task graph is considered complete if **all** of its **sink** tasks
# are complete. It is considered cancelled if **any** of its **sink**
# tasks are cancelled.

# If the task graph is complete, the Spark application will
# automatically shut down because it knows that all of its stages have
# finished executing.

# Matters get interesting in the presence of task cancellations. The
# service is aware of which tasks are cancelled.

# First, even when a task graph is cancelled, the simulator (without
# orchestration)
# continues to schedule and execute any tasks that were released into
# the system. The service, which runs the simulator in orchestrated
# mode, must emulate this behavior to maintain parity.

# Second, from Spark's point of view, however, those tasks are still
# pending placements. So, Spark will continue to periodically invoke
# `GetPlacements` in the hopes of receiving placements for those
# cancelled tasks. Left unhandled, the Spark application will loop
# indefinitely waiting for placements.

# We _could_ communicate these task cancellations to Spark. Then, we
# can modify the DAGScheduler to invoke GetPlacements until all of its
# stages have either finished executing or have been cancelled, after
# which it can safely terminate the application.

# However, we run into an issue due to VIRTUAL tasks. When a task is
# cancelled, the simulator invokes `TaskGraph.cancel(task)`.
# `TaskGraph.cancel(task)` traverses the tree rooted at `task`
# depth-first, cancelling tasks along the way until it finds the first
# terminal task. As a consequence, it is possible for the tree rooted
# at a cancelled task to have VIRTUAL tasks inside of it. These
# virtual tasks will never receive placements because they are not
# releasable. So, it is possible for the Spark application to stall on
# `GetPlacements` waiting on placements for these virtual tasks.

# Since the service knows the state of each task, it is easy then for
# the service to determine when the Spark application should terminate
# in the presence of task cancellations.

# So, instead of communicating task cancellations, we communicate when
# the Spark application should terminate.
#
# The first check makes sure all tasks are either CANCELLED,
# COMPLETED, or VIRTUAL. We check for all tasks because it is possible
# that the simulator is processing released and scheduled tasks. If we
# terminate early, then we will never receive `NotifyTaskCompletion`s
# for those tasks (because the Spark application was terminated),
# which then results in those tasks never getting removed from the
# worker pool.
#
# The second check makes sure that the task graph is indeed cancelled.
# We have this additional guard because at the start all tasks are
# VIRTUAL and we don't want to terminate the application then.

should_terminate = all(
task.state
in (
TaskState.CANCELLED,
TaskState.COMPLETED,
TaskState.VIRTUAL,
)
for task in r.task_graph
) and (r.task_graph.is_cancelled())
if should_terminate:
msg = f"[{stime}] Task graph '{r.task_graph.name}' was cancelled. No more placements to provide."
self._logger.error(msg)
return erdos_scheduler_pb2.GetPlacementsResponse(
success=True,
message=msg,
terminate=True,
)

with self._lock:
sim_placements = self._simulator.get_current_placements_for_task_graph(
r.task_graph.name
)

self._logger.info(
f"Received the following placements for '{r.task_graph.name}': {sim_placements}"
)

# Construct response. Notably, we apply stage-id mapping
placements = []
for placement in sim_placements:
# Ignore virtual placements
if placement.task.state < TaskState.RELEASED:
self._logger.debug("[{stime}] Skipping placement: {placement}")
if placement.task.state != TaskState.RUNNING:
self._logger.debug(f"[{stime}] Skipping placement: {placement}")
continue

worker_id = (
Expand All @@ -601,23 +669,17 @@ async def GetPlacements(self, request, context):
else 0
)

if placement.placement_type not in (
Placement.PlacementType.PLACE_TASK,
Placement.PlacementType.CANCEL_TASK,
):
if placement.placement_type not in (Placement.PlacementType.PLACE_TASK,):
raise NotImplementedError

placements.append(
{
"worker_id": worker_id,
"application_id": request.id,
"task_id": int(task_id),
"task_id": task_id,
"cores": cores,
"cancelled": placement.placement_type
== Placement.PlacementType.CANCEL_TASK,
}
},
)
self._logger.info(f"Sending placements for '{r.task_graph.name}': {placements}")

return erdos_scheduler_pb2.GetPlacementsResponse(
success=True,
Expand Down Expand Up @@ -660,28 +722,28 @@ async def NotifyTaskCompletion(self, request, context):
task.start_time + task.slowest_execution_strategy.runtime
)

# NOTE: Although the actual_task_completion_time works for task completion notifications that arrive early, it is
# inaccurate for task completion notifications that occur past that time. Thus, a max of the current and actual completion time
# is taken to ensure that the task is marked completed at the correct time.
task_finished_event = Event(
event_type=EventType.TASK_FINISHED,
time=max(actual_task_completion_time, stime),
task=task,
)
scheduler_start_event = Event(
event_type=EventType.SCHEDULER_START,
time=max(
actual_task_completion_time.to(EventTime.Unit.US),
stime.to(EventTime.Unit.US),
),
)

with self._lock:
# NOTE: Although the actual_task_completion_time works for task completion notifications that arrive early, it is
# inaccurate for task completion notifications that occur past that time. Thus, a max of the current and actual completion time
# is taken to ensure that the task is marked completed at the correct time.
task_finished_event = Event(
event_type=EventType.TASK_FINISHED,
time=max(actual_task_completion_time, self.__stime()),
task=task,
)
self._simulator._event_queue.add_event(task_finished_event)
self._simulator._event_queue.add_event(scheduler_start_event)
self._logger.info(
f"[{stime}] Adding event {task_finished_event} to the simulator's event queue"
)

scheduler_start_event = Event(
event_type=EventType.SCHEDULER_START,
time=max(
actual_task_completion_time.to(EventTime.Unit.US),
self.__stime(),
),
)
self._simulator._event_queue.add_event(scheduler_start_event)
self._logger.info(
f"[{stime}] Added event {scheduler_start_event} to the simulator's event queue"
)
Expand All @@ -698,7 +760,7 @@ async def _tick_simulator(self):
with self._lock:
if self._simulator is not None:
stime = self.__stime()
self._logger.debug(f"[{stime}] Simulator tick")
# self._logger.debug(f"[{stime}] Simulator tick")
self._simulator.tick(until=stime)
# else:
# print("Simulator instance is None")
Expand Down
8 changes: 4 additions & 4 deletions simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,6 @@ def tick(self, until: EventTime) -> None:
"""Tick the simulator until the specified time"""

def f():
self._logger.debug(f"EQ: {self._event_queue}")

time_until_next_event = self.__time_until_next_event()

if (
Expand Down Expand Up @@ -1224,6 +1222,7 @@ def __handle_task_finished(self, event: Event) -> None:
task_placed_at_worker_pool = self._worker_pools.get_worker_pool(
event.task.worker_pool_id
)

task_placed_at_worker_pool.remove_task(current_time=event.time, task=event.task)

# Remove the task from it's task graph's current placements
Expand Down Expand Up @@ -1626,7 +1625,9 @@ def is_source_task(task):
task_graph = self._workload.get_task_graph(task.task_graph)
return task_graph.is_source_task(task)

releasable_tasks = [task for task in releasable_tasks if is_source_task(task)]
releasable_tasks = [
task for task in releasable_tasks if is_source_task(task)
]

self._logger.info(
"[%s] The WorkloadLoader %s has %s TaskGraphs that released %s tasks.",
Expand Down Expand Up @@ -1676,7 +1677,6 @@ def is_source_task(task):

max_release_time = self._simulator_time
for task in releasable_tasks:

event = Event(
event_type=EventType.TASK_RELEASE, time=task.release_time, task=task
)
Expand Down

0 comments on commit 25f4a3d

Please sign in to comment.