Skip to content

Commit

Permalink
Merge pull request #93 from erdos-project/dg/add_queue_task_completion
Browse files Browse the repository at this point in the history
Enforcing correct task completion time using priority queue
  • Loading branch information
dhruvsgarg authored Mar 23, 2024
2 parents 1f2fd81 + c8b5608 commit 355e4e1
Showing 1 changed file with 106 additions and 9 deletions.
115 changes: 106 additions & 9 deletions rpc/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import heapq
import os
import sys
import time
Expand Down Expand Up @@ -67,6 +68,29 @@
)


# Define an item containing completion timestamp and task
class TimedItem:
def __init__(self, timestamp, task):
self.timestamp = timestamp
self.task = task


# Define a priority queue based on heapq module
class PriorityQueue:
def __init__(self):
self._queue = []

def put(self, item):
heapq.heappush(self._queue, (item.timestamp, item))

def get(self):
_, item = heapq.heappop(self._queue)
return item

def empty(self):
return len(self._queue) == 0


# Implement the service.
class SchedulerServiceServicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer):
def __init__(self) -> None:
Expand Down Expand Up @@ -101,6 +125,12 @@ def __init__(self) -> None:
# NOTE (Sukrit): This must always be sorted by the Placement time.
self._placements: Mapping[str, Sequence[Placement]] = defaultdict(list)

# Additional task information maintained by the servicer
self._tasks_marked_for_completion = PriorityQueue()

# Start the asyncio loop for clearing out pending tasks for completion
asyncio.create_task(self.PopTasksBasedOnTime())

super().__init__()

async def schedule(self) -> None:
Expand Down Expand Up @@ -404,7 +434,7 @@ async def RegisterTaskGraph(self, request, context):
# Construct all the Tasks for the TaskGraph.
task_ids_to_task: Mapping[int, Task] = {}
default_resource = Resources(
resource_vector={Resource(name="Slot_CPU", _id="any"): 30}
resource_vector={Resource(name="Slot_CPU", _id="any"): 20}
)
default_runtime = EventTime(20, EventTime.Unit.US)

Expand Down Expand Up @@ -685,16 +715,42 @@ async def NotifyTaskCompletion(self, request, context):
f"not found in TaskGraph {request.application_id}.",
)

# Mark the Task as completed.
matched_task.update_remaining_time(EventTime.zero())
matched_task.finish(EventTime(request.timestamp, EventTime.Unit.S))
# Instead of completing & removing the task immediately, check
# if it is actually complete or will complete in the future

# Run the scheduler since the Workload has changed.
await self.run_scheduler()
# Get the actual task completion timestamp
actual_task_completion_time = (
matched_task.start_time.time + matched_task.remaining_time.time
)

current_time = time.time()
self._logger.info(
"Received task for completion at time: %s , task.start_time: %s ,"
"task.remaining_time (=runtime): %s , actual completion time: %s ",
round(current_time),
matched_task.start_time.time,
matched_task.remaining_time.time,
actual_task_completion_time,
)

# TODO DG: remaining_time assumes execution of the slowest strategy
# Should be updated to reflect correct remaining_time based on chosen strategy?

# Add all tasks to _tasks_marked_for_completion queue.
# If task has actually completed, it will be dequeued immediately
# Else it will be dequeued at its actual task completion time
self._tasks_marked_for_completion.put(
TimedItem(actual_task_completion_time, matched_task)
)

# NOTE: task.finish() and run_scheduler() invocations are postponed
# until it is time for the task to be actually marked as complete.

return erdos_scheduler_pb2.NotifyTaskCompletionResponse(
success=True,
message=f"Task with ID {request.task_id} completed successfully!",
message=f"Task with ID {request.task_id} marked for completion at "
f"{round(current_time)}! It will be removed on actual "
f"task completion time at {actual_task_completion_time}",
)

async def GetPlacements(self, request, context):
Expand Down Expand Up @@ -751,6 +807,42 @@ async def GetPlacements(self, request, context):
f"placements at time {request.timestamp}.",
)

# Function to pop tasks from queue based on actual completion time
async def PopTasksBasedOnTime(self):
while True:
if not self._tasks_marked_for_completion.empty():
# Get the top item from the priority queue
top_item = self._tasks_marked_for_completion._queue[0][1]

# Check if top item's timestamp is reached or passed by current time
current_time = time.time()
if top_item.timestamp <= current_time:
# Pop the top item
popped_item = self._tasks_marked_for_completion.get()
self._logger.info(
"Removing tasks from pending completion queue: %s at time: %s",
popped_item.task,
current_time,
)

# Mark the Task as completed.
# Also release the task from the scheduler service
popped_item.task.update_remaining_time(EventTime.zero())
popped_item.task.finish(
EventTime(round(current_time), EventTime.Unit.S)
)

# Run the scheduler since the Workload has changed.
await self.run_scheduler()

else:
# If the top item's timestamp hasn't been reached yet,
# sleep for a short duration
await asyncio.sleep(0.1) # TODO: Can adjust value, curr=0.1s
else:
# If the queue is empty, sleep for a short duration
await asyncio.sleep(0.1) # TODO: Can adjust value, curr=0.1s


async def serve():
"""Serves the ERDOS Scheduling RPC Server."""
Expand All @@ -768,9 +860,14 @@ async def serve():


def main(argv):
# Create an asyncio event loop
loop = asyncio.get_event_loop()
loop.run_until_complete(serve())
loop.close()

# Run the event loop until serve() completes
try:
loop.run_until_complete(serve())
finally:
loop.close()


if __name__ == "__main__":
Expand Down

0 comments on commit 355e4e1

Please sign in to comment.