Skip to content

Commit

Permalink
feat: generate kinds concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
ahal committed Nov 30, 2023
1 parent 931ab6f commit fb7fa3a
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 31 deletions.
95 changes: 67 additions & 28 deletions src/taskgraph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import asyncio
import copy
import logging
import os
Expand Down Expand Up @@ -44,7 +45,8 @@ def _get_loader(self):
loader = "taskgraph.loader.default:loader"
return find_object(loader)

def load_tasks(self, parameters, loaded_tasks, write_artifacts):
async def load_tasks(self, parameters, loaded_tasks, write_artifacts):
logger.debug(f"Loading tasks for kind {self.name}")
loader = self._get_loader()
config = copy.deepcopy(self.config)

Expand Down Expand Up @@ -73,20 +75,22 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts):
self.graph_config,
write_artifacts=write_artifacts,
)
tasks = await transforms(trans_config, inputs)
tasks = [
Task(
self.name,
label=task_dict["label"],
description=task_dict["description"],
attributes=task_dict["attributes"],
task=task_dict["task"],
optimization=task_dict.get("optimization"),
dependencies=task_dict.get("dependencies"),
soft_dependencies=task_dict.get("soft-dependencies"),
if_dependencies=task_dict.get("if-dependencies"),
label=t["label"],
description=t["description"],
attributes=t["attributes"],
task=t["task"],
optimization=t.get("optimization"),
dependencies=t.get("dependencies"),
soft_dependencies=t.get("soft-dependencies"),
if_dependencies=t.get("if-dependencies"),
)
for task_dict in transforms(trans_config, inputs)
async for t in tasks
]
logger.info(f"Generated {len(tasks)} tasks for kind {self.name}")
return tasks

@classmethod
Expand Down Expand Up @@ -249,6 +253,57 @@ def _load_kinds(self, graph_config, target_kinds=None):
except KindNotFound:
continue

async def _load_tasks(self, kinds, kind_graph, parameters):
all_tasks = {}
futures_to_kind = {}

def add_new_tasks(tasks):
for task in tasks:
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task

def create_futures(kinds, edges):
"""Create the next batch of tasks for kinds without dependencies."""
kinds_with_deps = {edge[0] for edge in edges}
ready_kinds = set(kinds) - kinds_with_deps
futures = set()
for name in ready_kinds:
task = asyncio.create_task(
kinds[name].load_tasks(
parameters,
list(all_tasks.values()),
self._write_artifacts,
)
)
futures.add(task)
futures_to_kind[task] = name
return futures

edges = set(kind_graph.edges)
futures = create_futures(kinds, edges)
while len(kinds) > 0:
done, futures = await asyncio.wait(
futures, return_when=asyncio.FIRST_COMPLETED
)

for future in done:
add_new_tasks(future.result())
name = futures_to_kind[future]

# Update state for next batch of futures.
del kinds[name]
edges = {e for e in edges if e[1] != name}

futures |= create_futures(kinds, edges)

if futures:
done, _ = await asyncio.wait(futures, return_when=asyncio.ALL_COMPLETED)
for future in done:
add_new_tasks(future.result())

return all_tasks

def _run(self):
logger.info("Loading graph configuration.")
graph_config = load_graph_config(self.root_dir)
Expand Down Expand Up @@ -303,24 +358,8 @@ def _run(self):
)

logger.info("Generating full task set")
all_tasks = {}
for kind_name in kind_graph.visit_postorder():
logger.debug(f"Loading tasks for kind {kind_name}")
kind = kinds[kind_name]
try:
new_tasks = kind.load_tasks(
parameters,
list(all_tasks.values()),
self._write_artifacts,
)
except Exception:
logger.exception(f"Error loading tasks for kind {kind_name}:")
raise
for task in new_tasks:
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task
logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}")
all_tasks = asyncio.run(self._load_tasks(kinds, kind_graph, parameters))

full_task_set = TaskGraph(all_tasks, Graph(set(all_tasks), set()))
yield self.verify("full_task_set", full_task_set, graph_config, parameters)

Expand Down
31 changes: 28 additions & 3 deletions src/taskgraph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.


import asyncio
import inspect
import re
from dataclasses import dataclass, field
from typing import Dict, List, Union
Expand Down Expand Up @@ -107,6 +108,12 @@ def repo_configs(self):
return repo_configs


async def convert_async(it):
"""Convert a synchronous iterator to an async one."""
for i in it:
yield i


@dataclass()
class TransformSequence:
"""
Expand All @@ -121,11 +128,29 @@ class TransformSequence:

_transforms: List = field(default_factory=list)

def __call__(self, config, items):
async def __call__(self, config, items):
for xform in self._transforms:
items = xform(config, items)
if isinstance(xform, TransformSequence):
items = await xform(config, items)
elif inspect.isasyncgenfunction(xform):
# Async generator transforms require async generator inputs.
# This can happen if a synchronous transform ran immediately
# prior.
if not inspect.isasyncgen(items):
items = convert_async(items)
items = xform(config, items)
else:
# Creating a synchronous generator from an asynchronous context
# doesn't appear possible, so unfortunately we need to convert
# to a list.
if inspect.isasyncgen(items):
items = [i async for i in items]
items = xform(config, items)
if items is None:
raise Exception(f"Transform {xform} is not a generator")

if not inspect.isasyncgen(items):
items = convert_async(items)
return items

def add(self, func):
Expand Down

0 comments on commit fb7fa3a

Please sign in to comment.