Skip to content

Commit

Permalink
Add to feature-extraction + update types
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jul 1, 2024
1 parent 1ffbb07 commit c3511d6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
9 changes: 9 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ def feature_extraction(
text: str,
*,
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
model: Optional[str] = None,
) -> "np.ndarray":
Expand All @@ -956,6 +957,12 @@ def feature_extraction(
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
prompt_name (`str`, *optional*):
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the prompt text will be prepended before any text to encode.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
Expand Down Expand Up @@ -983,6 +990,8 @@ def feature_extraction(
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if prompt_name is not None:
payload["prompt_name"] = prompt_name
if truncate is not None:
payload["truncate"] = truncate
response = self.post(json=payload, model=model, task="feature-extraction")
Expand Down
9 changes: 9 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ async def feature_extraction(
text: str,
*,
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
model: Optional[str] = None,
) -> "np.ndarray":
Expand All @@ -960,6 +961,12 @@ async def feature_extraction(
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
prompt_name (`str`, *optional*):
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the prompt text will be prepended before any text to encode.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
Expand Down Expand Up @@ -988,6 +995,8 @@ async def feature_extraction(
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if prompt_name is not None:
payload["prompt_name"] = prompt_name
if truncate is not None:
payload["truncate"] = truncate
response = await self.post(json=payload, model=model, task="feature-extraction")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Literal, Optional

from .base import BaseInferenceType


FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]


@dataclass
class FeatureExtractionInput(BaseInferenceType):
"""Feature Extraction Input.
Expand All @@ -17,6 +20,18 @@ class FeatureExtractionInput(BaseInferenceType):
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
"""

inputs: Union[List[str], str]
inputs: str
"""The text to embed."""
normalize: Optional[bool] = None
prompt_name: Optional[str] = None
"""The name of the prompt that should be used by for encoding. If not set, no prompt
will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
...},
then the sentence "What is the capital of France?" will be encoded as
"query: What is the capital of France?" because the prompt text will be prepended before
any text to encode.
"""
truncate: Optional[bool] = None
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None

0 comments on commit c3511d6

Please sign in to comment.