Skip to content

Commit

Permalink
Nebius AI Studio provider added (#2866)
Browse files Browse the repository at this point in the history
* Nebius provider added

* nebius occurance sorted alphabetically

* nebius text-to-image task fixed; tests for nebius provider added

* upload cassettes and update docs

* maintain alphabetical order

* fix merging

* height and width are not required

* Update docs/source/en/guides/inference.md

---------

Co-authored-by: Akim Tsvigun <[email protected]>
Co-authored-by: Celina Hanouti <[email protected]>
Co-authored-by: Lucain <[email protected]>
  • Loading branch information
4 people authored Feb 17, 2025
1 parent 604b9ca commit ace6668
Show file tree
Hide file tree
Showing 11 changed files with 533 additions and 32 deletions.
60 changes: 30 additions & 30 deletions docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There

[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:

| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Novita AI | Replicate | Sambanova | Together |
| ------------------- | --------------------------------------------------- | ------------ | ------ | ------------ | ---------- | ------ | --------- | --------- | -------- |
| **Audio** | [`~InferenceClient.audio_classification`] ||| ||||||
| | [`~InferenceClient.audio_to_audio`] ||| ||||||
| | [`~InferenceClient.automatic_speech_recognition`] ||| ||||||
| | [`~InferenceClient.text_to_speech`] ||| ||||||
| **Computer Vision** | [`~InferenceClient.image_classification`] ||| ||||||
| | [`~InferenceClient.image_segmentation`] ||| ||||||
| | [`~InferenceClient.image_to_image`] ||| ||||||
| | [`~InferenceClient.image_to_text`] ||| ||||||
| | [`~InferenceClient.object_detection`] ||| ||||||
| | [`~InferenceClient.text_to_image`] ||| ||||||
| | [`~InferenceClient.text_to_video`] ||| ||||||
| | [`~InferenceClient.zero_shot_image_classification`] ||| ||||||
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||| ||||||
| | [`~InferenceClient.visual_question_answering`] ||| ||||||
| **NLP** | [`~InferenceClient.chat_completion`] ||| ||||||
| | [`~InferenceClient.feature_extraction`] ||| ||||||
| | [`~InferenceClient.fill_mask`] ||| ||||||
| | [`~InferenceClient.question_answering`] ||| ||||||
| | [`~InferenceClient.sentence_similarity`] ||| ||||||
| | [`~InferenceClient.summarization`] ||| ||||||
| | [`~InferenceClient.table_question_answering`] ||| ||||||
| | [`~InferenceClient.text_classification`] ||| ||||||
| | [`~InferenceClient.text_generation`] ||| ||||||
| | [`~InferenceClient.token_classification`] ||| ||||||
| | [`~InferenceClient.translation`] ||| ||||||
| | [`~InferenceClient.zero_shot_classification`] ||| ||||||
| **Tabular** | [`~InferenceClient.tabular_classification`] ||| ||||||
| | [`~InferenceClient.tabular_regression`] ||| ||||||
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
| ------------------- | --------------------------------------------------- | ------------ | ------ | --------- | ---------- | ---------------- | ------ | --------- | --------- | ----------- |
| **Audio** | [`~InferenceClient.audio_classification`] ||||| |||| |
| | [`~InferenceClient.audio_to_audio`] ||||| |||| |
| | [`~InferenceClient.automatic_speech_recognition`] ||||| |||| |
| | [`~InferenceClient.text_to_speech`] ||||| |||| |
| **Computer Vision** | [`~InferenceClient.image_classification`] ||||| |||| |
| | [`~InferenceClient.image_segmentation`] ||||| |||| |
| | [`~InferenceClient.image_to_image`] ||||| |||| |
| | [`~InferenceClient.image_to_text`] ||||| |||| |
| | [`~InferenceClient.object_detection`] ||||| |||| |
| | [`~InferenceClient.text_to_image`] ||||| |||| |
| | [`~InferenceClient.text_to_video`] ||||| |||| |
| | [`~InferenceClient.zero_shot_image_classification`] ||||| |||| |
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||||| |||| |
| | [`~InferenceClient.visual_question_answering`] ||||| |||| |
| **NLP** | [`~InferenceClient.chat_completion`] ||||| |||| |
| | [`~InferenceClient.feature_extraction`] ||||| |||| |
| | [`~InferenceClient.fill_mask`] ||||| |||| |
| | [`~InferenceClient.question_answering`] ||||| |||| |
| | [`~InferenceClient.sentence_similarity`] ||||| |||| |
| | [`~InferenceClient.summarization`] ||||| |||| |
| | [`~InferenceClient.table_question_answering`] ||||| |||| |
| | [`~InferenceClient.text_classification`] ||||| |||| |
| | [`~InferenceClient.text_generation`] ||||| |||| |
| | [`~InferenceClient.token_classification`] ||||| |||| |
| | [`~InferenceClient.translation`] ||||| |||| |
| | [`~InferenceClient.zero_shot_classification`] ||||| |||| |
| **Tabular** | [`~InferenceClient.tabular_classification`] ||||| |||| |
| | [`~InferenceClient.tabular_regression`] ||||| |||| |

<Tip>

Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class InferenceClient:
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
defaults to hf-inference (Hugging Face Serverless Inference API).
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str` or `bool`, *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class AsyncInferenceClient:
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
defaults to hf-inference (Hugging Face Serverless Inference API).
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str` or `bool`, *optional*):
Expand Down
7 changes: 7 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .fireworks_ai import FireworksAIConversationalTask
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
from .novita import NovitaConversationalTask, NovitaTextGenerationTask
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask
Expand All @@ -21,6 +22,7 @@
"fireworks-ai",
"hf-inference",
"hyperbolic",
"nebius",
"novita",
"replicate",
"sambanova",
Expand Down Expand Up @@ -70,6 +72,11 @@
"conversational": HyperbolicTextGenerationTask("conversational"),
"text-generation": HyperbolicTextGenerationTask("text-generation"),
},
"nebius": {
"text-to-image": NebiusTextToImageTask(),
"conversational": NebiusConversationalTask(),
"text-generation": NebiusTextGenerationTask(),
},
"novita": {
"text-generation": NovitaTextGenerationTask(),
"conversational": NovitaConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"fireworks-ai": {},
"hf-inference": {},
"hyperbolic": {},
"nebius": {},
"replicate": {},
"sambanova": {},
"together": {},
Expand Down
41 changes: 41 additions & 0 deletions src/huggingface_hub/inference/_providers/nebius.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import base64
from typing import Any, Dict, Optional, Union

from huggingface_hub.inference._common import _as_dict
from huggingface_hub.inference._providers._common import (
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
filter_none,
)


class NebiusTextGenerationTask(BaseTextGenerationTask):
def __init__(self):
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")


class NebiusConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")


class NebiusTextToImageTask(TaskProviderHelper):
def __init__(self):
super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai")

def _prepare_route(self, mapped_model: str) -> str:
return "/v1/images/generations"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "guidance_scale" in parameters:
parameters.pop("guidance_scale")
if parameters.get("response_format") not in ("b64_json", "url"):
parameters["response_format"] = "b64_json"

return {"prompt": inputs, **parameters, "model": mapped_model}

def get_response(self, response: Union[bytes, Dict]) -> Any:
response_dict = _as_dict(response)
return base64.b64decode(response_dict["data"][0]["b64_json"])
Loading

0 comments on commit ace6668

Please sign in to comment.