Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove buffer_size config from source/pipe #396

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import inspect
import logging
import warnings
from collections.abc import (
AsyncIterable,
Callable,
Expand Down Expand Up @@ -83,7 +82,6 @@ def add_source(
source: Iterable[T] | AsyncIterable[T],
*,
queue_class: type[AsyncQueue[T]] | None = None,
**_kwargs, # pyre-ignore: [2]
) -> "PipelineBuilder[T, U]":
"""Attach an iterator to the source buffer.

Expand Down Expand Up @@ -113,15 +111,7 @@ def add_source(
if not (hasattr(source, "__aiter__") or hasattr(source, "__iter__")):
raise ValueError("Source must be either generator or async generator.")

# Note: Do not document this option.
# See `pipe` method for detail.
buffer_size = int(_kwargs.get("_buffer_size", 1))
if buffer_size < 1:
raise ValueError(
f"buffer_size must be greater than 0. Found: {buffer_size}"
)

self._src = _SourceConfig(source, buffer_size, queue_class)
self._src = _SourceConfig(source, queue_class)
return self

def pipe(
Expand All @@ -135,7 +125,6 @@ def pipe(
hooks: list[PipelineHook] | None = None,
output_order: str = "completion",
queue_class: type[AsyncQueue[U_]] | None = None,
**_kwargs, # pyre-ignore: [2]
) -> "PipelineBuilder[T, U]":
"""Apply an operation to items in the pipeline.

Expand Down Expand Up @@ -212,14 +201,6 @@ def pipe(
)
name_ = name or _get_op_name(op)

if (op_kwargs := _kwargs.get("kwargs")) is not None:
warnings.warn(
"`kwargs` argument is deprecated. "
"Please use `functools.partial` to bind function arguments.",
stacklevel=2,
)
op = partial(op, **op_kwargs) # pyre-ignore: [9]

type_ = _PType.Pipe if output_order == "completion" else _PType.OrderedPipe

self._process_args.append(
Expand All @@ -233,7 +214,6 @@ def pipe(
hooks=hooks,
),
queue_class=queue_class, # pyre-ignore: [6]
buffer_size=_kwargs.get("_buffer_size", 1),
# Note:
# `_buffer_size` option is intentionally not documented.
#
Expand Down Expand Up @@ -366,8 +346,6 @@ def _get_desc(self) -> list[str]:
src = self._src
src_repr = getattr(src.source, "__name__", type(src.source).__name__)
parts.append(f" - src: {src_repr}")
if src.buffer_size != 1:
parts.append(f" Buffer: buffer_size={src.buffer_size}")
else:
parts.append(" - src: n/a")

Expand All @@ -387,9 +365,6 @@ def _get_desc(self) -> list[str]:
part = str(cfg.type_)
parts.append(f" - {part}")

if cfg.buffer_size > 1:
parts.append(f" Buffer: buffer_size={cfg.buffer_size}")

if (sink := self._sink) is not None:
parts.append(f" - sink: buffer_size={sink.buffer_size}")

Expand Down
6 changes: 2 additions & 4 deletions src/spdl/pipeline/_components/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
@dataclass
class _SourceConfig(Generic[T]):
source: Iterable | AsyncIterable
buffer_size: int
queue_class: type[AsyncQueue[T]] | None


Expand All @@ -58,7 +57,6 @@ class _ProcessConfig(Generic[T, U]):
type_: _PType
args: _PipeArgs[T, U]
queue_class: type[AsyncQueue[U]] | None
buffer_size: int = 1


@dataclass
Expand Down Expand Up @@ -131,14 +129,14 @@ def _build_pipeline(

# source
queue_class = _get_queue(src.queue_class, report_stats_interval)
queues.append(queue_class("src_queue", src.buffer_size))
queues.append(queue_class("src_queue", 1))
coros.append(("AsyncPipeline::0_source", _source(src.source, queues[0])))

# pipes
for i, cfg in enumerate(process_args, start=1):
queue_class = _get_queue(cfg.queue_class, report_stats_interval)
queue_name = f"{cfg.args.name.split('(')[0]}_queue"
queues.append(queue_class(queue_name, cfg.buffer_size))
queues.append(queue_class(queue_name, 1))
in_queue, out_queue = queues[i - 1 : i + 1]

match cfg.type_:
Expand Down
85 changes: 3 additions & 82 deletions tests/spdl_unittest/dataloader/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import random
import threading
import time
from collections.abc import AsyncIterable, Iterable, Iterator
from collections.abc import Iterator
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from contextlib import asynccontextmanager
from functools import partial
Expand Down Expand Up @@ -1605,84 +1605,6 @@ async def op(i: int) -> int:
).add_sink(1).build(num_threads=1)


def test_pipeline_pipe_kwargs():
"""pipe can pass kwargs to op"""

def op(i: int, j: int) -> int:
return i + j

pipeline = (
PipelineBuilder()
.add_source(range(10))
.pipe(op, kwargs={"j": 2})
.add_sink(1)
.build(num_threads=1)
)

with pipeline.auto_stop():
vals = list(pipeline.get_iterator(timeout=3))
assert vals == [i + 2 for i in range(10)]


def test_pipeline_pipe_kwargs_async():
"""pipe can pass kwargs to op"""

async def op(i: int, j: int) -> int:
return i + j

pipeline = (
PipelineBuilder()
.add_source(range(10))
.pipe(op, kwargs={"j": 2})
.add_sink(1)
.build(num_threads=1)
)

with pipeline.auto_stop():
vals = list(pipeline.get_iterator(timeout=3))
assert vals == [i + 2 for i in range(10)]


def test_pipeline_pipe_kwargs_iter():
"""pipe can pass kwargs to op"""

def op(i: int, vals: list[int]) -> Iterable[int]:
for v in vals:
yield i + v

pipeline = (
PipelineBuilder()
.add_source([1])
.pipe(op, kwargs={"vals": list(range(10))})
.add_sink(1)
.build(num_threads=1)
)

with pipeline.auto_stop():
vals = list(pipeline.get_iterator(timeout=3))
assert vals == [1 + i for i in range(10)]


def test_pipeline_pipe_kwargs_async_iter():
"""pipe can pass kwargs to op"""

async def op(i: int, vals: list[int]) -> AsyncIterable[int]:
for v in vals:
yield i + v

pipeline = (
PipelineBuilder()
.add_source([1])
.pipe(op, kwargs={"vals": list(range(10))})
.add_sink(1)
.build(num_threads=1)
)

with pipeline.auto_stop():
vals = list(pipeline.get_iterator(timeout=3))
assert vals == [1 + i for i in range(10)]


class _PicklableSource:
def __init__(self, n: int) -> None:
self.n = n
Expand All @@ -1709,9 +1631,8 @@ def test_pipelinebuilder_picklable():
aplus1,
concurrency=3,
)
.pipe(plusN, kwargs={"N": 2}, hooks=[CountHook()])
.pipe(partial(plusN, N=3), hooks=[CountHook()])
.pipe(passthrough, report_stats_interval=4)
.pipe(passthrough)
.aggregate(3)
.disaggregate()
.add_sink(10)
Expand All @@ -1720,6 +1641,6 @@ def test_pipelinebuilder_picklable():
results = list(run_pipeline_in_subprocess(builder, num_threads=5, buffer_size=-1))

def _ref(x: int) -> int:
return 2 * x + 1 + 2 + 3
return 2 * x + 1 + 3

assert sorted(results) == [_ref(i) for i in range(10)]
Loading