diff --git a/.github/workflows/fluxus-release-pipeline.yaml b/.github/workflows/fluxus-release-pipeline.yaml index a3b2cbc..684492f 100644 --- a/.github/workflows/fluxus-release-pipeline.yaml +++ b/.github/workflows/fluxus-release-pipeline.yaml @@ -461,6 +461,8 @@ jobs: if: (startsWith(github.head_ref, 'dev/') && startsWith(github.base_ref, 'release/')) || startsWith(github.ref, 'refs/heads/release/') needs: - check_release + - veracode_check + - unit_tests - conda_tox_matrix steps: - name: Checkout code diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff64df9..12fa43b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: exclude: condabuild/meta.yaml - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.11.0 hooks: - id: mypy entry: mypy src/ test/ diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 253c68e..1542665 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -4,6 +4,21 @@ Release Notes *fluxus* 1.0 ------------ +*fluxus* 1.0.3 +~~~~~~~~~~~~~~ + +- API: Rename methods ``iter_concurrent_conduits`` to ``iter_concurrent_producers``, + and ``aiter_concurrent_conduits`` to ``aiter_concurrent_producers`` +- API: Return iterators not lists from :meth:`.SerialTransformer.process` and + :meth:`.SerialTransformer.aprocess` for greater flexibility in processing the results +- API: Removed functions ``iter()`` and ``aiter()`` from class + :class:`.SerialTransformer`, to further streamline the API and given they can be + easily replaced by repeated calls to :meth:`.SerialTransformer.process` and + :meth:`.SerialTransformer.aprocess` +- FIX: Updated logic for iterating over concurrent producers and transformers to ensure + shared conduits never run more than once + + *fluxus* 1.0.2 ~~~~~~~~~~~~~~ diff --git a/src/fluxus/_passthrough.py b/src/fluxus/_passthrough.py index ae9f3c3..d2916d4 100644 --- a/src/fluxus/_passthrough.py +++ b/src/fluxus/_passthrough.py @@ -70,6 +70,11 @@ def final_conduit(self) -> Never: """[see superclass]""" raise NotImplementedError("Final conduit is not defined for passthroughs") + @property + def chained_conduits(self) -> Never: + """[see superclass]""" + raise NotImplementedError("Chained conduits are not defined for passthroughs") + def get_final_conduits(self) -> Iterator[Never]: """ Returns an empty iterator since passthroughs do not define a final conduit. @@ -91,6 +96,16 @@ def get_connections(self, *, ingoing: Collection[SerialConduit[Any]]) -> Never: :param ingoing: the ingoing conduits (ignored) :return: nothing; passthroughs do not define connections - :raises NotImplementedError: passthroughs do not define connections + :raises NotImplementedError: connections are not defined for passthroughs """ raise NotImplementedError("Connections are not defined for passthroughs") + + def get_isolated_conduits(self) -> Never: + """ + Fails with a :class:`NotImplementedError` since passthroughs are transparent in + flows and therefore isolated conduits are not defined. + + :return: nothing; passthroughs do not define isolated conduits + :raises NotImplementedError: isolated conduits are not defined for passthroughs + """ + raise NotImplementedError("Isolated conduits are not defined for passthroughs") diff --git a/src/fluxus/_transformer.py b/src/fluxus/_transformer.py index f51f437..ae286d8 100644 --- a/src/fluxus/_transformer.py +++ b/src/fluxus/_transformer.py @@ -23,12 +23,10 @@ import logging from abc import ABCMeta, abstractmethod from collections.abc import AsyncIterator, Iterator -from typing import Any, Generic, TypeVar, final +from typing import Generic, TypeVar, final from pytools.asyncio import arun, iter_async_to_sync -from pytools.typing import issubclass_generic -from ._passthrough import Passthrough from .core import AtomicConduit from .core.transformer import SerialTransformer @@ -111,30 +109,3 @@ def atransform( :param source_product: the existing product to use as input :return: the new product """ - - -# -# Auxiliary functions -# - - -def _validate_concurrent_passthrough( - conduit: SerialTransformer[Any, Any] | Passthrough -) -> None: - """ - Validate that the given conduit is valid as a concurrent conduit with a passthrough. - - To be valid, its input type must be a subtype of its product type. - - :param conduit: the conduit to validate - """ - - if not ( - isinstance(conduit, Passthrough) - or issubclass_generic(conduit.input_type, conduit.product_type) - ): - raise TypeError( - "Conduit is not a valid concurrent conduit with a passthrough because its " - f"input type {conduit.input_type} is not a subtype of its product type " - f"{conduit.product_type}:\n{conduit}" - ) diff --git a/src/fluxus/core/_base.py b/src/fluxus/core/_base.py index 4757aac..a539e5d 100644 --- a/src/fluxus/core/_base.py +++ b/src/fluxus/core/_base.py @@ -25,13 +25,18 @@ import logging from abc import ABCMeta, abstractmethod -from collections.abc import AsyncIterable, Collection, Iterable, Iterator -from typing import Any, Generic, TypeVar, cast - -from typing_extensions import Self +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Collection, + Iterable, + Iterator, +) +from typing import Any, Generic, TypeVar, cast, final from pytools.api import inheritdoc -from pytools.typing import get_common_generic_base, issubclass_generic +from pytools.typing import issubclass_generic from ._conduit import Conduit, SerialConduit @@ -65,17 +70,11 @@ class Source(Conduit[T_Product_ret], Generic[T_Product_ret], metaclass=ABCMeta): """ @property + @abstractmethod def product_type(self) -> type[T_Product_ret]: """ The type of the products produced by this conduit. """ - from .. import Passthrough - - return get_common_generic_base( - cast(SerialSource[T_Product_ret], source).product_type - for source in self.iter_concurrent_conduits() - if not isinstance(source, Passthrough) - ) class Processor( @@ -100,7 +99,7 @@ def input_type(self) -> type[T_SourceProduct_arg]: @abstractmethod def process( self, input: Iterable[T_SourceProduct_arg] - ) -> list[T_Output_ret] | T_Output_ret: + ) -> Iterator[T_Output_ret] | T_Output_ret: """ Generate new products from the given input. @@ -109,9 +108,9 @@ def process( """ @abstractmethod - async def aprocess( + def aprocess( self, input: AsyncIterable[T_SourceProduct_arg] - ) -> list[T_Output_ret] | T_Output_ret: + ) -> AsyncIterator[T_Output_ret] | Awaitable[T_Output_ret]: """ Generate new products asynchronously from the given input. @@ -119,6 +118,7 @@ async def aprocess( :return: the generated output or outputs """ + @abstractmethod def is_valid_source( self, source: SerialConduit[T_SourceProduct_arg], @@ -133,20 +133,6 @@ def is_valid_source( :return: ``True`` if the given conduit is valid source for this conduit, ``False`` otherwise """ - from .. import Passthrough - - if not isinstance(source, SerialSource): - return False - - ingoing_product_type = source.product_type - return all( - issubclass_generic( - ingoing_product_type, - cast(Self, processor).input_type, - ) - for processor in self.iter_concurrent_conduits() - if not isinstance(processor, Passthrough) - ) class SerialSource( @@ -178,6 +164,17 @@ class SerialProcessor( A processor that processes products sequentially. """ + @final + def is_valid_source( + self, + source: SerialConduit[T_SourceProduct_arg], + ) -> bool: + """[see superclass]""" + if not isinstance(source, SerialSource): + return False + + return issubclass_generic(source.product_type, self.input_type) + def get_connections( self, *, ingoing: Collection[SerialConduit[Any]] ) -> Iterator[tuple[SerialConduit[Any], SerialConduit[Any]]]: diff --git a/src/fluxus/core/_chained_base_.py b/src/fluxus/core/_chained_base_.py index cb566fa..d2516b8 100644 --- a/src/fluxus/core/_chained_base_.py +++ b/src/fluxus/core/_chained_base_.py @@ -68,26 +68,26 @@ def is_chained(self) -> bool: @property @abstractmethod - def _source(self) -> Source[T_SourceProduct_ret]: + def source(self) -> Source[T_SourceProduct_ret]: """ The source producer of this conduit. """ @property @abstractmethod - def _processor(self) -> Processor[T_SourceProduct_ret, T_Output_ret]: + def processor(self) -> Processor[T_SourceProduct_ret, T_Output_ret]: """ The second conduit in this chained conduit, processing the output of the - :attr:`._source` conduit. + :attr:`.source` conduit. """ def get_final_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: """[see superclass]""" - if self._processor._has_passthrough: + if self.processor._has_passthrough: yield from cast( - Iterator[SerialConduit[T_Output_ret]], self._source.get_final_conduits() + Iterator[SerialConduit[T_Output_ret]], self.source.get_final_conduits() ) - yield from self._processor.get_final_conduits() + yield from self.processor.get_final_conduits() def get_connections( self, *, ingoing: Collection[SerialConduit[Any]] @@ -97,8 +97,8 @@ def get_connections( :return: an iterable of connections """ - source = self._source - processor = self._processor + source = self.source + processor = self.processor # We first yield all connections from within the source yield from source.get_connections(ingoing=ingoing) @@ -113,11 +113,16 @@ def get_connections( # Then we get all connections of the processor, including ingoing connections yield from processor.get_connections(ingoing=processor_ingoing) + def get_isolated_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: + """[see superclass]""" + # Chained conduits are never isolated + yield from () + def to_expression(self, *, compact: bool = False) -> Expression: """[see superclass]""" - return self._source.to_expression( + return self.source.to_expression( compact=compact - ) >> self._processor.to_expression(compact=compact) + ) >> self.processor.to_expression(compact=compact) class _SerialChainedConduit( @@ -132,7 +137,7 @@ class _SerialChainedConduit( @property @abstractmethod - def _source(self) -> SerialSource[T_SourceProduct_ret]: + def source(self) -> SerialSource[T_SourceProduct_ret]: """[see superclass]""" @property @@ -140,5 +145,5 @@ def chained_conduits(self) -> Iterator[SerialConduit[Any]]: """ The chained conduits in the flow leading up to this conduit. """ - yield from self._source.chained_conduits + yield from self.source.chained_conduits yield self.final_conduit diff --git a/src/fluxus/core/_concurrent.py b/src/fluxus/core/_concurrent.py index e6587af..47ef640 100644 --- a/src/fluxus/core/_concurrent.py +++ b/src/fluxus/core/_concurrent.py @@ -26,8 +26,6 @@ from typing_extensions import Self -from pytools.api import inheritdoc - from ._conduit import Conduit log = logging.getLogger(__name__) @@ -52,7 +50,6 @@ # -@inheritdoc(match="[see superclass]") class ConcurrentConduit( Conduit[T_Product_ret], Generic[T_Product_ret], metaclass=ABCMeta ): @@ -71,11 +68,6 @@ def is_concurrent(self) -> bool: """ return True - @property - def is_chained(self) -> bool: - """[see superclass]""" - return any(conduit.is_chained for conduit in self.iter_concurrent_conduits()) - @property def final_conduit(self) -> Self: """ diff --git a/src/fluxus/core/_conduit.py b/src/fluxus/core/_conduit.py index 043b982..8fad2d1 100644 --- a/src/fluxus/core/_conduit.py +++ b/src/fluxus/core/_conduit.py @@ -22,7 +22,7 @@ import logging from abc import ABCMeta, abstractmethod -from collections.abc import AsyncIterator, Collection, Iterator, Mapping +from collections.abc import Collection, Iterator, Mapping from typing import Any, Generic, TypeVar, final from typing_extensions import Self @@ -133,25 +133,6 @@ def n_concurrent_conduits(self) -> int: The number of concurrent conduits in this conduit. """ - @abstractmethod - def iter_concurrent_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: - """ - Iterate over the concurrent conduits that make up this conduit. - - :return: an iterator over the concurrent conduits - """ - - async def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[SerialConduit[T_Output_ret]]: - """ - Asynchronously iterate over the concurrent conduits that make up this conduit. - - :return: an asynchronous iterator over the concurrent conduits - """ - for conduit in self.iter_concurrent_conduits(): - yield conduit - def draw(self, style: str = "graph") -> None: """ Draw the flow. @@ -181,6 +162,17 @@ def get_connections( :return: an iterator yielding connections between conduits """ + @abstractmethod + def get_isolated_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: + """ + Get an iterator yielding the isolated conduits in this conduit. + + An isolated conduit is a conduit that is not connected to any other conduit in + the flow. + + :return: an iterator yielding the isolated conduits + """ + def _repr_svg_(self) -> str: # pragma: no cover """ Get the SVG representation of the flow. @@ -279,25 +271,8 @@ def n_concurrent_conduits(self) -> int: """ return 1 - def iter_concurrent_conduits(self) -> Iterator[Self]: - """ - Yields ``self``, since this is a serial conduit and is not made up of concurrent - conduits. - - :return: an iterator with ``self`` as the only element - """ - yield self - - async def aiter_concurrent_conduits(self: Self) -> AsyncIterator[Self]: - """ - Yields ``self``, since this is a serial conduit and is not made up of concurrent - conduits. - - :return: an asynchronous iterator with ``self`` as the only element - """ - yield self - @property + @abstractmethod def chained_conduits(self) -> Iterator[SerialConduit[T_Product_ret]]: """ An iterator yielding the chained conduits that make up this conduit, starting @@ -305,7 +280,6 @@ def chained_conduits(self) -> Iterator[SerialConduit[T_Product_ret]]: For atomic conduit, yields the conduit itself. """ - yield self.final_conduit def get_repr_attributes(self) -> Mapping[str, Any]: """ @@ -350,7 +324,17 @@ def final_conduit(self) -> Self: """ return self + @property + @final + def chained_conduits(self) -> Iterator[SerialConduit[T_Product_ret]]: + """[see superclass]""" + yield self.final_conduit + @final def get_final_conduits(self) -> Iterator[Self]: """[see superclass]""" yield self + + def get_isolated_conduits(self) -> Iterator[SerialConduit[T_Product_ret]]: + """[see superclass]""" + yield self diff --git a/src/fluxus/core/producer/_chained_.py b/src/fluxus/core/producer/_chained_.py index 500ecc1..54aabef 100644 --- a/src/fluxus/core/producer/_chained_.py +++ b/src/fluxus/core/producer/_chained_.py @@ -21,11 +21,11 @@ from __future__ import annotations import logging -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator from typing import Generic, TypeVar, final from pytools.api import inheritdoc -from pytools.asyncio import aenumerate, async_flatten +from pytools.asyncio import async_flatten, iter_sync_to_async from ..._consumer import Consumer from ..._flow import Flow @@ -99,14 +99,14 @@ def final_conduit(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: return self._consumer @property - def _source(self) -> SerialProducer[T_SourceProduct_ret]: + def source(self) -> SerialProducer[T_SourceProduct_ret]: """ The source producer. """ return self._producer @property - def _processor(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: + def processor(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: """ The final processor of this flow. """ @@ -156,9 +156,9 @@ def __init__( """ super().__init__() invalid_producers = { - type(producer.final_conduit).__name__ - for producer in producer.iter_concurrent_conduits() - if not consumer.is_valid_source(producer.final_conduit) + type(final_conduit).__name__ + for final_conduit in producer.get_final_conduits() + if not consumer.is_valid_source(final_conduit) } if invalid_producers: raise TypeError( @@ -181,12 +181,12 @@ def final_conduit(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: return self._consumer @property - def _source(self) -> BaseProducer[T_SourceProduct_ret]: + def source(self) -> BaseProducer[T_SourceProduct_ret]: """[see superclass]""" return self._producer @property - def _processor(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: + def processor(self) -> Consumer[T_SourceProduct_ret, T_Output_ret]: """[see superclass]""" return self._consumer @@ -195,20 +195,6 @@ def n_concurrent_conduits(self) -> int: """[see superclass]""" return self._producer.n_concurrent_conduits - def iter_concurrent_conduits( - self, - ) -> Iterator[_ProducerFlow[T_SourceProduct_ret, T_Output_ret]]: - """[see superclass]""" - for producer in self._producer.iter_concurrent_conduits(): - yield _ProducerFlow(producer=producer, consumer=self._consumer) - - async def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[_ProducerFlow[T_SourceProduct_ret, T_Output_ret]]: - """[see superclass]""" - async for producer in self._producer.aiter_concurrent_conduits(): - yield _ProducerFlow(producer=producer, consumer=self._consumer) - @final def run(self) -> T_Output_ret: """[see superclass]""" @@ -242,7 +228,7 @@ def _consume( """ return consumer.consume( (index, product) - for index, producer in enumerate(producer.iter_concurrent_conduits()) + for index, producer in enumerate(producer.iter_concurrent_producers()) for product in producer ) @@ -275,7 +261,7 @@ async def _annotate( async_flatten( _annotate(producer_index, producer) async for producer_index, producer in ( - aenumerate(producer.aiter_concurrent_conduits()) + iter_sync_to_async(enumerate(producer.iter_concurrent_producers())) ) ) ) diff --git a/src/fluxus/core/producer/_producer_base.py b/src/fluxus/core/producer/_producer_base.py index eb34e29..63b12a4 100644 --- a/src/fluxus/core/producer/_producer_base.py +++ b/src/fluxus/core/producer/_producer_base.py @@ -26,7 +26,7 @@ from typing import Generic, TypeVar, cast, final from pytools.api import inheritdoc -from pytools.asyncio import async_flatten +from pytools.asyncio import async_flatten, iter_sync_to_async from pytools.typing import get_common_generic_base from ..._consumer import Consumer @@ -57,7 +57,6 @@ # -@inheritdoc(match="[see superclass]") class BaseProducer(Source[T_Product_ret], Generic[T_Product_ret], metaclass=ABCMeta): """ A source that generates products from scratch – this is either a @@ -81,12 +80,13 @@ def aproduce(self) -> AsyncIterator[T_Product_ret]: """ @abstractmethod - def iter_concurrent_conduits(self) -> Iterator[SerialProducer[T_Product_ret]]: - """[see superclass]""" + def iter_concurrent_producers(self) -> Iterator[SerialProducer[T_Product_ret]]: + """ + Iterate over the concurrent producers that make up this (potentially) + composite producer. - @abstractmethod - def aiter_concurrent_conduits(self) -> AsyncIterator[SerialProducer[T_Product_ret]]: - """[see superclass]""" + :return: an iterator over the concurrent producers + """ @final def __iter__(self) -> Iterator[T_Product_ret]: @@ -107,7 +107,7 @@ def __and__( # indicate the type for static type checks return cast( ConcurrentProducer[T_Product_ret], - SimpleConcurrentProducer[ # type: ignore[misc] + SimpleConcurrentProducer[ # type: ignore[misc, operator] get_common_generic_base((self.product_type, other.product_type)) ](self, other), ) @@ -141,13 +141,7 @@ class SerialProducer( It can run synchronously or asynchronously. """ - def iter_concurrent_conduits(self) -> Iterator[SerialProducer[T_Product_ret]]: - """[see superclass]""" - yield self - - async def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[SerialProducer[T_Product_ret]]: + def iter_concurrent_producers(self) -> Iterator[SerialProducer[T_Product_ret]]: """[see superclass]""" yield self @@ -191,7 +185,7 @@ def produce(self) -> Iterator[T_Product_ret]: :return: an iterator of the new products """ - for producer in self.iter_concurrent_conduits(): + for producer in self.iter_concurrent_producers(): yield from producer def aproduce(self) -> AsyncIterator[T_Product_ret]: @@ -200,10 +194,8 @@ def aproduce(self) -> AsyncIterator[T_Product_ret]: :return: an async iterator of the new products """ - # create tasks for each producer - these need to be coroutines that materialize - # the producers - # noinspection PyTypeChecker return async_flatten( - producer.aproduce() async for producer in self.aiter_concurrent_conduits() + producer.aproduce() + async for producer in iter_sync_to_async(self.iter_concurrent_producers()) ) diff --git a/src/fluxus/core/producer/_simple.py b/src/fluxus/core/producer/_simple.py index fba387f..7f4df23 100644 --- a/src/fluxus/core/producer/_simple.py +++ b/src/fluxus/core/producer/_simple.py @@ -24,12 +24,12 @@ import itertools import logging import operator -from collections.abc import AsyncIterator, Collection, Iterator +from collections.abc import Collection, Iterator from typing import Any, Generic, TypeVar, cast from pytools.api import as_tuple, inheritdoc -from pytools.asyncio import async_flatten, iter_sync_to_async from pytools.expression import Expression +from pytools.typing import get_common_generic_base from ... import Passthrough from .. import SerialConduit @@ -96,6 +96,16 @@ def __init__( arg_name="producers", ) + @property + def product_type(self) -> type[T_SourceProduct_ret]: + """[see superclass]""" + return get_common_generic_base(source.product_type for source in self.producers) + + @property + def is_chained(self) -> bool: + """[see superclass]""" + return any(producer.is_chained for producer in self.producers) + @property def n_concurrent_conduits(self) -> int: """[see superclass]""" @@ -114,23 +124,17 @@ def get_connections( for producer in self.producers: yield from producer.get_connections(ingoing=ingoing) - def iter_concurrent_conduits( - self, - ) -> Iterator[SerialProducer[T_SourceProduct_ret]]: + def get_isolated_conduits(self) -> Iterator[SerialConduit[T_SourceProduct_ret]]: """[see superclass]""" - for prod in self.producers: - yield from prod.iter_concurrent_conduits() + for producer in self.producers: + yield from producer.get_isolated_conduits() - def aiter_concurrent_conduits( + def iter_concurrent_producers( self, - ) -> AsyncIterator[SerialProducer[T_SourceProduct_ret]]: + ) -> Iterator[SerialProducer[T_SourceProduct_ret]]: """[see superclass]""" - - # noinspection PyTypeChecker - return async_flatten( - prod.aiter_concurrent_conduits() - async for prod in iter_sync_to_async(self.producers) - ) + for prod in self.producers: + yield from prod.iter_concurrent_producers() def to_expression(self, *, compact: bool = False) -> Expression: """[see superclass]""" diff --git a/src/fluxus/core/transformer/_chained_.py b/src/fluxus/core/transformer/_chained_.py index 628d190..d4c0330 100644 --- a/src/fluxus/core/transformer/_chained_.py +++ b/src/fluxus/core/transformer/_chained_.py @@ -20,19 +20,16 @@ from __future__ import annotations -import asyncio import logging -from abc import ABCMeta -from collections.abc import AsyncIterable, AsyncIterator, Collection, Iterator -from typing import Any, Generic, Literal, TypeVar, cast +from collections.abc import AsyncIterator, Iterator +from typing import Generic, TypeVar, final from pytools.api import inheritdoc from pytools.asyncio import async_flatten -from ..._passthrough import Passthrough from .._base import Processor, Source from .._chained_base_ import _ChainedConduit, _SerialChainedConduit -from .._conduit import AtomicConduit, SerialConduit +from .._conduit import SerialConduit from ..producer import BaseProducer, ConcurrentProducer, SerialProducer from ._transformer_base import BaseTransformer, ConcurrentTransformer, SerialTransformer @@ -46,7 +43,6 @@ # _ret for covariant type variables used in return positions # _arg for contravariant type variables used in argument positions -T_Output_ret = TypeVar("T_Output_ret", covariant=True) T_Product_ret = TypeVar("T_Product_ret", covariant=True) T_TransformedProduct_ret = TypeVar("T_TransformedProduct_ret", covariant=True) T_SourceProduct_ret = TypeVar("T_SourceProduct_ret", covariant=True) @@ -108,12 +104,12 @@ def product_type(self) -> type[T_TransformedProduct_ret]: return self.transformer.product_type @property - def _source(self) -> SerialProducer[T_SourceProduct_ret]: + def source(self) -> SerialProducer[T_SourceProduct_ret]: """[see superclass]""" return self._producer @property - def _processor( + def processor( self, ) -> SerialTransformer[T_SourceProduct_ret, T_TransformedProduct_ret]: """[see superclass]""" @@ -121,11 +117,12 @@ def _processor( def produce(self) -> Iterator[T_TransformedProduct_ret]: """[see superclass]""" - return self.transformer.iter(self._producer) + return self.transformer.process(input=self._producer) def aproduce(self) -> AsyncIterator[T_TransformedProduct_ret]: """[see superclass]""" - return self.transformer.aiter(self._producer) + # noinspection PyTypeChecker + return self.transformer.aprocess(input=self._producer) @inheritdoc(match="[see superclass]") @@ -169,12 +166,12 @@ def product_type(self) -> type[T_TransformedProduct_ret]: return self.second.product_type @property - def _source(self) -> SerialTransformer[T_SourceProduct_arg, T_SourceProduct_ret]: + def source(self) -> SerialTransformer[T_SourceProduct_arg, T_SourceProduct_ret]: """[see superclass]""" return self.first @property - def _processor( + def processor( self, ) -> SerialTransformer[T_SourceProduct_ret, T_TransformedProduct_ret]: """[see superclass]""" @@ -231,12 +228,18 @@ def __init__( self.transformer = transformer @property - def _source(self) -> BaseProducer[T_SourceProduct_ret]: + @final + def product_type(self) -> type[T_TransformedProduct_ret]: + """[see superclass]""" + return self.transformer.product_type + + @property + def source(self) -> BaseProducer[T_SourceProduct_ret]: """[see superclass]""" return self._producer @property - def _processor(self) -> Processor[T_SourceProduct_ret, T_TransformedProduct_ret]: + def processor(self) -> Processor[T_SourceProduct_ret, T_TransformedProduct_ret]: """[see superclass]""" return self.transformer @@ -248,56 +251,13 @@ def n_concurrent_conduits(self) -> int: * self.transformer.n_concurrent_conduits ) - def iter_concurrent_conduits( + def iter_concurrent_producers( self, ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: """[see superclass]""" - def _iter_chained_producers( - tx: ( - SerialTransformer[T_SourceProduct_ret, T_TransformedProduct_ret] - | Passthrough - ), - ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: - if isinstance(tx, Passthrough): - # Passthrough does not change the type, so we can cast the input type - # to the type of the transformed product - yield from cast( - Iterator[SerialProducer[T_TransformedProduct_ret]], - self._source.iter_concurrent_conduits(), - ) - else: - for source in self._producer.iter_concurrent_conduits(): - yield _ChainedProducer(producer=source, transformer=tx) - - for transformer in self.transformer.iter_concurrent_conduits(): - yield from _iter_chained_producers(transformer) - - def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[SerialProducer[T_TransformedProduct_ret]]: - """[see superclass]""" - - async def _aiter_chained_producers( - tx: ( - SerialTransformer[T_SourceProduct_ret, T_TransformedProduct_ret] - | Passthrough - ), - ) -> AsyncIterator[SerialProducer[T_TransformedProduct_ret]]: - if isinstance(tx, Passthrough): - async for source in self._producer.aiter_concurrent_conduits(): - # Passthrough does not change the type, so we can cast the - # input type to the type of the transformed product - yield cast(SerialProducer[T_TransformedProduct_ret], source) - else: - async for source in self._producer.aiter_concurrent_conduits(): - yield _ChainedProducer(producer=source, transformer=tx) - - # noinspection PyTypeChecker - return async_flatten( - _aiter_chained_producers(transformer) - async for transformer in self.transformer.aiter_concurrent_conduits() - ) + for source in self._producer.iter_concurrent_producers(): + yield from self.transformer.iter_concurrent_producers(source=source) @inheritdoc(match="[see superclass]") @@ -327,71 +287,48 @@ class _ChainedConcurrentTransformedProducer( #: The source producer _producer: SerialProducer[T_SourceProduct_ret] - #: The transformer group to apply to the producer - transformer_group: BaseTransformer[T_SourceProduct_ret, T_Product_ret] + #: The transformer to apply to the producer + transformer: BaseTransformer[T_SourceProduct_ret, T_Product_ret] def __init__( self, *, source: SerialProducer[T_SourceProduct_ret], - transformer_group: BaseTransformer[T_SourceProduct_ret, T_Product_ret], + transformer: BaseTransformer[T_SourceProduct_ret, T_Product_ret], ) -> None: """ :param source: the producer to use as input - :param transformer_group: the transformer group to apply to the producer + :param transformer: the transformer to apply to the producer """ super().__init__() self._producer = source - self.transformer_group = transformer_group + self.transformer = transformer @property - def _source(self) -> SerialProducer[T_SourceProduct_ret]: + @final + def product_type(self) -> type[T_Product_ret]: """[see superclass]""" - return self._producer + return self.transformer.product_type @property - def _processor(self) -> BaseTransformer[T_SourceProduct_ret, T_Product_ret]: + def source(self) -> SerialProducer[T_SourceProduct_ret]: """[see superclass]""" - return self.transformer_group + return self._producer @property - def n_concurrent_conduits(self) -> int: + def processor(self) -> BaseTransformer[T_SourceProduct_ret, T_Product_ret]: """[see superclass]""" - return self.transformer_group.n_concurrent_conduits + return self.transformer - def iter_concurrent_conduits(self) -> Iterator[SerialProducer[T_Product_ret]]: + @property + def n_concurrent_conduits(self) -> int: """[see superclass]""" + return self.transformer.n_concurrent_conduits - # create one shared buffered producer for synchronous iteration - producer = _BufferedProducer(self._producer) - - # for synchronous iteration, we need to materialize the source products - for transformer in self.transformer_group.iter_concurrent_conduits(): - if isinstance(transformer, Passthrough): - # We cast to T_Product_ret, since the Passthrough does not change the - # type of the source - yield cast(SerialProducer[T_Product_ret], producer) - else: - yield producer >> transformer - - async def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[SerialProducer[T_Product_ret]]: + def iter_concurrent_producers(self) -> Iterator[SerialProducer[T_Product_ret]]: """[see superclass]""" - # Create parallel synchronized iterators for the source products - concurrent_producers = _AsyncBufferedProducer.create( - source=self._producer, n=self.transformer_group.n_concurrent_conduits - ) - - async for transformer in self.transformer_group.aiter_concurrent_conduits(): - producer = next(concurrent_producers) - if isinstance(transformer, Passthrough): - # We cast to T_Product_ret, since the Passthrough does not change the - # type of the source - yield cast(SerialProducer[T_Product_ret], producer) - else: - yield producer >> transformer + yield from self.transformer.iter_concurrent_producers(source=self._producer) @inheritdoc(match="[see superclass]") @@ -425,243 +362,39 @@ def __init__( self.second = second @property - def _source(self) -> Source[T_SourceProduct_ret]: + def input_type(self) -> type[T_SourceProduct_arg]: """[see superclass]""" - return self.first + return self.first.input_type @property - def _processor(self) -> Processor[T_SourceProduct_ret, T_TransformedProduct_ret]: + @final + def product_type(self) -> type[T_TransformedProduct_ret]: """[see superclass]""" - return self.second + return self.second.product_type @property - def n_concurrent_conduits(self) -> int: + def source(self) -> Source[T_SourceProduct_ret]: """[see superclass]""" - return self.first.n_concurrent_conduits * self.second.n_concurrent_conduits - - def iter_concurrent_conduits( - self, - ) -> Iterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: - """[see superclass]""" - for first in self.first.iter_concurrent_conduits(): - if isinstance(first, Passthrough): - yield from cast( - Iterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] - ], - self.second.iter_concurrent_conduits(), - ) - else: - for second in self.second.iter_concurrent_conduits(): - if isinstance(second, Passthrough): - yield cast( - SerialTransformer[ - T_SourceProduct_arg, T_TransformedProduct_ret - ], - first, - ) - else: - yield first >> second - - def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: - """[see superclass]""" - - async def _aiter( - first: ( - SerialTransformer[T_SourceProduct_arg, T_SourceProduct_ret] - | Passthrough - ) - ) -> AsyncIterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] - | Passthrough - ]: - if isinstance(first, Passthrough): - async for second in self.second.aiter_concurrent_conduits(): - yield cast( - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] - | Passthrough, - second, - ) - else: - async for second in self.second.aiter_concurrent_conduits(): - if isinstance(second, Passthrough): - yield cast( - SerialTransformer[ - T_SourceProduct_arg, T_TransformedProduct_ret - ], - first, - ) - else: - yield first >> second - - # noinspection PyTypeChecker - return async_flatten( - _aiter(first) async for first in self.first.aiter_concurrent_conduits() - ) - - -@inheritdoc(match="[see superclass]") -class _BaseBufferedProducer( - SerialProducer[T_Output_ret], Generic[T_Output_ret], metaclass=ABCMeta -): - """ - A producer that materializes the products of another producer - to allow multiple iterations over the same products. - """ - - source: SerialProducer[T_Output_ret] - _products: list[T_Output_ret] | None - - def __init__(self, source: SerialProducer[T_Output_ret]) -> None: - """ - :param source: the producer from which to buffer the products - """ - self.source = source + return self.first @property - def product_type(self) -> type[T_Output_ret]: + def processor(self) -> Processor[T_SourceProduct_ret, T_TransformedProduct_ret]: """[see superclass]""" - return self.source.product_type + return self.second - def get_final_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: + @property + def n_concurrent_conduits(self) -> int: """[see superclass]""" - return self.source.get_final_conduits() + return self.first.n_concurrent_conduits * self.second.n_concurrent_conduits - def get_connections( - self, *, ingoing: Collection[SerialConduit[Any]] - ) -> Iterator[tuple[SerialConduit[Any], SerialConduit[Any]]]: + @final + def is_valid_source(self, source: SerialConduit[T_SourceProduct_arg]) -> bool: """[see superclass]""" - return self.source.get_connections(ingoing=ingoing) - + return self.first.is_valid_source(source=source) -@inheritdoc(match="[see superclass]") -class _BufferedProducer( - AtomicConduit[T_Output_ret], - _BaseBufferedProducer[T_Output_ret], - Generic[T_Output_ret], -): - """ - A producer that materializes the products of another producer - to allow multiple iterations over the same products. - """ - - source: SerialProducer[T_Output_ret] - _products: list[T_Output_ret] | None = None - - def produce(self) -> Iterator[T_Output_ret]: + def iter_concurrent_producers( + self, *, source: SerialProducer[T_SourceProduct_arg] + ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: """[see superclass]""" - if self._products is None: - self._products = list(self.source.produce()) - return iter(self._products) - - -class _AsyncBufferedProducer( - AtomicConduit[T_Output_ret], - _BaseBufferedProducer[T_Output_ret], - Generic[T_Output_ret], -): - """ - A producer that creates multiple synchronized iterators over the same products, - allowing multiple asynchronous iterations over the same products where each - iteration blocks until all iterators have processed the current item, ensuring - that the original producer is only iterated once. - """ - - products: AsyncIterator[T_Output_ret] - k: int - - def __init__( - self, - *, - source: SerialProducer[T_Output_ret], - products: AsyncIterator[T_Output_ret], - k: int, - ) -> None: - super().__init__(source=source) - self.products = products - self.k = k - - @classmethod - def create( - cls, source: SerialProducer[T_Output_ret], *, n: int - ) -> Iterator[_AsyncBufferedProducer[T_Output_ret]]: - """ - Create multiple synchronized asynchronous producers over the products of the - given source producer. - - :param source: the source producer - :param n: the number of synchronized producers to create - :return: the synchronized producers - """ - return ( - _AsyncBufferedProducer(source=source, products=products, k=k) - for k, products in enumerate(_async_iter_parallel(source.aproduce(), n)) - ) - - def produce(self) -> Iterator[T_Output_ret]: - raise NotImplementedError( - "Not implemented; use `aiter` to iterate asynchronously." - ) - - def aproduce(self) -> AsyncIterator[T_Output_ret]: - return self.products - - -# -# Auxiliary constants, functions and classes -# - -T = TypeVar("T") - -#: Tasks for the producer that need to be awaited before the producer is garbage -#: collected -_producer_tasks: set[asyncio.Task[Any]] = set() - -#: Sentinel to indicate the end of processing -_END: Literal["END"] = cast(Literal["END"], "END") - - -def _async_iter_parallel( - iterable: AsyncIterable[T], n: int -) -> Iterator[AsyncIterator[T]]: - # Create a given number of asynchronous iterators that share the same items. - - async def _shared_iterator( - queue: asyncio.Queue[T | Literal["END"]], - ) -> AsyncIterator[T]: - while True: - # Wait for the item to be available for this iterator - item = await queue.get() - if item is _END: - # The producer has finished - break - yield cast(T, item) - - async def _producer() -> None: - # Iterate over the items in the source iterable - async for item in iterable: - # Add the item to all queues - for queue in queues: - await queue.put(item) - # Notify all consumers that the producer has finished - for queue in queues: - await queue.put(_END) - - # Create a queue for each consumer - queues: list[asyncio.Queue[T | Literal["END"]]] = [ - asyncio.Queue() for _ in range(n) - ] - - # Start the producer task, and store a reference to it to prevent it from being - # garbage collected before it finishes - task = asyncio.create_task(_producer()) - _producer_tasks.add(task) - task.add_done_callback(_producer_tasks.remove) - - return (_shared_iterator(queue) for queue in queues) + for producer in self.first.iter_concurrent_producers(source=source): + yield from self.second.iter_concurrent_producers(source=producer) diff --git a/src/fluxus/core/transformer/_simple.py b/src/fluxus/core/transformer/_simple.py index 25729c2..69b535a 100644 --- a/src/fluxus/core/transformer/_simple.py +++ b/src/fluxus/core/transformer/_simple.py @@ -20,20 +20,24 @@ from __future__ import annotations +import asyncio import functools import itertools import logging import operator -from collections.abc import AsyncIterator, Collection, Iterator -from typing import Any, Generic, TypeVar, cast +from abc import ABCMeta +from collections import deque +from collections.abc import AsyncIterable, AsyncIterator, Collection, Iterator +from typing import Any, Generic, Literal, TypeVar, cast, final from pytools.api import as_tuple, inheritdoc -from pytools.asyncio import async_flatten, iter_sync_to_async from pytools.expression import Expression +from pytools.typing import get_common_generic_base, get_common_generic_subclass from ... import Passthrough -from .. import SerialConduit -from ._transformer_base import BaseTransformer, ConcurrentTransformer, SerialTransformer +from .. import AtomicConduit, SerialConduit +from ..producer import SerialProducer +from ._transformer_base import BaseTransformer, ConcurrentTransformer log = logging.getLogger(__name__) @@ -48,22 +52,18 @@ # _ret for covariant type variables used in return positions # _arg for contravariant type variables used in argument positions +T = TypeVar("T") +T_Output_ret = TypeVar("T_Output_ret", covariant=True) T_SourceProduct_arg = TypeVar("T_SourceProduct_arg", contravariant=True) T_TransformedProduct_ret = TypeVar("T_TransformedProduct_ret", covariant=True) -# -# Constants -# - -# The passthrough singleton instance. -_PASSTHROUGH = Passthrough() - # # Classes # +@final @inheritdoc(match="[see superclass]") class SimpleConcurrentTransformer( ConcurrentTransformer[T_SourceProduct_arg, T_TransformedProduct_ret], @@ -88,7 +88,7 @@ def __init__( """ :param transformers: the transformers in this group """ - self.transformers = as_tuple( + self.transformers = transformers = as_tuple( itertools.chain(*map(_flatten_concurrent_transformers, transformers)), element_type=cast( tuple[ @@ -102,6 +102,49 @@ def __init__( ), ) + input_types = { + transformer.input_type + for transformer in transformers + if not isinstance(transformer, Passthrough) + } + try: + self._input_type = get_common_generic_subclass(input_types) + except TypeError as e: + raise TypeError( + "Transformers have incompatible input types: " + + ", ".join(sorted(input_type.__name__ for input_type in input_types)) + ) from e + + product_types = { + transformer.product_type + for transformer in transformers + if not isinstance(transformer, Passthrough) + } + try: + self._product_type = get_common_generic_base(product_types) + except TypeError as e: + raise TypeError( + "Transformers have incompatible product types: " + + ", ".join( + sorted(product_type.__name__ for product_type in product_types) + ) + ) from e + + @property + def input_type(self) -> type[T_SourceProduct_arg]: + """[see superclass]""" + return self._input_type + + @property + def product_type(self) -> type[T_TransformedProduct_ret]: + """[see superclass]""" + return self._product_type + + @property + def is_chained(self) -> bool: + """[see superclass]""" + return any(transformer.is_chained for transformer in self.transformers) + @property def n_concurrent_conduits(self) -> int: """[see superclass]""" @@ -109,6 +152,14 @@ def n_concurrent_conduits(self) -> int: transformer.n_concurrent_conduits for transformer in self.transformers ) + def is_valid_source(self, source: SerialConduit[T_SourceProduct_arg]) -> bool: + """[see superclass]""" + return all( + transformer.is_valid_source(source=source) + for transformer in self.transformers + if not isinstance(transformer, Passthrough) + ) + def get_final_conduits(self) -> Iterator[SerialConduit[T_TransformedProduct_ret]]: """[see superclass]""" for transformer in self.transformers: @@ -124,29 +175,33 @@ def get_connections( ) -> Iterator[tuple[SerialConduit[Any], SerialConduit[Any]]]: """[see superclass]""" for transformer in self.transformers: - if transformer is not _PASSTHROUGH: + if not isinstance(transformer, Passthrough): yield from transformer.get_connections(ingoing=ingoing) - def iter_concurrent_conduits( + def get_isolated_conduits( self, - ) -> Iterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: + ) -> Iterator[SerialConduit[T_TransformedProduct_ret]]: """[see superclass]""" for transformer in self.transformers: - yield from transformer.iter_concurrent_conduits() + yield from transformer.get_isolated_conduits() - def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: - """[see superclass]""" - # noinspection PyTypeChecker - return async_flatten( - transformer.aiter_concurrent_conduits() - async for transformer in iter_sync_to_async(self.transformers) - ) + def iter_concurrent_producers( + self, *, source: SerialProducer[T_SourceProduct_arg] + ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: + """[see superclass]""" + + n_transformers = len(self.transformers) + + buffered_source = _BufferedSource(source=source, n=n_transformers) + + for k, transformer in enumerate(self.transformers): + buffered_producer = _BufferedProducer(source=buffered_source, k=k) + if isinstance(transformer, Passthrough): + yield cast(SerialProducer[T_TransformedProduct_ret], buffered_producer) + else: + yield from transformer.iter_concurrent_producers( + source=buffered_producer + ) def to_expression(self, *, compact: bool = False) -> Expression: """[see superclass]""" @@ -160,7 +215,7 @@ def to_expression(self, *, compact: bool = False) -> Expression: # -# Auxiliary functions +# Auxiliary constants, functions and classes # @@ -183,3 +238,176 @@ def _flatten_concurrent_transformers( yield from _flatten_concurrent_transformers(transformer) else: yield transformer + + +@inheritdoc(match="[see superclass]") +class _BaseBufferedProducer( + SerialProducer[T_Output_ret], Generic[T_Output_ret], metaclass=ABCMeta +): + """ + A producer that materializes the products of another producer + to allow multiple iterations over the same products. + """ + + source: SerialProducer[T_Output_ret] + + def __init__(self, source: SerialProducer[T_Output_ret]) -> None: + """ + :param source: the producer from which to buffer the products + """ + self.source = source + + @property + def product_type(self) -> type[T_Output_ret]: + """[see superclass]""" + return self.source.product_type + + def get_final_conduits(self) -> Iterator[SerialConduit[T_Output_ret]]: + """[see superclass]""" + return self.source.get_final_conduits() + + def get_connections( + self, *, ingoing: Collection[SerialConduit[Any]] + ) -> Iterator[tuple[SerialConduit[Any], SerialConduit[Any]]]: + """[see superclass]""" + return self.source.get_connections(ingoing=ingoing) + + +class _BufferedSource(Generic[T_Output_ret]): + """ + A source shared by multiple buffered producers. + """ + + source: SerialProducer[T_Output_ret] + n: int + + _products: list[deque[T_Output_ret]] | None = None + _products_async: list[AsyncIterator[T_Output_ret]] | None = None + + def __init__(self, source: SerialProducer[T_Output_ret], n: int) -> None: + """ + :param source: the source producer + :param n: the number of buffered producers + """ + self.source = source + self.n = n + + def get_products(self, k: int) -> Iterator[T_Output_ret]: + """ + Get the products of the source producer. + + :return: the products + """ + if self._products is None: + products = list(self.source.produce()) + self._products = [deque(products) for _ in range(self.n)] + + my_deque = self._products[k] + while my_deque: + yield my_deque.popleft() + + def get_products_async(self, k: int) -> AsyncIterator[T_Output_ret]: + """ + Get the products of the source producer asynchronously. + + :param k: the index of the buffered producer + :return: the k-th async iterator over the products + """ + + if self._products_async is None: + self._products_async = list( + _async_iter_parallel(self.source.aproduce(), self.n) + ) + return self._products_async[k] + + +@inheritdoc(match="[see superclass]") +class _BufferedProducer( + AtomicConduit[T_Output_ret], + SerialProducer[T_Output_ret], + Generic[T_Output_ret], +): + """ + A producer that creates multiple iterators over the same products, allowing multiple + synchronous or asynchronous iterations, ensuring that the original producer is only + iterated once. + """ + + source: _BufferedSource[T_Output_ret] + k: int + + def __init__(self, source: _BufferedSource[T_Output_ret], k: int) -> None: + """ + :param source: the source shared by multiple buffered producers + :param k: the index of this buffered producer + """ + self.source = source + self.k = k + + @property + def product_type(self) -> type[T_Output_ret]: + """[see superclass]""" + return self.source.source.product_type + + def produce(self) -> Iterator[T_Output_ret]: + """[see superclass]""" + return self.source.get_products(self.k) + + def aproduce(self) -> AsyncIterator[T_Output_ret]: + """[see superclass]""" + return self.source.get_products_async(self.k) + + def get_connections( + self, *, ingoing: Collection[SerialConduit[Any]] + ) -> Iterator[tuple[SerialConduit[Any], SerialConduit[Any]]]: + """[see superclass]""" + return self.source.source.get_connections(ingoing=ingoing) + + +#: Tasks for the producer that need to be awaited before the producer is garbage +#: collected +_producer_tasks: set[asyncio.Task[Any]] = set() + +#: Sentinel to indicate the end of processing +_END: Literal["END"] = cast(Literal["END"], "END") + + +def _async_iter_parallel( + iterable: AsyncIterable[T], n: int +) -> Iterator[AsyncIterator[T]]: + # Create a given number of asynchronous iterators that share the same items + # from the given source iterable. + + async def _shared_iterator( + queue: asyncio.Queue[T | Literal["END"]], + ) -> AsyncIterator[T]: + while True: + # Wait for the item to be available for this iterator + item = await queue.get() + if item is _END: + # The producer has finished + break + yield cast(T, item) + + async def _producer() -> None: + # Iterate over the items in the source iterable + async for item in iterable: + # Add the item to all queues + for queue in queues: + await queue.put(item) + # Notify all consumers that the producer has finished + for queue in queues: + await queue.put(_END) + + # Create a queue for each consumer + queues: list[asyncio.Queue[T | Literal["END"]]] = [ + asyncio.Queue() for _ in range(n) + ] + + # Start the producer task, and store a reference to it to prevent it from being + # garbage collected before it finishes + task = asyncio.create_task(_producer()) + _producer_tasks.add(task) + task.add_done_callback(_producer_tasks.remove) + + return (_shared_iterator(queue) for queue in queues) diff --git a/src/fluxus/core/transformer/_transformer_base.py b/src/fluxus/core/transformer/_transformer_base.py index b7c0988..e160253 100644 --- a/src/fluxus/core/transformer/_transformer_base.py +++ b/src/fluxus/core/transformer/_transformer_base.py @@ -65,7 +65,6 @@ # -@inheritdoc(match="[see superclass]") class BaseTransformer( Processor[T_SourceProduct_arg, T_TransformedProduct_ret], Source[T_TransformedProduct_ret], @@ -78,58 +77,16 @@ class BaseTransformer( """ @abstractmethod - def iter_concurrent_conduits( - self, - ) -> Iterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: - """[see superclass]""" - - @abstractmethod - def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] | Passthrough - ]: - """[see superclass]""" - - @final - def process( - self, input: Iterable[T_SourceProduct_arg] - ) -> list[T_TransformedProduct_ret]: + def iter_concurrent_producers( + self, *, source: SerialProducer[T_SourceProduct_arg] + ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: """ - Transform the given products. + Generate serial producers which, run concurrently, will produce all transformed + products. - :param input: the products to transform - :return: the transformed products + :param source: the source producer whose products to transform + :return: the concurrent producers for all concurrent paths of this transformer """ - from ...simple import SimpleProducer - - return list( - SimpleProducer[self.input_type](input) >> self # type: ignore[name-defined] - ) - - @final - async def aprocess( - self, input: AsyncIterable[T_SourceProduct_arg] - ) -> list[T_TransformedProduct_ret]: - """ - Transform the given products asynchronously. - - :param input: the products to transform - :return: the transformed products - """ - from ...simple import SimpleAsyncProducer - - return [ - product - async for product in ( - SimpleAsyncProducer[self.input_type]( # type: ignore[name-defined] - input - ) - >> self - ) - ] def __and__( self, @@ -141,8 +98,7 @@ def __and__( product_type: type[T_TransformedProduct_ret] if isinstance(other, Passthrough): - for transformer in self.iter_concurrent_conduits(): - _validate_concurrent_passthrough(transformer) + _validate_concurrent_passthrough(self) input_type = self.input_type product_type = self.product_type elif not isinstance(other, BaseTransformer): @@ -164,8 +120,7 @@ def __rand__( self, other: Passthrough ) -> BaseTransformer[T_SourceProduct_arg, T_TransformedProduct_ret]: if isinstance(other, Passthrough): - for transformer in self.iter_concurrent_conduits(): - _validate_concurrent_passthrough(transformer) + _validate_concurrent_passthrough(self) from . import SimpleConcurrentTransformer @@ -232,9 +187,7 @@ def __rrshift__( from ._chained_ import _ChainedConcurrentTransformedProducer # noinspection PyTypeChecker - return _ChainedConcurrentTransformedProducer( - source=other, transformer_group=self - ) + return _ChainedConcurrentTransformedProducer(source=other, transformer=self) elif isinstance(other, BaseProducer): from ._chained_ import _ChainedConcurrentProducer @@ -249,26 +202,32 @@ class SerialTransformer( SerialSource[T_TransformedProduct_ret], BaseTransformer[T_SourceProduct_arg, T_TransformedProduct_ret], Generic[T_SourceProduct_arg, T_TransformedProduct_ret], + metaclass=ABCMeta, ): """ A transformer that generates new products from the products of a producer. """ @final - def iter_concurrent_conduits( - self, - ) -> Iterator[SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret]]: + def iter_concurrent_producers( + self, *, source: SerialProducer[T_SourceProduct_arg] + ) -> Iterator[SerialProducer[T_TransformedProduct_ret]]: """[see superclass]""" - yield self + yield source >> self - @final - async def aiter_concurrent_conduits( - self, - ) -> AsyncIterator[ - SerialTransformer[T_SourceProduct_arg, T_TransformedProduct_ret] - ]: + def process( + self, input: Iterable[T_SourceProduct_arg] + ) -> Iterator[T_TransformedProduct_ret]: + """[see superclass]""" + for product in input: + yield from self.transform(product) + + def aprocess( + self, input: AsyncIterable[T_SourceProduct_arg] + ) -> AsyncIterator[T_TransformedProduct_ret]: """[see superclass]""" - yield self + # noinspection PyTypeChecker + return async_flatten(self.atransform(product) async for product in input) @abstractmethod def transform( @@ -295,30 +254,6 @@ async def atransform( for tx in self.transform(source_product): yield tx - def iter( - self, source: Iterable[T_SourceProduct_arg] - ) -> Iterator[T_TransformedProduct_ret]: - """ - Generate new products, using an existing producer as input. - - :param source: an existing producer to use as input (optional) - :return: the new products - """ - for product in source: - yield from self.transform(product) - - def aiter( - self, source: AsyncIterable[T_SourceProduct_arg] - ) -> AsyncIterator[T_TransformedProduct_ret]: - """ - Generate new products asynchronously, using an existing producer as input. - - :param source: an existing producer to use as input (optional) - :return: the new products - """ - # noinspection PyTypeChecker - return async_flatten(self.atransform(product) async for product in source) - @overload def __rshift__( self, @@ -383,7 +318,7 @@ def __rrshift__( def _validate_concurrent_passthrough( - conduit: SerialTransformer[Any, Any] | Passthrough + conduit: BaseTransformer[Any, Any] | Passthrough ) -> None: """ Validate that the given conduit is valid as a concurrent conduit with a passthrough. @@ -404,7 +339,6 @@ def _validate_concurrent_passthrough( ) -@inheritdoc(match="[see superclass]") class ConcurrentTransformer( BaseTransformer[T_SourceProduct_arg, T_TransformedProduct_ret], ConcurrentConduit[T_TransformedProduct_ret], @@ -415,11 +349,33 @@ class ConcurrentTransformer( A collection of one or more transformers, operating in parallel. """ - @property - def input_type(self) -> type[T_SourceProduct_arg]: - """[see superclass]""" - return get_common_generic_subclass( - transformer.input_type - for transformer in self.iter_concurrent_conduits() - if not isinstance(transformer, Passthrough) + def process( + self, input: Iterable[T_SourceProduct_arg] + ) -> Iterator[T_TransformedProduct_ret]: + """ + Transform the given products. + + :param input: the products to transform + :return: the transformed products + """ + from ...simple import SimpleProducer + + return iter( + SimpleProducer[self.input_type](input) >> self # type: ignore[name-defined] + ) + + def aprocess( + self, input: AsyncIterable[T_SourceProduct_arg] + ) -> AsyncIterator[T_TransformedProduct_ret]: + """ + Transform the given products asynchronously. + + :param input: the products to transform + :return: the transformed products + """ + from ...simple import SimpleAsyncProducer + + return aiter( + SimpleAsyncProducer[self.input_type](input) # type: ignore[name-defined] + >> self ) diff --git a/src/fluxus/viz/_graph.py b/src/fluxus/viz/_graph.py index 4295046..8eb181f 100644 --- a/src/fluxus/viz/_graph.py +++ b/src/fluxus/viz/_graph.py @@ -98,11 +98,7 @@ def from_conduit(cls, conduit: Conduit[Any]) -> FlowGraph: conduit.get_connections(ingoing=[]) ) - single_conduits: set[SerialConduit[Any]] = { - conduit - for conduit in conduit.iter_concurrent_conduits() - if conduit.is_atomic - } + single_conduits: set[SerialConduit[Any]] = set(conduit.get_isolated_conduits()) return FlowGraph(connections=connections, single_conduits=single_conduits) diff --git a/test/fluxus_test/test_flow.py b/test/fluxus_test/test_flow.py index 69d4823..5e5e015 100644 --- a/test/fluxus_test/test_flow.py +++ b/test/fluxus_test/test_flow.py @@ -9,7 +9,7 @@ import pytest -from fluxus import AsyncConsumer, Consumer, Passthrough, Producer, Transformer +from fluxus import AsyncConsumer, Consumer, Flow, Passthrough, Producer, Transformer from fluxus.core import Conduit from fluxus.core.producer import ConcurrentProducer from fluxus.core.transformer import BaseTransformer, ConcurrentTransformer @@ -65,6 +65,23 @@ def transform(self, source_product: int) -> Iterator[int]: yield source_product + 1 +class Counter(NumberTransformer): + """ + Ignores the input and outputs a number that increments each time transform is + called. + """ + + counter: int + + def __init__(self, start: int) -> None: + self.counter = start + + def transform(self, source_product: int) -> Iterator[int]: + value = self.counter + self.counter += 1 + yield value + + class StringConsumer(Consumer[str, str]): """ Concatenates all strings as separate lines. @@ -129,7 +146,7 @@ def test_group_construction() -> None: assert isinstance(producer_group, ConcurrentProducer) assert producer_group.n_concurrent_conduits == 2 producers = cast( - tuple[NumberProducer, ...], tuple(producer_group.iter_concurrent_conduits()) + tuple[NumberProducer, ...], tuple(producer_group.iter_concurrent_producers()) ) assert len(producers) == 2 assert all(isinstance(prod, NumberProducer) for prod in producers) @@ -143,9 +160,17 @@ def test_group_construction() -> None: ) assert isinstance(transformer_group, ConcurrentTransformer) assert transformer_group.n_concurrent_conduits == 2 - transformers = tuple(transformer_group.iter_concurrent_conduits()) + transformers = tuple( + transformer_group.iter_concurrent_producers(source=NumberProducer(0, 4)) + ) assert len(transformers) == 2 - assert isinstance(transformers[0], DoublingTransformer) + + # noinspection PyProtectedMember + from fluxus.core.transformer._simple import _BufferedProducer + + assert tuple( + type(conduit).__name__ for conduit in transformers[0].chained_conduits + ) == (_BufferedProducer.__name__, DoublingTransformer.__name__) def test_producer_group_construction() -> None: @@ -196,25 +221,25 @@ def test_chain_of_groups() -> None: ) expected_result = [ - [2, 2, 3, 4], - [12, 22], - [1, 1, 2, 3], - [11, 21], - [1, 1, 1, 1, 2, 3, 3, 5], - [11, 21, 21, 41], [0, 0, 0, 0, 1, 2, 2, 4], - [10, 20, 20, 40], - [2, 3], - [12], + [0, 0, 1, 2], + [1, 1, 1, 1, 2, 3, 3, 5], + [1, 1, 2, 3], + [1, 1, 2, 3], [1, 2], + [2, 2, 3, 4], + [2, 3], + [10, 20], + [10, 20, 20, 40], [11], - [1, 1, 2, 3], [11, 21], - [0, 0, 1, 2], - [10, 20], + [11, 21], + [11, 21, 21, 41], + [12], + [12, 22], ] - assert flow.run() == expected_result - assert asyncio.run(flow.arun()) == expected_result + assert sorted(flow.run()) == expected_result + assert sorted(asyncio.run(flow.arun())) == expected_result def test_expression() -> None: @@ -493,6 +518,37 @@ def test_flow_construction() -> None: ) +@pytest.mark.asyncio +async def test_shared_conduits() -> None: + # This test ensures that shared conduits are not called more than once + + def make_flow() -> Flow[list[list[int]]]: + return ( + NumberProducer(0, 1) + >> Counter(start=10) + >> ( + (Counter(start=100) >> (DoublingTransformer() & Passthrough())) + & (DoublingTransformer() >> (DoublingTransformer() & Passthrough())) + & Passthrough() + ) + >> NumberConsumer() + ) + + # The expected result is a list of lists, where each list contains the values + # produced by a single path through the flow. The test is designed to confirm + # that each counter is evaluated only once. + expected_result = [ + [100, 200], + [100], + [10, 20, 20, 40], + [10, 20], + [10], + ] + + assert make_flow().run() == expected_result + assert await make_flow().arun() == expected_result + + def test_large_flows() -> None: """ Test a flow with 1000 parallel steps.