Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task-specific pipeline init args #28439

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import requests

from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_torch_available():
Expand Down Expand Up @@ -63,7 +63,7 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
return audio


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True))
class AudioClassificationPipeline(Pipeline):
"""
Audio classification pipeline using any `AutoModelForAudioClassification`. This pipeline predicts the class of a
Expand Down
45 changes: 38 additions & 7 deletions src/transformers/pipelines/base.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to remove PIPELINE_INIT_ARGS to maintain a single version of the string? Or must it be supported for external use?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed them and added the following on L766 so it's backwards compatible for imports

PIPELINE_INIT_ARGS = build_pipeline_init_args(
    has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call!

Original file line number Diff line number Diff line change
Expand Up @@ -702,14 +702,33 @@ def predict(self, X):
raise NotImplementedError()


PIPELINE_INIT_ARGS = r"""
def build_pipeline_init_args(
has_tokenizer: bool = False,
has_feature_extractor: bool = False,
has_image_processor: bool = False,
supports_binary_output: bool = True,
) -> str:
docstring = r"""
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow."""
if has_tokenizer:
docstring += r"""
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
[`PreTrainedTokenizer`].
[`PreTrainedTokenizer`]."""
if has_feature_extractor:
docstring += r"""
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode data for the model. This object inherits from
[`SequenceFeatureExtractor`]."""
if has_image_processor:
docstring += r"""
image_processor ([`BaseImageProcessor`]):
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
[`BaseImageProcessor`]."""
docstring += r"""
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
framework (`str`, *optional*):
Expand All @@ -732,10 +751,22 @@ def predict(self, X):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
the associated CUDA device id. You can pass native `torch.device` or a `str` too
torch_dtype (`str` or `torch.dtype`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`)"""
if supports_binary_output:
docstring += r"""
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Flag indicating if the output the pipeline should happen in a serialized format (i.e., pickle) or as
the raw output data e.g. text."""
return docstring
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved


PIPELINE_INIT_ARGS = build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
)


if is_torch_available():
from transformers.pipelines.pt_utils import (
Expand All @@ -746,7 +777,7 @@ def predict(self, X):
)


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
class Pipeline(_ScikitCompat):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Union

from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_tf_available():
Expand Down Expand Up @@ -192,13 +192,12 @@ def new_user_input(self):


@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
min_length_for_response (`int`, *optional*, defaults to 32):
The minimum length (in number of tokens) for a response.
minimum_tokens (`int`, *optional*, defaults to 10):
The minimum length of tokens to leave for a response.
""",
The minimum length of tokens to leave for a response.""",
)
class ConversationalPipeline(Pipeline):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -19,7 +19,7 @@
logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class DepthEstimationPipeline(Pipeline):
"""
Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
is_vision_available,
logging,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import ChunkPipeline, build_pipeline_init_args
from .question_answering import select_starts_ends


Expand Down Expand Up @@ -98,7 +98,7 @@ class ModelType(ExplicitEnum):
VisionEncoderDecoder = "vision_encoder_decoder"


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True, has_tokenizer=True))
class DocumentQuestionAnsweringPipeline(ChunkPipeline):
# TODO: Update task_summary docs to include an example with document QA and then update the first sentence
"""
Expand Down
40 changes: 10 additions & 30 deletions src/transformers/pipelines/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from typing import Dict

from .base import GenericTensor, Pipeline
from ..utils import add_end_docstrings
from .base import GenericTensor, Pipeline, build_pipeline_init_args


# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
r"""
tokenize_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.
return_tensors (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
)
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
Expand All @@ -27,34 +35,6 @@ class FeatureExtractionPipeline(Pipeline):

All models may be used for this pipeline. See a list of all models, including community-contributed models on
[huggingface.co/models](https://huggingface.co/models).

Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
installed.

If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
provided.
return_tensors (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.
task (`str`, defaults to `""`):
A task-identifier for the pipeline.
args_parser ([`~pipelines.ArgumentHandler`], *optional*):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
tokenize_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.
"""

def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException
from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args


if is_tf_available():
Expand All @@ -20,16 +20,16 @@


@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_tokenizer=True),
r"""
top_k (`int`, defaults to 5):
The number of predictions to return.
targets (`str` or `List[str]`, *optional*):
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
token will be used (with a warning, and that might be slower).

""",
tokenizer_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.""",
)
class FillMaskPipeline(Pipeline):
"""
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand Down Expand Up @@ -48,7 +48,7 @@ class ClassificationFunction(ExplicitEnum):


@add_end_docstrings(
PIPELINE_INIT_ARGS,
build_pipeline_init_args(has_image_processor=True),
r"""
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
Expand All @@ -57,8 +57,7 @@ class ClassificationFunction(ExplicitEnum):
has several labels, will apply the softmax function on the output.
- `"sigmoid"`: Applies the sigmoid function on the output.
- `"softmax"`: Applies the softmax function on the output.
- `"none"`: Does not apply any function on the output.
""",
- `"none"`: Does not apply any function on the output.""",
)
class ImageClassificationPipeline(Pipeline):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -27,7 +27,7 @@
Predictions = List[Prediction]


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ImageSegmentationPipeline(Pipeline):
"""
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -36,7 +36,7 @@
logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ImageToImagePipeline(Pipeline):
"""
Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -27,7 +27,7 @@
logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
class ImageToTextPipeline(Pipeline):
"""
Image To Text pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given image.
Expand Down
31 changes: 12 additions & 19 deletions src/transformers/pipelines/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import ChunkPipeline, build_pipeline_init_args


if is_torch_available():
Expand All @@ -19,7 +19,17 @@
logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(
build_pipeline_init_args(has_image_processor=True),
r"""
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format""",
)
class MaskGenerationPipeline(ChunkPipeline):
"""
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
Expand Down Expand Up @@ -48,23 +58,6 @@ class MaskGenerationPipeline(ChunkPipeline):
applies a variety of filters based on non maximum suppression to remove bad masks.
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.

Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode the input.
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format

Example:

```python
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Union

from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import PIPELINE_INIT_ARGS, Pipeline
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -23,7 +23,7 @@
Predictions = List[Prediction]


@add_end_docstrings(PIPELINE_INIT_ARGS)
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ObjectDetectionPipeline(Pipeline):
"""
Object detection pipeline using any `AutoModelForObjectDetection`. This pipeline predicts bounding boxes of objects
Expand Down
Loading
Loading