diff --git a/main.py b/main.py index c501e4cc..fe3e49b9 100644 --- a/main.py +++ b/main.py @@ -414,6 +414,12 @@ "If True, the scheduler writes the status of each run to a seperate" "log file in a format unique to every scheduler.", ) +flags.DEFINE_bool( + "decompose_deadlines", + False, + "If True, the task deadline is decided by decomposing the TaskGraph's deadline " + "according to the critical path of the TaskGraph.", +) flags.DEFINE_list( "scheduler_log_times", [], diff --git a/workload/jobs.py b/workload/jobs.py index de469191..e0acb2e4 100644 --- a/workload/jobs.py +++ b/workload/jobs.py @@ -886,8 +886,57 @@ def _generate_task_graph( task_graph_deadline = release_time + weighted_task_graph_length.fuzz( deadline_variance, deadline_bounds ) - for task in task_graph.get_nodes(): - task.update_deadline(task_graph_deadline) + if _flags and _flags.decompose_deadlines: + stages_info = {} + stages = set([]) + for task in task_graph.topological_sort(): + stage = 0 + for previous_task in task_graph.get_parents(task): + stage = max(stage, stages_info.get(previous_task, 0) + 1) + stages_info[task] = stage + stages.add(stage) + + critical_path = task_graph.get_longest_path( + weights=lambda task: (task.slowest_execution_strategy.runtime.time) + ) + critical_path_time = ( + sum( + [t.slowest_execution_strategy.runtime for t in critical_path], + start=EventTime.zero(), + ) + .to(EventTime.Unit.US) + .time + ) + stage_wise_deadline = {} + for critical_task in critical_path: + stage_deadline = int( + task_graph_deadline.to(EventTime.Unit.US).time + * critical_task.slowest_execution_strategy.runtime.to( + EventTime.Unit.US + ).time + / critical_path_time + ) + stage_wise_deadline[stages_info[critical_task]] = stage_deadline + + for task in task_graph.get_nodes(): + # For the tasks that do not fall on the critical path of the + # computation, we heuristically find the closest stage and assign it + # to that stage's deadline. + task_stage = min( + stage_wise_deadline.keys(), + key=lambda s: abs( + s - stages_info[task] + if s >= stages_info[task] + else float("inf") + ), + ) + deadline = EventTime( + stage_wise_deadline[task_stage], unit=EventTime.Unit.US + ) + task.update_deadline(deadline) + else: + for task in task_graph.get_nodes(): + task.update_deadline(task_graph_deadline) return task_graph