Skip to content

Commit

Permalink
fix: support batch expression everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 4, 2024
1 parent 62a609a commit ca23fe1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 77 deletions.
118 changes: 52 additions & 66 deletions edsnlp/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing_extensions import Literal

import edsnlp.data
from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify_fns
from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify, batchify_fns
from edsnlp.utils.collections import flatten, flatten_once, shuffle
from edsnlp.utils.stream_sentinels import StreamSentinel

Expand All @@ -47,25 +47,6 @@ def deep_isgeneratorfunction(x):
raise ValueError(f"{x} does not have a __call__ or batch_process method.")


class _InferType:
# Singleton is important since the INFER object may be passed to
# other processes, i.e. pickled, depickled, while it should
# always be the same object.
instance = None

def __repr__(self):
return "INFER"

def __new__(cls, *args, **kwargs):
if cls.instance is None:
cls.instance = super().__new__(cls)
return cls.instance

def __bool__(self):
return False


INFER = _InferType()
CONTEXT = [{}]

T = TypeVar("T")
Expand Down Expand Up @@ -125,8 +106,8 @@ def __init__(
):
if batch_fn is None:
if size is None:
size = INFER
batch_fn = INFER
size = None
batch_fn = None
else:
batch_fn = batchify_fns["docs"]
self.size = size
Expand Down Expand Up @@ -302,17 +283,13 @@ def validate_batching(cls, batch_size, batch_by):
"Cannot use both a batch_size expression and a batch_by function"
)
batch_size, batch_by = BatchSizeArg.validate(batch_size)
if (
batch_size is not None
and batch_size is not INFER
and not isinstance(batch_size, int)
):
if batch_size is not None and not isinstance(batch_size, int):
raise ValueError(
f"Invalid batch_size (must be an integer or None): {batch_size}"
)
if (
batch_by is not None
and batch_by is not INFER
and batch_by is not None
and batch_by not in batchify_fns
and not callable(batch_by)
):
Expand All @@ -321,11 +298,11 @@ def validate_batching(cls, batch_size, batch_by):

@property
def batch_size(self):
return self.config.get("batch_size", 1)
return self.config.get("batch_size", None)

@property
def batch_by(self):
return self.config.get("batch_by", "docs")
return self.config.get("batch_by", None)

@property
def disable_implicit_parallelism(self):
Expand Down Expand Up @@ -372,39 +349,36 @@ def deterministic(self):
@with_non_default_args
def set_processing(
self,
batch_size: int = INFER,
batch_by: BatchBy = "docs",
split_into_batches_after: str = INFER,
num_cpu_workers: Optional[int] = INFER,
num_gpu_workers: Optional[int] = INFER,
batch_size: Optional[Union[int, str]] = None,
batch_by: BatchBy = None,
split_into_batches_after: str = None,
num_cpu_workers: Optional[int] = None,
num_gpu_workers: Optional[int] = None,
disable_implicit_parallelism: bool = True,
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER,
autocast: Union[bool, Any] = INFER,
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = None,
autocast: Union[bool, Any] = None,
show_progress: bool = False,
gpu_pipe_names: Optional[List[str]] = INFER,
process_start_method: Optional[Literal["fork", "spawn"]] = INFER,
gpu_worker_devices: Optional[List[str]] = INFER,
cpu_worker_devices: Optional[List[str]] = INFER,
gpu_pipe_names: Optional[List[str]] = None,
process_start_method: Optional[Literal["fork", "spawn"]] = None,
gpu_worker_devices: Optional[List[str]] = None,
cpu_worker_devices: Optional[List[str]] = None,
deterministic: bool = True,
work_unit: Literal["record", "fragment"] = "record",
chunk_size: int = INFER,
chunk_size: int = None,
sort_chunks: bool = False,
_non_default_args: Iterable[str] = (),
) -> "Stream":
"""
Parameters
----------
batch_size: int
Number of documents to process at a time in a GPU worker (or in the
main process if no workers are used). This is the global batch size
that is used for batching methods that do not provide their own
batching arguments.
batch_size: Optional[Union[int, str]]
The batch size. Can also be a batching expression like
"32 docs", "1024 words", "dataset", "fragment", etc.
batch_by: BatchBy
Function to compute the batches. If set, it should take an iterable of
documents and return an iterable of batches. You can also set it to
"docs", "words" or "padded_words" to use predefined batching functions.
Defaults to "docs". Only used for operations that do not provide their
own batching arguments.
Defaults to "docs".
num_cpu_workers: int
Number of CPU workers. A CPU worker handles the non deep-learning components
and the preprocessing, collating and postprocessing of deep-learning
Expand Down Expand Up @@ -468,15 +442,15 @@ def set_processing(
"""
kwargs = {k: v for k, v in locals().items() if k in _non_default_args}
if (
kwargs.pop("chunk_size", INFER) is not INFER
or kwargs.pop("sort_chunks", INFER) is not INFER
kwargs.pop("chunk_size", None) is not None
or kwargs.pop("sort_chunks", None) is not None
):
warnings.warn(
"chunk_size and sort_chunks are deprecated, use "
"map_batched(sort_fn, batch_size=chunk_size) instead.",
VisibleDeprecationWarning,
)
if kwargs.pop("split_into_batches_after", INFER) is not INFER:
if kwargs.pop("split_into_batches_after", None) is not None:
warnings.warn(
"split_into_batches_after is deprecated.", VisibleDeprecationWarning
)
Expand All @@ -486,7 +460,7 @@ def set_processing(
ops=self.ops,
config={
**self.config,
**{k: v for k, v in kwargs.items() if v is not INFER},
**{k: v for k, v in kwargs.items() if v is not None},
},
)

