Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for combined iterator lineage #2949

Merged
merged 24 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b931580
rename to generators
joeyballentine Jun 9, 2024
edcb012
change node kind from newIterator to generator
joeyballentine Jun 10, 2024
6d839f9
change more terminology
joeyballentine Jun 10, 2024
e0352b3
we might be getting somewhere
joeyballentine Jun 10, 2024
2c10659
fix timing
joeyballentine Jun 10, 2024
534113c
comment out validity blocker
joeyballentine Jun 10, 2024
2737b71
some more fixes
joeyballentine Jun 10, 2024
a794288
Split by lineage
joeyballentine Jun 10, 2024
b434e67
clean up commented code
joeyballentine Jun 10, 2024
29cf65b
more performant function
joeyballentine Jun 11, 2024
433a3a2
Enforce that all connected iterators share the same expected length
joeyballentine Jun 11, 2024
b4891e2
Add migration
joeyballentine Jun 11, 2024
c62f96e
remove load image pairs
joeyballentine Jun 11, 2024
40d6954
finalize validity rules
joeyballentine Jun 11, 2024
c1e0eaa
update snapshot
joeyballentine Jun 11, 2024
050975f
fix type errors
joeyballentine Jun 11, 2024
2e01a4a
gen_supplier -> supplier
joeyballentine Jun 11, 2024
7bdd211
supplier can still use Iterable
joeyballentine Jun 11, 2024
89babdf
move function, add doc comment
joeyballentine Jun 12, 2024
3d38e87
use typing.Iterator instead of typing.Generator
joeyballentine Jun 12, 2024
ecba60e
move function again, add unit tests
joeyballentine Jun 15, 2024
5cc1189
Add identity functions, use frozen sets
joeyballentine Jun 15, 2024
210cc47
slight refactor
joeyballentine Jun 15, 2024
c56c750
Merge branch 'main' into iterators-take-3
joeyballentine Jun 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import os
import typing
from dataclasses import asdict, dataclass, field
from typing import (
Any,
Expand Down Expand Up @@ -149,10 +150,10 @@ def to_list(x: list[S] | S | None) -> list[S]:
iterator_inputs = to_list(iterator_inputs)
iterator_outputs = to_list(iterator_outputs)

if kind == "collector":
assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0
elif kind == "newIterator":
if kind == "generator": # Generator
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 1
elif kind == "collector":
assert len(iterator_inputs) == 1 and len(iterator_outputs) == 0
else:
assert len(iterator_inputs) == 0 and len(iterator_outputs) == 0

Expand Down Expand Up @@ -188,8 +189,8 @@ def inner_wrapper(wrapped_func: T) -> T:
inputs=p_inputs,
group_layout=group_layout,
outputs=p_outputs,
iterator_inputs=iterator_inputs,
iterator_outputs=iterator_outputs,
iterable_inputs=iterator_inputs,
iterable_outputs=iterator_outputs,
key_info=key_info,
suggestions=suggestions or [],
side_effects=side_effects,
Expand Down Expand Up @@ -511,25 +512,25 @@ def add_package(


@dataclass
class Iterator(Generic[I]):
iter_supplier: Callable[[], Iterable[I | Exception]]
class Generator(Generic[I]):
supplier: Callable[[], typing.Iterator[I | Exception]]
expected_length: int
fail_fast: bool = True

@staticmethod
def from_iter(
iter_supplier: Callable[[], Iterable[I | Exception]],
supplier: Callable[[], typing.Iterator[I | Exception]],
expected_length: int,
fail_fast: bool = True,
) -> Iterator[I]:
return Iterator(iter_supplier, expected_length, fail_fast=fail_fast)
) -> Generator[I]:
return Generator(supplier, expected_length, fail_fast=fail_fast)

@staticmethod
def from_list(
l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True
) -> Iterator[I]:
) -> Generator[I]:
"""
Creates a new iterator from a list that is mapped using the given
Creates a new generator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
"""

Expand All @@ -540,14 +541,14 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, len(l), fail_fast=fail_fast)
return Generator(supplier, len(l), fail_fast=fail_fast)

@staticmethod
def from_range(
count: int, map_fn: Callable[[int], I], fail_fast: bool = True
) -> Iterator[I]:
) -> Generator[I]:
"""
Creates a new iterator the given number of items where each item is
Creates a new generator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
"""
assert count >= 0
Expand All @@ -559,7 +560,7 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, count, fail_fast=fail_fast)
return Generator(supplier, count, fail_fast=fail_fast)


N = TypeVar("N")
Expand Down
16 changes: 8 additions & 8 deletions backend/src/api/node_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class NodeData:
outputs: list[BaseOutput]
group_layout: list[InputId | NestedIdGroup]

iterator_inputs: list[IteratorInputInfo]
iterator_outputs: list[IteratorOutputInfo]
iterable_inputs: list[IteratorInputInfo]
iterable_outputs: list[IteratorOutputInfo]

key_info: KeyInfo | None
suggestions: list[SpecialSuggestion]
Expand All @@ -150,11 +150,11 @@ class NodeData:
run: RunFn

@property
def single_iterator_input(self) -> IteratorInputInfo:
assert len(self.iterator_inputs) == 1
return self.iterator_inputs[0]
def single_iterable_input(self) -> IteratorInputInfo:
assert len(self.iterable_inputs) == 1
return self.iterable_inputs[0]

@property
def single_iterator_output(self) -> IteratorOutputInfo:
assert len(self.iterator_outputs) == 1
return self.iterator_outputs[0]
def single_iterable_output(self) -> IteratorOutputInfo:
assert len(self.iterable_outputs) == 1
return self.iterable_outputs[0]
2 changes: 1 addition & 1 deletion backend/src/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

