Skip to content

Commit

Permalink
Unified prepare_request method + class-based providers (#2777)
Browse files Browse the repository at this point in the history
* (first draft) Unified provider_helper.prepare_request method

* style

* token in Openai tests

* together provider

* raise later

* Replicate provider

* sambanova

* fal ai

* fal ai forgotten

* fix test
  • Loading branch information
Wauplin authored Jan 24, 2025
1 parent a46d2ad commit f004bf2
Show file tree
Hide file tree
Showing 25 changed files with 1,646 additions and 1,448 deletions.
10 changes: 8 additions & 2 deletions src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def client(self) -> InferenceClient:
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
)
return InferenceClient(model=self.url, token=self._token)
return InferenceClient(
model=self.url,
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
)

@property
def async_client(self) -> AsyncInferenceClient:
Expand All @@ -169,7 +172,10 @@ def async_client(self) -> AsyncInferenceClient:
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
)
return AsyncInferenceClient(model=self.url, token=self._token)
return AsyncInferenceClient(
model=self.url,
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
)

def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint":
"""Wait for the Inference Endpoint to be deployed.
Expand Down
915 changes: 477 additions & 438 deletions src/huggingface_hub/inference/_client.py

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io
import json
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -74,6 +75,34 @@
logger = logging.getLogger(__name__)


@dataclass
class RequestParameters:
url: str
task: str
model: Optional[str]
json: Optional[Union[str, Dict, List]]
data: Optional[ContentT]
headers: Dict[str, Any]


class TaskProviderHelper(ABC):
"""Protocol defining the interface for task-specific provider helpers."""

@abstractmethod
def prepare_request(
self,
*,
inputs: Any,
parameters: Dict[str, Any],
headers: Dict,
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
) -> RequestParameters: ...
@abstractmethod
def get_response(self, response: Union[bytes, Dict]) -> Any: ...


# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
@dataclass
class ModelStatus:
Expand Down
927 changes: 481 additions & 446 deletions src/huggingface_hub/inference/_generated/_async_client.py

Large diffs are not rendered by default.

45 changes: 19 additions & 26 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,17 @@
# mypy: disable-error-code="dict-item"
from typing import Any, Dict, Optional, Protocol, Union
from typing import Dict

from . import fal_ai, replicate, sambanova, together
from .._common import TaskProviderHelper
from .fal_ai import FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask


class TaskProviderHelper(Protocol):
"""Protocol defining the interface for task-specific provider helpers."""

def build_url(self, model: Optional[str] = None) -> str: ...
def map_model(self, model: Optional[str] = None) -> str: ...
def prepare_headers(self, headers: Dict, *, token: Optional[str] = None) -> Dict: ...
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...
def get_response(self, response: Union[bytes, Dict]) -> Any: ...
from .replicate import ReplicateTextToImageTask
from .sambanova import SambanovaConversationalTask
from .together import TogetherTextGenerationTask, TogetherTextToImageTask


PROVIDERS: Dict[str, Dict[str, TaskProviderHelper]] = {
"replicate": {
"text-to-image": replicate.text_to_image,
},
"fal-ai": {
"text-to-image": fal_ai.text_to_image,
"automatic-speech-recognition": fal_ai.automatic_speech_recognition,
},
"sambanova": {
"conversational": sambanova.conversational,
},
"together": {
"text-to-image": together.text_to_image,
"conversational": together.conversational,
"text-generation": together.text_generation,
"text-to-image": FalAITextToImageTask(),
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
},
"hf-inference": {
"text-to-image": HFInferenceTask("text-to-image"),
Expand Down Expand Up @@ -59,6 +41,17 @@ def get_response(self, response: Union[bytes, Dict]) -> Any: ...
"summarization": HFInferenceTask("summarization"),
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
},
"replicate": {
"text-to-image": ReplicateTextToImageTask(),
},
"sambanova": {
"conversational": SambanovaConversationalTask(),
},
"together": {
"text-to-image": TogetherTextToImageTask(),
"conversational": TogetherTextGenerationTask("conversational"),
"text-generation": TogetherTextGenerationTask("text-generation"),
},
}


Expand Down
123 changes: 123 additions & 0 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union

from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper
from huggingface_hub.utils import build_hf_headers, get_session


BASE_URL = "https://fal.run"

SUPPORTED_MODELS = {
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
},
}


class FalAITask(TaskProviderHelper, ABC):
"""Base class for FalAI API tasks."""

def __init__(self, task: str):
self.task = task

def prepare_request(
self,
*,
inputs: Any,
parameters: Dict[str, Any],
headers: Dict,
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
) -> RequestParameters:
mapped_model = self._map_model(model)

if api_key is None:
raise ValueError("You must provide an api_key to work with Together API.")
headers = {
**build_hf_headers(token=api_key),
**headers,
"authorization": f"Key {api_key}",
}

payload = self._prepare_payload(inputs, parameters=parameters)

return RequestParameters(
url=f"{BASE_URL}/{mapped_model}",
task=self.task,
model=mapped_model,
json=payload,
data=None,
headers=headers,
)

def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a model available on FalAI.")
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with FalAI.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with FalAI for task {self.task}.")
return mapped_model

@abstractmethod
def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...


class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
else:
# If input is a file path, read it first
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()

audio_b64 = base64.b64encode(inputs).decode()
content_type = "audio/mpeg"
audio_url = f"data:{content_type};base64,{audio_b64}"

return {
"audio_url": audio_url,
**{k: v for k, v in parameters.items() if v is not None},
}

def get_response(self, response: Union[bytes, Dict]) -> Any:
text = _as_dict(response)["text"]
if not isinstance(text, str):
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
return text


class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
if "image_size" not in parameters and "width" in parameters and "height" in parameters:
parameters["image_size"] = {
"width": parameters.pop("width"),
"height": parameters.pop("height"),
}
return {"prompt": inputs, **parameters}

def get_response(self, response: Union[bytes, Dict]) -> Any:
url = _as_dict(response)["images"][0]["url"]
return get_session().get(url).content


def _as_dict(response: Union[bytes, Dict]) -> Dict:
return json.loads(response) if isinstance(response, bytes) else response
2 changes: 0 additions & 2 deletions src/huggingface_hub/inference/_providers/fal_ai/__init__.py

This file was deleted.

This file was deleted.

52 changes: 0 additions & 52 deletions src/huggingface_hub/inference/_providers/fal_ai/text_to_image.py

This file was deleted.

Loading

0 comments on commit f004bf2

Please sign in to comment.