Skip to content

Commit

Permalink
(misc) Deprecate some hf-inference specific features (wait-for-model …
Browse files Browse the repository at this point in the history
…header, can't override model's task, get_model_status, list_deployed_models) (#2851)

* (draft) deprecate some hf-inference specific features

* remove hf-inference specific behavior (wait for model + handle 503)

* remove make sure sentence

* add back sentence-similarity task but use /models instead of /pipeline/tag

* async as well

* fix cassettes
  • Loading branch information
Wauplin authored Feb 13, 2025
1 parent bf80d1c commit b19ab11
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 80 deletions.
2 changes: 1 addition & 1 deletion docs/source/de/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Das Ziel von [`InferenceClient`] ist es, die einfachste Schnittstelle zum Ausfü
| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) || [`~InferenceClient.feature_extraction`] |
| | [Fill Mask](https://huggingface.co/tasks/fill-mask) || [`~InferenceClient.fill_mask`] |
| | [Question Answering](https://huggingface.co/tasks/question-answering) || [`~InferenceClient.question_answering`] |
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) || [`~InferenceClient.sentence_similarity`] |
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) || [`~InferenceClient.sentence_similarity`] |
| | [Summarization](https://huggingface.co/tasks/summarization) || [`~InferenceClient.summarization`] |
| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) || [`~InferenceClient.table_question_answering`] |
| | [Text Classification](https://huggingface.co/tasks/text-classification) || [`~InferenceClient.text_classification`] |
Expand Down
38 changes: 14 additions & 24 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import base64
import logging
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload

Expand Down Expand Up @@ -301,8 +300,6 @@ def _inner_post(
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
request_parameters.headers["Accept"] = "image/png"

t0 = time.time()
timeout = self.timeout
while True:
with _open_as_binary(request_parameters.data) as data_as_binary:
try:
Expand All @@ -326,30 +323,9 @@ def _inner_post(
except HTTPError as error:
if error.response.status_code == 422 and request_parameters.task != "unknown":
msg = str(error.args[0])
print(error.response.text)
if len(error.response.text) > 0:
msg += f"\n{error.response.text}\n"
msg += f"\nMake sure '{request_parameters.task}' task is supported by the model."
error.args = (msg,) + error.args[1:]
if error.response.status_code == 503:
# If Model is unavailable, either raise a TimeoutError...
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout (current:"
f" {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
time.sleep(1)
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
INFERENCE_ENDPOINT
):
request_parameters.headers["X-wait-for-model"] = "1"
if timeout is not None:
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
continue
raise

def audio_classification(
Expand Down Expand Up @@ -3261,6 +3237,13 @@ def zero_shot_image_classification(
response = self._inner_post(request_parameters)
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

@_deprecate_method(
version="0.33.0",
message=(
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
),
)
def list_deployed_models(
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -3444,6 +3427,13 @@ def health_check(self, model: Optional[str] = None) -> bool:
response = get_session().get(url, headers=build_hf_headers(token=self.token))
return response.status_code == 200

@_deprecate_method(
version="0.33.0",
message=(
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
),
)
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
Get the status of a model hosted on the HF Inference API.
Expand Down
38 changes: 14 additions & 24 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import base64
import logging
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload

Expand Down Expand Up @@ -299,8 +298,6 @@ async def _inner_post(
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
request_parameters.headers["Accept"] = "image/png"

t0 = time.time()
timeout = self.timeout
while True:
with _open_as_binary(request_parameters.data) as data_as_binary:
# Do not use context manager as we don't want to close the connection immediately when returning
Expand Down Expand Up @@ -331,27 +328,6 @@ async def _inner_post(
except aiohttp.ClientResponseError as error:
error.response_error_payload = response_error_payload
await session.close()
if response.status == 422 and request_parameters.task != "unknown":
error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
if response.status == 503:
# If Model is unavailable, either raise a TimeoutError...
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
f" (current: {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
INFERENCE_ENDPOINT
):
request_parameters.headers["X-wait-for-model"] = "1"
await asyncio.sleep(1)
if timeout is not None:
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
continue
raise error
except Exception:
await session.close()
Expand Down Expand Up @@ -3325,6 +3301,13 @@ async def zero_shot_image_classification(
response = await self._inner_post(request_parameters)
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

@_deprecate_method(
version="0.33.0",
message=(
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
),
)
async def list_deployed_models(
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -3554,6 +3537,13 @@ async def health_check(self, model: Optional[str] = None) -> bool:
response = await client.get(url, proxy=self.proxies)
return response.status == 200

@_deprecate_method(
version="0.33.0",
message=(
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
),
)
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
Get the status of a model hosted on the HF Inference API.
Expand Down
9 changes: 1 addition & 8 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
if mapped_model.startswith(("http://", "https://")):
return mapped_model

return (
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
f"{self.base_url}/pipeline/{self.task}/{mapped_model}"
if self.task in ("feature-extraction", "sentence-similarity")
# Otherwise, we use the default endpoint
else f"{self.base_url}/models/{mapped_model}"
)
return f"{self.base_url}/models/{mapped_model}"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if isinstance(inputs, bytes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ interactions:
X-Amzn-Trace-Id:
- 0434ff33-56fe-49db-9380-17b81e41f756
method: POST
uri: https://router.huggingface.co/hf-inference/pipeline/sentence-similarity/sentence-transformers/all-MiniLM-L6-v2
uri: https://router.huggingface.co/hf-inference/models/sentence-transformers/all-MiniLM-L6-v2
response:
body:
string: '[0.7785724997520447,0.4587624967098236,0.29062220454216003]'
Expand Down
2 changes: 1 addition & 1 deletion tests/cassettes/test_async_sentence_similarity.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ interactions:
body: null
headers: {}
method: POST
uri: https://router.huggingface.co/hf-inference/pipeline/sentence-similarity/sentence-transformers/all-MiniLM-L6-v2
uri: https://router.huggingface.co/hf-inference/models/sentence-transformers/all-MiniLM-L6-v2
response:
body:
string: '[0.7785724997520447,0.4587624967098236,0.29062220454216003]'
Expand Down
5 changes: 5 additions & 0 deletions tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def test_sync_vs_async_signatures() -> None:


@pytest.mark.asyncio
@pytest.mark.skip("Deprecated (get_model_status)")
async def test_get_status_too_big_model() -> None:
model_status = await AsyncInferenceClient(token=False).get_model_status("facebook/nllb-moe-54b")
assert model_status.loaded is False
Expand All @@ -309,6 +310,7 @@ async def test_get_status_too_big_model() -> None:


@pytest.mark.asyncio
@pytest.mark.skip("Deprecated (get_model_status)")
async def test_get_status_loaded_model() -> None:
model_status = await AsyncInferenceClient(token=False).get_model_status("bigscience/bloom")
assert model_status.loaded is True
Expand All @@ -318,18 +320,21 @@ async def test_get_status_loaded_model() -> None:


@pytest.mark.asyncio
@pytest.mark.skip("Deprecated (get_model_status)")
async def test_get_status_unknown_model() -> None:
with pytest.raises(ClientResponseError):
await AsyncInferenceClient(token=False).get_model_status("unknown/model")


@pytest.mark.asyncio
@pytest.mark.skip("Deprecated (get_model_status)")
async def test_get_status_model_as_url() -> None:
with pytest.raises(NotImplementedError):
await AsyncInferenceClient(token=False).get_model_status("https://unkown/model")


@pytest.mark.asyncio
@pytest.mark.skip("Deprecated (list_deployed_models)")
async def test_list_deployed_models_single_frameworks() -> None:
models_by_task = await AsyncInferenceClient().list_deployed_models("text-generation-inference")
assert isinstance(models_by_task, dict)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def test_accept_header_image(self, get_session_mock: MagicMock, bytes_to_image_m


class TestModelStatus(TestBase):
@expect_deprecation("get_model_status")
def test_too_big_model(self) -> None:
client = InferenceClient(token=False)
model_status = client.get_model_status("facebook/nllb-moe-54b")
Expand All @@ -877,6 +878,7 @@ def test_too_big_model(self) -> None:
assert model_status.compute_type == "cpu"
assert model_status.framework == "transformers"

@expect_deprecation("get_model_status")
def test_loaded_model(self) -> None:
client = InferenceClient(token=False)
model_status = client.get_model_status("bigscience/bloom")
Expand All @@ -885,28 +887,33 @@ def test_loaded_model(self) -> None:
assert isinstance(model_status.compute_type, dict) # e.g. {'gpu': {'gpu': 'a100', 'count': 8}}
assert model_status.framework == "text-generation-inference"

@expect_deprecation("get_model_status")
def test_unknown_model(self) -> None:
client = InferenceClient()
with pytest.raises(HfHubHTTPError):
client.get_model_status("unknown/model")

@expect_deprecation("get_model_status")
def test_model_as_url(self) -> None:
client = InferenceClient()
with pytest.raises(NotImplementedError):
client.get_model_status("https://unkown/model")


class TestListDeployedModels(TestBase):
@expect_deprecation("list_deployed_models")
@patch("huggingface_hub.inference._client.get_session")
def test_list_deployed_models_main_frameworks_mock(self, get_session_mock: MagicMock) -> None:
InferenceClient().list_deployed_models()
assert len(get_session_mock.return_value.get.call_args_list) == len(MAIN_INFERENCE_API_FRAMEWORKS)

@expect_deprecation("list_deployed_models")
@patch("huggingface_hub.inference._client.get_session")
def test_list_deployed_models_all_frameworks_mock(self, get_session_mock: MagicMock) -> None:
InferenceClient().list_deployed_models("all")
assert len(get_session_mock.return_value.get.call_args_list) == len(ALL_INFERENCE_API_FRAMEWORKS)

@expect_deprecation("list_deployed_models")
def test_list_deployed_models_single_frameworks(self) -> None:
models_by_task = InferenceClient().list_deployed_models("text-generation-inference")
assert isinstance(models_by_task, dict)
Expand Down
21 changes: 0 additions & 21 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
request_parameters.headers["Accept"] = "image/png"
t0 = time.time()
timeout = self.timeout
while True:
with _open_as_binary(request_parameters.data) as data_as_binary:
# Do not use context manager as we don't want to close the connection immediately when returning
Expand Down Expand Up @@ -205,25 +203,6 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
except aiohttp.ClientResponseError as error:
error.response_error_payload = response_error_payload
await session.close()
if response.status == 422 and request_parameters.task != "unknown":
error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
if response.status == 503:
# If Model is unavailable, either raise a TimeoutError...
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
f" (current: {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(INFERENCE_ENDPOINT):
request_parameters.headers["X-wait-for-model"] = "1"
await asyncio.sleep(1)
if timeout is not None:
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
continue
raise error
except Exception:
await session.close()
Expand Down

0 comments on commit b19ab11

Please sign in to comment.