Skip to content

Commit

Permalink
Merge pull request #1014 from roboflow/feature/caller_event_loop_in_w…
Browse files Browse the repository at this point in the history
…orkflows

Prepare version of Workflows EE where thread pool executor is injectable instead of created at each run
  • Loading branch information
PawelPeczek-Roboflow authored Feb 11, 2025
2 parents c1493a6 + e1cfbaa commit 84f6b62
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 95 deletions.
2 changes: 1 addition & 1 deletion inference/core/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.37.1"
__version__ = "0.38.0rc1"


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions inference/core/workflows/execution_engine/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Type

from packaging.specifiers import SpecifierSet
Expand Down Expand Up @@ -37,6 +38,7 @@ def init(
prevent_local_images_loading: bool = False,
workflow_id: Optional[str] = None,
profiler: Optional[WorkflowsProfiler] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "ExecutionEngine":
requested_engine_version = retrieve_requested_execution_engine_version(
workflow_definition=workflow_definition,
Expand All @@ -51,6 +53,7 @@ def init(
prevent_local_images_loading=prevent_local_images_loading,
workflow_id=workflow_id,
profiler=profiler,
executor=executor,
)
return cls(engine=engine)

Expand Down
2 changes: 2 additions & 0 deletions inference/core/workflows/execution_engine/entities/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from inference.core.workflows.execution_engine.profiling.core import WorkflowsProfiler
Expand All @@ -16,6 +17,7 @@ def init(
prevent_local_images_loading: bool = False,
workflow_id: Optional[str] = None,
profiler: Optional[WorkflowsProfiler] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "BaseExecutionEngine":
pass

Expand Down
6 changes: 6 additions & 0 deletions inference/core/workflows/execution_engine/v1/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from packaging.version import Version
Expand Down Expand Up @@ -36,6 +37,7 @@ def init(
prevent_local_images_loading: bool = False,
workflow_id: Optional[str] = None,
profiler: Optional[WorkflowsProfiler] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "ExecutionEngineV1":
if init_parameters is None:
init_parameters = {}
Expand All @@ -54,6 +56,7 @@ def init(
profiler=profiler,
workflow_id=workflow_id,
internal_id=workflow_definition.get("id"),
executor=executor,
)

def __init__(
Expand All @@ -64,13 +67,15 @@ def __init__(
profiler: WorkflowsProfiler,
workflow_id: Optional[str] = None,
internal_id: Optional[str] = None,
executor: Optional[ThreadPoolExecutor] = None,
):
self._compiled_workflow = compiled_workflow
self._max_concurrent_steps = max_concurrent_steps
self._prevent_local_images_loading = prevent_local_images_loading
self._workflow_id = workflow_id
self._profiler = profiler
self._internal_id = internal_id
self._executor = executor

def run(
self,
Expand Down Expand Up @@ -109,6 +114,7 @@ def run(
kinds_serializers=self._compiled_workflow.kinds_serializers,
serialize_results=serialize_results,
profiler=self._profiler,
executor=self._executor,
)
self._profiler.end_workflow_run()
return result
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set
Expand Down Expand Up @@ -47,6 +48,7 @@ def run_workflow(
kinds_serializers: Optional[Dict[str, Callable[[Any], Any]]],
serialize_results: bool = False,
profiler: Optional[WorkflowsProfiler] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> List[Dict[str, Any]]:
execution_data_manager = ExecutionDataManager.init(
execution_graph=workflow.execution_graph,
Expand All @@ -63,6 +65,7 @@ def run_workflow(
execution_data_manager=execution_data_manager,
max_concurrent_steps=max_concurrent_steps,
profiler=profiler,
executor=executor,
)
next_steps = execution_coordinator.get_steps_to_execute_next(profiler=profiler)
with profiler.profile_execution_phase(
Expand All @@ -89,6 +92,7 @@ def execute_steps(
execution_data_manager: ExecutionDataManager,
max_concurrent_steps: int,
profiler: Optional[WorkflowsProfiler] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> None:
logger.info(f"Executing steps: {next_steps}.")
steps_functions = [
Expand All @@ -101,7 +105,9 @@ def execute_steps(
)
for step_selector in next_steps
]
_ = run_steps_in_parallel(steps=steps_functions, max_workers=max_concurrent_steps)
_ = run_steps_in_parallel(
steps=steps_functions, max_workers=max_concurrent_steps, executor=executor
)


@execution_phase(
Expand Down
31 changes: 26 additions & 5 deletions inference/core/workflows/execution_engine/v1/executor/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
import concurrent
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, TypeVar
from typing import Callable, Generator, Iterable, List, Optional, TypeVar

T = TypeVar("T")


def run_steps_in_parallel(
steps: List[Callable[[], T]], max_workers: int = 1
steps: List[Callable[[], T]],
max_workers: int = 1,
executor: Optional[ThreadPoolExecutor] = None,
) -> List[T]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(_run, steps))
if executor is None:
with ThreadPoolExecutor(max_workers=max_workers) as inner_executor:
return list(inner_executor.map(_run, steps))
results = []
for batch in create_batches(sequence=steps, batch_size=max_workers):
batch_results = list(executor.map(_run, batch))
results.extend(batch_results)
return results


def create_batches(
sequence: Iterable[T], batch_size: int
) -> Generator[List[T], None, None]:
batch_size = max(batch_size, 1)
current_batch = []
for element in sequence:
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
current_batch.append(element)
if len(current_batch) > 0:
yield current_batch


def _run(fun: Callable[[], T]) -> T:
Expand Down
Loading

0 comments on commit 84f6b62

Please sign in to comment.