RunFn = Callable[..., Any]

NodeKind = Literal["regularNode", "newIterator", "collector"]
NodeKind = Literal["regularNode", "generator", "collector"]
6 changes: 3 additions & 3 deletions backend/src/chain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from api import NodeId

from .chain import Chain, Edge, FunctionNode, NewIteratorNode
from .chain import Chain, Edge, FunctionNode, GeneratorNode


class CacheStrategy:
Expand Down Expand Up @@ -54,9 +54,9 @@ def any_are_iterated(out_edges: list[Edge]) -> bool:
else:
# the node is NOT implicitly iterated

if isinstance(node, NewIteratorNode):
if isinstance(node, GeneratorNode):
# we only care about non-iterator outputs
iterator_output = node.data.single_iterator_output
iterator_output = node.data.single_iterable_output
out_edges = [
out_edge
for out_edge in out_edges
Expand Down
18 changes: 9 additions & 9 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def has_side_effects(self) -> bool:
return self.data.side_effects


class NewIteratorNode:
class GeneratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.data: NodeData = registry.get_node(schema_id)
assert self.data.kind == "newIterator"
assert self.data.kind == "generator"

def has_side_effects(self) -> bool:
return self.data.side_effects
Expand All @@ -50,7 +50,7 @@ def has_side_effects(self) -> bool:
return self.data.side_effects


Node = Union[FunctionNode, NewIteratorNode, CollectorNode]
Node = Union[FunctionNode, GeneratorNode, CollectorNode]


@dataclass(frozen=True)
Expand Down Expand Up @@ -176,28 +176,28 @@ def visit(node_id: NodeId):

return result

def get_parent_iterator_map(self) -> dict[FunctionNode, NewIteratorNode | None]:
def get_parent_iterator_map(self) -> dict[FunctionNode, GeneratorNode | None]:
"""
Returns a map of all function nodes to their parent iterator node (if any).
"""
iterator_cache: dict[FunctionNode, NewIteratorNode | None] = {}
iterator_cache: dict[FunctionNode, GeneratorNode | None] = {}

def get_iterator(r: FunctionNode) -> NewIteratorNode | None:
def get_iterator(r: FunctionNode) -> GeneratorNode | None:
if r in iterator_cache:
return iterator_cache[r]

iterator: NewIteratorNode | None = None
iterator: GeneratorNode | None = None

for in_edge in self.edges_to(r.id):
source = self.nodes[in_edge.source.id]
if isinstance(source, FunctionNode):
iterator = get_iterator(source)
if iterator is not None:
break
elif isinstance(source, NewIteratorNode):
elif isinstance(source, GeneratorNode):
if (
in_edge.source.output_id
in source.data.single_iterator_output.outputs
in source.data.single_iterable_output.outputs
):
iterator = source
break
Expand Down
6 changes: 3 additions & 3 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EdgeSource,
EdgeTarget,
FunctionNode,
NewIteratorNode,
GeneratorNode,
)


Expand Down Expand Up @@ -53,8 +53,8 @@ def parse_json(json: list[JsonNode]) -> Chain:
index_edges: list[IndexEdge] = []

for json_node in json:
if json_node["nodeType"] == "newIterator":
node = NewIteratorNode(json_node["id"], json_node["schemaId"])
if json_node["nodeType"] == "generator":
node = GeneratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "collector":
node = CollectorNode(json_node["id"], json_node["schemaId"])
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sanic.log import logger

from api import Iterator, IteratorOutputInfo
from api import Generator, IteratorOutputInfo
from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
Expand Down Expand Up @@ -46,12 +46,12 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
)
def load_models_node(
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[NcnnModelWrapper, str, str, int]], Path]:
) -> tuple[Generator[tuple[NcnnModelWrapper, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(filepath_pairs: tuple[Path, Path], index: int):
Expand Down Expand Up @@ -82,4 +82,4 @@ def load_model(filepath_pairs: tuple[Path, Path], index: int):

model_files = list(zip(param_files, bin_files))

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sanic.log import logger

from api import Iterator, IteratorOutputInfo
from api import Generator, IteratorOutputInfo
from nodes.impl.onnx.model import OnnxModel
from nodes.properties.inputs import BoolInput, DirectoryInput
from nodes.properties.outputs import (
Expand Down Expand Up @@ -46,12 +46,12 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
)
def load_models_node(
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[OnnxModel, str, str, int]], Path]:
) -> tuple[Generator[tuple[OnnxModel, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(path: Path, index: int):
Expand All @@ -63,4 +63,4 @@ def load_model(path: Path, index: int):
supported_filetypes = [".onnx"]
model_files = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sanic.log import logger
from spandrel import ModelDescriptor

from api import Iterator, IteratorOutputInfo, NodeContext
from api import Generator, IteratorOutputInfo, NodeContext
from nodes.properties.inputs import DirectoryInput
from nodes.properties.inputs.generic_inputs import BoolInput
from nodes.properties.outputs import DirectoryOutput, NumberOutput, TextOutput
Expand Down Expand Up @@ -43,14 +43,14 @@
),
],
iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]),
kind="newIterator",
kind="generator",
node_context=True,
)
def load_models_node(
context: NodeContext,
directory: Path,
fail_fast: bool,
) -> tuple[Iterator[tuple[ModelDescriptor, str, str, int]], Path]:
) -> tuple[Generator[tuple[ModelDescriptor, str, str, int]], Path]:
logger.debug(f"Iterating over models in directory: {directory}")

def load_model(path: Path, index: int):
Expand All @@ -62,4 +62,4 @@ def load_model(path: Path, index: int):
supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"]
model_files: list[Path] = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, fail_fast), directory
return Generator.from_list(model_files, load_model, fail_fast), directory
Loading
Loading