From fb7fa3a7283007060ffb53ec97cda0a754ffc075 Mon Sep 17 00:00:00 2001 From: Andrew Halberstadt Date: Mon, 7 Nov 2022 16:40:44 -0500 Subject: [PATCH] feat: generate kinds concurrently --- src/taskgraph/generator.py | 95 ++++++++++++++++++++++---------- src/taskgraph/transforms/base.py | 31 ++++++++++- 2 files changed, 95 insertions(+), 31 deletions(-) diff --git a/src/taskgraph/generator.py b/src/taskgraph/generator.py index 4ed2a4152..b5eca4b20 100644 --- a/src/taskgraph/generator.py +++ b/src/taskgraph/generator.py @@ -2,6 +2,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +import asyncio import copy import logging import os @@ -44,7 +45,8 @@ def _get_loader(self): loader = "taskgraph.loader.default:loader" return find_object(loader) - def load_tasks(self, parameters, loaded_tasks, write_artifacts): + async def load_tasks(self, parameters, loaded_tasks, write_artifacts): + logger.debug(f"Loading tasks for kind {self.name}") loader = self._get_loader() config = copy.deepcopy(self.config) @@ -73,20 +75,22 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts): self.graph_config, write_artifacts=write_artifacts, ) + tasks = await transforms(trans_config, inputs) tasks = [ Task( self.name, - label=task_dict["label"], - description=task_dict["description"], - attributes=task_dict["attributes"], - task=task_dict["task"], - optimization=task_dict.get("optimization"), - dependencies=task_dict.get("dependencies"), - soft_dependencies=task_dict.get("soft-dependencies"), - if_dependencies=task_dict.get("if-dependencies"), + label=t["label"], + description=t["description"], + attributes=t["attributes"], + task=t["task"], + optimization=t.get("optimization"), + dependencies=t.get("dependencies"), + soft_dependencies=t.get("soft-dependencies"), + if_dependencies=t.get("if-dependencies"), ) - for task_dict in transforms(trans_config, inputs) + async for t in tasks ] + logger.info(f"Generated {len(tasks)} tasks for kind {self.name}") return tasks @classmethod @@ -249,6 +253,57 @@ def _load_kinds(self, graph_config, target_kinds=None): except KindNotFound: continue + async def _load_tasks(self, kinds, kind_graph, parameters): + all_tasks = {} + futures_to_kind = {} + + def add_new_tasks(tasks): + for task in tasks: + if task.label in all_tasks: + raise Exception("duplicate tasks with label " + task.label) + all_tasks[task.label] = task + + def create_futures(kinds, edges): + """Create the next batch of tasks for kinds without dependencies.""" + kinds_with_deps = {edge[0] for edge in edges} + ready_kinds = set(kinds) - kinds_with_deps + futures = set() + for name in ready_kinds: + task = asyncio.create_task( + kinds[name].load_tasks( + parameters, + list(all_tasks.values()), + self._write_artifacts, + ) + ) + futures.add(task) + futures_to_kind[task] = name + return futures + + edges = set(kind_graph.edges) + futures = create_futures(kinds, edges) + while len(kinds) > 0: + done, futures = await asyncio.wait( + futures, return_when=asyncio.FIRST_COMPLETED + ) + + for future in done: + add_new_tasks(future.result()) + name = futures_to_kind[future] + + # Update state for next batch of futures. + del kinds[name] + edges = {e for e in edges if e[1] != name} + + futures |= create_futures(kinds, edges) + + if futures: + done, _ = await asyncio.wait(futures, return_when=asyncio.ALL_COMPLETED) + for future in done: + add_new_tasks(future.result()) + + return all_tasks + def _run(self): logger.info("Loading graph configuration.") graph_config = load_graph_config(self.root_dir) @@ -303,24 +358,8 @@ def _run(self): ) logger.info("Generating full task set") - all_tasks = {} - for kind_name in kind_graph.visit_postorder(): - logger.debug(f"Loading tasks for kind {kind_name}") - kind = kinds[kind_name] - try: - new_tasks = kind.load_tasks( - parameters, - list(all_tasks.values()), - self._write_artifacts, - ) - except Exception: - logger.exception(f"Error loading tasks for kind {kind_name}:") - raise - for task in new_tasks: - if task.label in all_tasks: - raise Exception("duplicate tasks with label " + task.label) - all_tasks[task.label] = task - logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}") + all_tasks = asyncio.run(self._load_tasks(kinds, kind_graph, parameters)) + full_task_set = TaskGraph(all_tasks, Graph(set(all_tasks), set())) yield self.verify("full_task_set", full_task_set, graph_config, parameters) diff --git a/src/taskgraph/transforms/base.py b/src/taskgraph/transforms/base.py index e6fcd2400..c633fd3f2 100644 --- a/src/taskgraph/transforms/base.py +++ b/src/taskgraph/transforms/base.py @@ -2,7 +2,8 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. - +import asyncio +import inspect import re from dataclasses import dataclass, field from typing import Dict, List, Union @@ -107,6 +108,12 @@ def repo_configs(self): return repo_configs +async def convert_async(it): + """Convert a synchronous iterator to an async one.""" + for i in it: + yield i + + @dataclass() class TransformSequence: """ @@ -121,11 +128,29 @@ class TransformSequence: _transforms: List = field(default_factory=list) - def __call__(self, config, items): + async def __call__(self, config, items): for xform in self._transforms: - items = xform(config, items) + if isinstance(xform, TransformSequence): + items = await xform(config, items) + elif inspect.isasyncgenfunction(xform): + # Async generator transforms require async generator inputs. + # This can happen if a synchronous transform ran immediately + # prior. + if not inspect.isasyncgen(items): + items = convert_async(items) + items = xform(config, items) + else: + # Creating a synchronous generator from an asynchronous context + # doesn't appear possible, so unfortunately we need to convert + # to a list. + if inspect.isasyncgen(items): + items = [i async for i in items] + items = xform(config, items) if items is None: raise Exception(f"Transform {xform} is not a generator") + + if not inspect.isasyncgen(items): + items = convert_async(items) return items def add(self, func):