diff --git a/src/spdl/pipeline/_builder.py b/src/spdl/pipeline/_builder.py index 11c8ae17..75cbe2e9 100644 --- a/src/spdl/pipeline/_builder.py +++ b/src/spdl/pipeline/_builder.py @@ -8,7 +8,6 @@ import inspect import logging -import warnings from collections.abc import ( AsyncIterable, Callable, @@ -83,7 +82,6 @@ def add_source( source: Iterable[T] | AsyncIterable[T], *, queue_class: type[AsyncQueue[T]] | None = None, - **_kwargs, # pyre-ignore: [2] ) -> "PipelineBuilder[T, U]": """Attach an iterator to the source buffer. @@ -113,15 +111,7 @@ def add_source( if not (hasattr(source, "__aiter__") or hasattr(source, "__iter__")): raise ValueError("Source must be either generator or async generator.") - # Note: Do not document this option. - # See `pipe` method for detail. - buffer_size = int(_kwargs.get("_buffer_size", 1)) - if buffer_size < 1: - raise ValueError( - f"buffer_size must be greater than 0. Found: {buffer_size}" - ) - - self._src = _SourceConfig(source, buffer_size, queue_class) + self._src = _SourceConfig(source, queue_class) return self def pipe( @@ -135,7 +125,6 @@ def pipe( hooks: list[PipelineHook] | None = None, output_order: str = "completion", queue_class: type[AsyncQueue[U_]] | None = None, - **_kwargs, # pyre-ignore: [2] ) -> "PipelineBuilder[T, U]": """Apply an operation to items in the pipeline. @@ -212,14 +201,6 @@ def pipe( ) name_ = name or _get_op_name(op) - if (op_kwargs := _kwargs.get("kwargs")) is not None: - warnings.warn( - "`kwargs` argument is deprecated. " - "Please use `functools.partial` to bind function arguments.", - stacklevel=2, - ) - op = partial(op, **op_kwargs) # pyre-ignore: [9] - type_ = _PType.Pipe if output_order == "completion" else _PType.OrderedPipe self._process_args.append( @@ -233,7 +214,6 @@ def pipe( hooks=hooks, ), queue_class=queue_class, # pyre-ignore: [6] - buffer_size=_kwargs.get("_buffer_size", 1), # Note: # `_buffer_size` option is intentionally not documented. # @@ -366,8 +346,6 @@ def _get_desc(self) -> list[str]: src = self._src src_repr = getattr(src.source, "__name__", type(src.source).__name__) parts.append(f" - src: {src_repr}") - if src.buffer_size != 1: - parts.append(f" Buffer: buffer_size={src.buffer_size}") else: parts.append(" - src: n/a") @@ -387,9 +365,6 @@ def _get_desc(self) -> list[str]: part = str(cfg.type_) parts.append(f" - {part}") - if cfg.buffer_size > 1: - parts.append(f" Buffer: buffer_size={cfg.buffer_size}") - if (sink := self._sink) is not None: parts.append(f" - sink: buffer_size={sink.buffer_size}") diff --git a/src/spdl/pipeline/_components/_build.py b/src/spdl/pipeline/_components/_build.py index f334ddbd..5eb10f08 100644 --- a/src/spdl/pipeline/_components/_build.py +++ b/src/spdl/pipeline/_components/_build.py @@ -42,7 +42,6 @@ @dataclass class _SourceConfig(Generic[T]): source: Iterable | AsyncIterable - buffer_size: int queue_class: type[AsyncQueue[T]] | None @@ -58,7 +57,6 @@ class _ProcessConfig(Generic[T, U]): type_: _PType args: _PipeArgs[T, U] queue_class: type[AsyncQueue[U]] | None - buffer_size: int = 1 @dataclass @@ -131,14 +129,14 @@ def _build_pipeline( # source queue_class = _get_queue(src.queue_class, report_stats_interval) - queues.append(queue_class("src_queue", src.buffer_size)) + queues.append(queue_class("src_queue", 1)) coros.append(("AsyncPipeline::0_source", _source(src.source, queues[0]))) # pipes for i, cfg in enumerate(process_args, start=1): queue_class = _get_queue(cfg.queue_class, report_stats_interval) queue_name = f"{cfg.args.name.split('(')[0]}_queue" - queues.append(queue_class(queue_name, cfg.buffer_size)) + queues.append(queue_class(queue_name, 1)) in_queue, out_queue = queues[i - 1 : i + 1] match cfg.type_: diff --git a/tests/spdl_unittest/dataloader/pipeline_test.py b/tests/spdl_unittest/dataloader/pipeline_test.py index 6fa867d3..f3a2e1d0 100644 --- a/tests/spdl_unittest/dataloader/pipeline_test.py +++ b/tests/spdl_unittest/dataloader/pipeline_test.py @@ -12,7 +12,7 @@ import random import threading import time -from collections.abc import AsyncIterable, Iterable, Iterator +from collections.abc import Iterator from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from contextlib import asynccontextmanager from functools import partial @@ -1605,84 +1605,6 @@ async def op(i: int) -> int: ).add_sink(1).build(num_threads=1) -def test_pipeline_pipe_kwargs(): - """pipe can pass kwargs to op""" - - def op(i: int, j: int) -> int: - return i + j - - pipeline = ( - PipelineBuilder() - .add_source(range(10)) - .pipe(op, kwargs={"j": 2}) - .add_sink(1) - .build(num_threads=1) - ) - - with pipeline.auto_stop(): - vals = list(pipeline.get_iterator(timeout=3)) - assert vals == [i + 2 for i in range(10)] - - -def test_pipeline_pipe_kwargs_async(): - """pipe can pass kwargs to op""" - - async def op(i: int, j: int) -> int: - return i + j - - pipeline = ( - PipelineBuilder() - .add_source(range(10)) - .pipe(op, kwargs={"j": 2}) - .add_sink(1) - .build(num_threads=1) - ) - - with pipeline.auto_stop(): - vals = list(pipeline.get_iterator(timeout=3)) - assert vals == [i + 2 for i in range(10)] - - -def test_pipeline_pipe_kwargs_iter(): - """pipe can pass kwargs to op""" - - def op(i: int, vals: list[int]) -> Iterable[int]: - for v in vals: - yield i + v - - pipeline = ( - PipelineBuilder() - .add_source([1]) - .pipe(op, kwargs={"vals": list(range(10))}) - .add_sink(1) - .build(num_threads=1) - ) - - with pipeline.auto_stop(): - vals = list(pipeline.get_iterator(timeout=3)) - assert vals == [1 + i for i in range(10)] - - -def test_pipeline_pipe_kwargs_async_iter(): - """pipe can pass kwargs to op""" - - async def op(i: int, vals: list[int]) -> AsyncIterable[int]: - for v in vals: - yield i + v - - pipeline = ( - PipelineBuilder() - .add_source([1]) - .pipe(op, kwargs={"vals": list(range(10))}) - .add_sink(1) - .build(num_threads=1) - ) - - with pipeline.auto_stop(): - vals = list(pipeline.get_iterator(timeout=3)) - assert vals == [1 + i for i in range(10)] - - class _PicklableSource: def __init__(self, n: int) -> None: self.n = n @@ -1709,9 +1631,8 @@ def test_pipelinebuilder_picklable(): aplus1, concurrency=3, ) - .pipe(plusN, kwargs={"N": 2}, hooks=[CountHook()]) .pipe(partial(plusN, N=3), hooks=[CountHook()]) - .pipe(passthrough, report_stats_interval=4) + .pipe(passthrough) .aggregate(3) .disaggregate() .add_sink(10) @@ -1720,6 +1641,6 @@ def test_pipelinebuilder_picklable(): results = list(run_pipeline_in_subprocess(builder, num_threads=5, buffer_size=-1)) def _ref(x: int) -> int: - return 2 * x + 1 + 2 + 3 + return 2 * x + 1 + 3 assert sorted(results) == [_ref(i) for i in range(10)]