diff --git a/rpc/service.py b/rpc/service.py index 8934d80c..2aaa2dc9 100644 --- a/rpc/service.py +++ b/rpc/service.py @@ -1,4 +1,5 @@ import asyncio +import heapq import os import sys import time @@ -67,6 +68,29 @@ ) +# Define an item containing completion timestamp and task +class TimedItem: + def __init__(self, timestamp, task): + self.timestamp = timestamp + self.task = task + + +# Define a priority queue based on heapq module +class PriorityQueue: + def __init__(self): + self._queue = [] + + def put(self, item): + heapq.heappush(self._queue, (item.timestamp, item)) + + def get(self): + _, item = heapq.heappop(self._queue) + return item + + def empty(self): + return len(self._queue) == 0 + + # Implement the service. class SchedulerServiceServicer(erdos_scheduler_pb2_grpc.SchedulerServiceServicer): def __init__(self) -> None: @@ -101,6 +125,12 @@ def __init__(self) -> None: # NOTE (Sukrit): This must always be sorted by the Placement time. self._placements: Mapping[str, Sequence[Placement]] = defaultdict(list) + # Additional task information maintained by the servicer + self._tasks_marked_for_completion = PriorityQueue() + + # Start the asyncio loop for clearing out pending tasks for completion + asyncio.create_task(self.PopTasksBasedOnTime()) + super().__init__() async def schedule(self) -> None: @@ -404,7 +434,7 @@ async def RegisterTaskGraph(self, request, context): # Construct all the Tasks for the TaskGraph. task_ids_to_task: Mapping[int, Task] = {} default_resource = Resources( - resource_vector={Resource(name="Slot_CPU", _id="any"): 30} + resource_vector={Resource(name="Slot_CPU", _id="any"): 20} ) default_runtime = EventTime(20, EventTime.Unit.US) @@ -685,16 +715,42 @@ async def NotifyTaskCompletion(self, request, context): f"not found in TaskGraph {request.application_id}.", ) - # Mark the Task as completed. - matched_task.update_remaining_time(EventTime.zero()) - matched_task.finish(EventTime(request.timestamp, EventTime.Unit.S)) + # Instead of completing & removing the task immediately, check + # if it is actually complete or will complete in the future - # Run the scheduler since the Workload has changed. - await self.run_scheduler() + # Get the actual task completion timestamp + actual_task_completion_time = ( + matched_task.start_time.time + matched_task.remaining_time.time + ) + + current_time = time.time() + self._logger.info( + "Received task for completion at time: %s , task.start_time: %s ," + "task.remaining_time (=runtime): %s , actual completion time: %s ", + round(current_time), + matched_task.start_time.time, + matched_task.remaining_time.time, + actual_task_completion_time, + ) + + # TODO DG: remaining_time assumes execution of the slowest strategy + # Should be updated to reflect correct remaining_time based on chosen strategy? + + # Add all tasks to _tasks_marked_for_completion queue. + # If task has actually completed, it will be dequeued immediately + # Else it will be dequeued at its actual task completion time + self._tasks_marked_for_completion.put( + TimedItem(actual_task_completion_time, matched_task) + ) + + # NOTE: task.finish() and run_scheduler() invocations are postponed + # until it is time for the task to be actually marked as complete. return erdos_scheduler_pb2.NotifyTaskCompletionResponse( success=True, - message=f"Task with ID {request.task_id} completed successfully!", + message=f"Task with ID {request.task_id} marked for completion at " + f"{round(current_time)}! It will be removed on actual " + f"task completion time at {actual_task_completion_time}", ) async def GetPlacements(self, request, context): @@ -751,6 +807,42 @@ async def GetPlacements(self, request, context): f"placements at time {request.timestamp}.", ) + # Function to pop tasks from queue based on actual completion time + async def PopTasksBasedOnTime(self): + while True: + if not self._tasks_marked_for_completion.empty(): + # Get the top item from the priority queue + top_item = self._tasks_marked_for_completion._queue[0][1] + + # Check if top item's timestamp is reached or passed by current time + current_time = time.time() + if top_item.timestamp <= current_time: + # Pop the top item + popped_item = self._tasks_marked_for_completion.get() + self._logger.info( + "Removing tasks from pending completion queue: %s at time: %s", + popped_item.task, + current_time, + ) + + # Mark the Task as completed. + # Also release the task from the scheduler service + popped_item.task.update_remaining_time(EventTime.zero()) + popped_item.task.finish( + EventTime(round(current_time), EventTime.Unit.S) + ) + + # Run the scheduler since the Workload has changed. + await self.run_scheduler() + + else: + # If the top item's timestamp hasn't been reached yet, + # sleep for a short duration + await asyncio.sleep(0.1) # TODO: Can adjust value, curr=0.1s + else: + # If the queue is empty, sleep for a short duration + await asyncio.sleep(0.1) # TODO: Can adjust value, curr=0.1s + async def serve(): """Serves the ERDOS Scheduling RPC Server.""" @@ -768,9 +860,14 @@ async def serve(): def main(argv): + # Create an asyncio event loop loop = asyncio.get_event_loop() - loop.run_until_complete(serve()) - loop.close() + + # Run the event loop until serve() completes + try: + loop.run_until_complete(serve()) + finally: + loop.close() if __name__ == "__main__":