Skip to content

Commit

Permalink
[RPC] End-to-End working Spark workflow w/ sequential scheduling.
Browse files Browse the repository at this point in the history
  • Loading branch information
sukritkalra committed Feb 21, 2024
1 parent c37ba23 commit 6b6bb4a
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 32 deletions.
169 changes: 137 additions & 32 deletions rpc/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import time
from collections import defaultdict
from concurrent import futures
from typing import Mapping, Sequence
from urllib.parse import urlparse
Expand All @@ -16,7 +17,7 @@
from absl import app, flags

from utils import EventTime, setup_logging
from workers import Worker, WorkerPool
from workers import Worker, WorkerPool, WorkerPools
from workload import (
ExecutionStrategies,
ExecutionStrategy,
Expand Down Expand Up @@ -57,16 +58,21 @@ def __init__(self) -> None:

# The simulator types maintained by the Servicer.
self._worker_pool = None
self._worker_pools = None
self._drivers: Mapping[str, Task] = {}
self._workload = None

# Scheduler information maintained by the servicer.
self._scheduler_running_lock = asyncio.Lock()
self._scheduler_running = False
self._rerun_scheduler = False

# Placement information maintained by the servicer.
# The placements map the application IDs to the Placement retrieved from the
# scheduler. The placements are automatically clipped at the time of informing
# the framework of applying them to the executors.
# NOTE (Sukrit): This must always be sorted by the Placement time.
self._placements: Sequence[Placement] = []
self._placements: Mapping[str, Sequence[Placement]] = defaultdict(list)

super().__init__()

Expand All @@ -80,42 +86,65 @@ async def schedule(self) -> None:
return
self._scheduler_running = True

current_time = EventTime(int(time.time()), EventTime.Unit.S)
self._logger.info(
"Starting a scheduling cycle with %s TaskGraphs and %s Workers.",
"Starting a scheduling cycle with %s TaskGraphs and %s Workers at %s.",
len(self._workload.task_graphs),
len(self._worker_pool.workers),
current_time,
)

# TODO (Sukrit): Change this to a better implementation.
# Let's do some simple scheduling for now, that gives a fixed number of
# executors to all the available applications in intervals of 10 seconds.
if len(self._workload.task_graphs) > 0:
task_graph = next(iter(self._workload.task_graphs.values()))
task = task_graph.get_source_tasks()[0]
strategy = task.available_execution_strategies.get_fastest_strategy()
for i in range(5, 10, 6):
self._placements.append(
Placement(
type=Placement.PlacementType.PLACE_TASK,
computation=task,
placement_time=EventTime(
int(time.time()) + i, EventTime.Unit.S
),
worker_pool_id=self._worker_pool.id,
worker_id=self._worker_pool.workers[0].name,
strategy=strategy,
)
tasks = self._workload.get_schedulable_tasks(
current_time, worker_pools=self._worker_pools
)
self._logger.info(
"Found %s tasks that can be scheduled at %s: %s",
len(tasks),
current_time,
[task.unique_name for task in tasks],
)
if len(tasks) > 0:
task = tasks[0]
strategy = task.available_execution_strategies.get_fastest_strategy()
placement = Placement(
type=Placement.PlacementType.PLACE_TASK,
computation=tasks[0],
placement_time=EventTime(int(time.time()) + 5, EventTime.Unit.S),
worker_pool_id=self._worker_pool.id,
worker_id=self._worker_pool.workers[0].name,
strategy=strategy,
)
self._placements[task.task_graph].append(placement)
task.schedule(
time=placement.placement_time,
placement=placement,
)

self._logger.info("Finished a scheduling cycle.")

# Check if another run of the Scheduler has been requested, and if so, create
# a task for it. Otherwise, mark the scheduler as not running.
async with self._scheduler_running_lock:
self._scheduler_running = False
if self._rerun_scheduler:
self._rerun_scheduler = False
asyncio.create_task(self.schedule())

async def run_scheduler(self) -> None:
"""Checks if the scheduler is running, and if not, starts it."""
"""Checks if the scheduler is running, and if not, starts it.
If the scheduler is already running, we queue up another execution of the
scheduler. This execution batches the scheduling requests, and runs the
scheduler only once for all the requests."""
async with self._scheduler_running_lock:
if not self._scheduler_running:
asyncio.create_task(self.schedule())
else:
self._rerun_scheduler = True

async def RegisterFramework(self, request, context):
"""Registers a new framework with the backend scheduler.
Expand Down Expand Up @@ -149,6 +178,7 @@ async def RegisterFramework(self, request, context):
# Setup the simulator types.
parsed_uri = urlparse(self._master_uri)
self._worker_pool = WorkerPool(name=f"WorkerPool_{parsed_uri.netloc}")
self._worker_pools = WorkerPools(worker_pools=[self._worker_pool])
self._workload = Workload.from_task_graphs({})

# Return the response.
Expand Down Expand Up @@ -345,6 +375,9 @@ async def RegisterTaskGraph(self, request, context):
# type so that we can correlate the Tasks with a particular invocation.
timestamp=1,
)
# NOTE (Sukrit): We maintain the StageID of the Task as a separate field
# that is not accessible / used by the Simulator.
task_ids_to_task[framework_task.id].stage_id = framework_task.id
self._logger.info(
"Constructed Task %s for the TaskGraph %s.",
framework_task.name,
Expand Down Expand Up @@ -373,9 +406,6 @@ async def RegisterTaskGraph(self, request, context):
str(task_graph),
)

# Run the scheduler since the Workload has changed.
await self.run_scheduler()

# Return the response.
return erdos_scheduler_pb2.RegisterTaskGraphResponse(
success=True,
Expand Down Expand Up @@ -430,8 +460,8 @@ async def RegisterEnvironmentReady(self, request, context):
for source_task in task_graph.get_source_tasks():
source_task.release(EventTime(request.timestamp, EventTime.Unit.S))

# TODO (Sukrit): A new application has arrived, we need to queue up the
# execution of the scheduler.
# Run the scheduler since the Workload has changed.
await self.run_scheduler()

return erdos_scheduler_pb2.RegisterEnvironmentReadyResponse(
success=True,
Expand Down Expand Up @@ -518,38 +548,113 @@ async def RegisterWorker(self, request, context):
success=True, message=f"Worker {request.name} registered successfully!"
)

async def NotifyTaskCompletion(self, request, context):
"""Notifies the backend scheduler that a task has completed."""
if not self._initialized:
self._logger.warning(
"Trying to notify the backend scheduler that the task with ID %s "
"from application %s has completed, "
"but no framework is registered yet.",
request.task_id,
request.application_id,
)
return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=False, message="Framework not registered yet."
)

task_graph = self._workload.get_task_graph(request.application_id)
if task_graph is None:
self._logger.warning(
"Trying to notify the backend scheduler that the task with ID %s "
"from application %s has completed, but the application "
"was not registered with the backend yet.",
request.task_id,
request.application_id,
)
return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=False,
message=f"Application with ID {request.application_id} "
f"not registered yet.",
)

# Find the Task that has completed, and mark it as such.
matched_task = None
for task in task_graph.get_nodes():
if task.stage_id == request.task_id:
matched_task = task
if matched_task is None:
self._logger.warning(
"Trying to notify the backend scheduler that the task with ID %s "
"from application %s has completed, but the task "
"was not found in the TaskGraph.",
request.task_id,
request.application_id,
)
return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=False,
message=f"Task with ID {request.task_id} "
f"not found in TaskGraph {request.application_id}.",
)

# Mark the Task as completed.
matched_task.update_remaining_time(EventTime.zero())
matched_task.finish(EventTime(request.timestamp, EventTime.Unit.S))

# Run the scheduler since the Workload has changed.
await self.run_scheduler()

return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=True,
message=f"Task with ID {request.task_id} completed successfully!",
)

async def GetPlacements(self, request, context):
"""Retrieves the placements applicable at the specified time."""
request_timestamp = EventTime(request.timestamp, EventTime.Unit.S)
if not self._initialized:
self._logger.warning(
"Trying to get placements at time %s, "
"Trying to get placements for %s at time %s, "
"but no framework is registered yet.",
request.timestamp,
request.id,
request_timestamp,
)
return erdos_scheduler_pb2.GetPlacementsResponse(
success=False, message="Framework not registered yet."
)

# Construct and return the placements.
if request.id not in self._placements:
self._logger.warning(
"Trying to get placements for %s at time %s, but the application "
"was not registered with the backend yet.",
request.id,
request_timestamp,
)

# Construct and return the placements.,
placements = []
clip_at = -1
request_timestamp = EventTime(request.timestamp, EventTime.Unit.S)
for index, placement in enumerate(self._placements):
for index, placement in enumerate(self._placements[request.id]):
if placement.placement_time <= request_timestamp:
clip_at = index
# Mark the Task as RUNNING.
placement.task.start(request_timestamp)

# resources = placement.execution_strategy.resources
placements.append(
erdos_scheduler_pb2.Placement(
worker_id=placement.worker_id,
application_id=placement.task.task_graph,
application_id=request.id,
task_id=placement.task.stage_id,
cores=1,
)
)
self._logger.info(
"Currently %s placements, clipping at %s.", len(placements), clip_at
)
self._placements = self._placements[clip_at + 1 :]
self._logger.info("Clipped placements length: %s", len(self._placements))
self._placements[request.id] = self._placements[request.id][clip_at + 1 :]
self._logger.info(
"Clipped placements length: %s", len(self._placements[request.id])
)
self._logger.info(
"Constructed %s placements at time %s.", len(placements), request.timestamp
)
Expand Down
22 changes: 22 additions & 0 deletions workload/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,28 @@ def get_remaining_time(
# Find the maximum remaining time across all the sink nodes.
return max([remaining_time[sink] for sink in self.get_sink_tasks()])

def get_task(self, name: str) -> Optional[Task]:
"""Retrieve the Task with the given name from the TaskGraph.
Returns `None` if no such Task exists in the TaskGraph.
Args:
name (`str`): The name of the Task to retrieve.
Returns:
The `Task` with the given name, if it exists in the TaskGraph.
Raises:
`ValueError` if multiple tasks with the same name are found.
"""
matched_task = None
for task in self.get_nodes():
if task.name == name:
if matched_task:
raise ValueError(f"Multiple tasks with the name {name} found.")
matched_task = task
return matched_task

@property
def deadline(self) -> EventTime:
"""Retrieve the deadline to which the TaskGraph is being subjected to.
Expand Down

0 comments on commit 6b6bb4a

Please sign in to comment.