Expand Down Expand Up @@ -690,8 +664,8 @@ def map_gpu(
def map_pipeline(
self,
model: Pipeline,
batch_size: Optional[int] = INFER,
batch_by: BatchBy = INFER,
batch_size: Optional[Union[int, str]] = None,
batch_by: BatchBy = None,
) -> "Stream":
"""
Maps a pipeline to the documents, i.e. adds each component of the pipeline to
Expand Down Expand Up @@ -974,16 +948,10 @@ def __getattr__(self, item):
def _make_stages(self, split_torch_pipes: bool) -> List[Stage]:
current_ops = []
stages = []
self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by)
self_batch_size = self.batch_size
assert self_batch_size is not None

ops = [copy(op) for op in self.ops]

for op in ops:
if isinstance(op, BatchifyOp):
op.batch_fn = self_batch_fn if op.batch_fn is INFER else op.batch_fn
op.size = self_batch_size if op.size is INFER else op.size
if (
isinstance(op, MapBatchesOp)
and hasattr(op.pipe, "forward")
Expand All @@ -1005,23 +973,39 @@ def validate_ops(self, ops, update: bool = False):
# Check batchify requirements
requires_sentinels = set()

self_batch_size, self_batch_by = self.validate_batching(
self.batch_size, self.batch_by
)
if self_batch_by is None:
self_batch_by = "docs"
if self_batch_size is None:
self_batch_size = 1
self_batch_fn = batchify_fns.get(self_batch_by, self_batch_by)

if hasattr(self.writer, "batch_fn") and hasattr(
self.writer.batch_fn, "requires_sentinel"
):
requires_sentinels.add(self.writer.batch_fn.requires_sentinel)

self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by)
for op in reversed(ops):
if isinstance(op, BatchifyOp):
batch_fn = op.batch_fn or self_batch_fn
if op.batch_fn is None and op.size is None:
batch_size = self_batch_size
batch_fn = self_batch_fn
elif op.batch_fn is None:
batch_size = op.size
batch_fn = batchify
else:
batch_size = op.size
batch_fn = op.batch_fn
sentinel_mode = op.sentinel_mode or (
"auto"
if "sentinel_mode" in signature(batch_fn).parameters
else None
)
if sentinel_mode == "auto":
sentinel_mode = "split" if requires_sentinels else "drop"
if requires_sentinels and op.sentinel_mode == "drop":
if requires_sentinels and sentinel_mode == "drop":
raise ValueError(
f"Operation {op} drops the stream sentinel values "
f"(markers for the end of a dataset or a dataset "
Expand All @@ -1031,10 +1015,12 @@ def validate_ops(self, ops, update: bool = False):
f"any upstream batching operation."
)
if update:
op.size = batch_size
op.batch_fn = batch_fn
op.sentinel_mode = sentinel_mode

if hasattr(batch_fn, "requires_sentinel"):
requires_sentinels.add(batch_fn.requires_sentinel)
if hasattr(op.batch_fn, "requires_sentinel"):
requires_sentinels.add(op.batch_fn.requires_sentinel)

sentinel_str = ", ".join(requires_sentinels)
if requires_sentinels and self.backend == "spark":
Expand Down
24 changes: 13 additions & 11 deletions tests/data/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ def forward(batch):
assert set(res.tolist()) == {i * 2 for i in range(15)}


# fmt: off
@pytest.mark.parametrize(
"sort,num_cpu_workers,batch_by,expected",
"sort,num_cpu_workers,batch_kwargs,expected",
[
(False, 1, "words", [3, 1, 3, 1, 3, 1]),
(False, 1, "padded_words", [2, 1, 1, 2, 1, 1, 2, 1, 1]),
(False, 1, "docs", [10, 2]),
(False, 2, "words", [2, 1, 2, 1, 2, 1, 1, 1, 1]),
(False, 2, "padded_words", [2, 1, 2, 1, 2, 1, 1, 1, 1]),
(False, 2, "docs", [6, 6]),
(True, 2, "padded_words", [3, 3, 2, 1, 1, 1, 1]),
(False, 1, {"batch_size": 10, "batch_by": "words"}, [3, 1, 3, 1, 3, 1]), # noqa: E501
(False, 1, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 1, 2, 1, 1, 2, 1, 1]), # noqa: E501
(False, 1, {"batch_size": 10, "batch_by": "docs"}, [10, 2]), # noqa: E501
(False, 2, {"batch_size": 10, "batch_by": "words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
(False, 2, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
(False, 2, {"batch_size": 10, "batch_by": "docs"}, [6, 6]), # noqa: E501
(True, 2, {"batch_size": 10, "batch_by": "padded_words"}, [3, 3, 2, 1, 1, 1, 1]), # noqa: E501
(False, 2, {"batch_size": "10 words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
],
)
def test_map_with_batching(sort, num_cpu_workers, batch_by, expected):
# fmt: on
def test_map_with_batching(sort, num_cpu_workers, batch_kwargs, expected):
nlp = edsnlp.blank("eds")
nlp.add_pipe(
"eds.matcher",
Expand All @@ -94,8 +97,7 @@ def test_map_with_batching(sort, num_cpu_workers, batch_by, expected):
stream = stream.map_batches(len)
stream = stream.set_processing(
num_cpu_workers=num_cpu_workers,
batch_size=10,
batch_by=batch_by,
**batch_kwargs,
chunk_size=1000, # deprecated
split_into_batches_after="matcher",
show_progress=True,
Expand Down

0 comments on commit ca23fe1

Please sign in to comment.