Skip to content

Commit

Permalink
Keep-alive connection between requests (#1394)
Browse files Browse the repository at this point in the history
* Configurable requests Sessions + tests

* Use session everywhere

* fix python3.7

* fix telemetry tests

* FIX http_backoff tests

* FIX inference API mocked tests

* FIX mocked http offline test

* fix cached_download mocked test

* fix mocked space api test

* FIX mocked pagination tests

* FIX tests

* add documentation
  • Loading branch information
Wauplin authored Mar 23, 2023
1 parent f1247d8 commit 46d51b2
Show file tree
Hide file tree
Showing 19 changed files with 425 additions and 137 deletions.
18 changes: 18 additions & 0 deletions docs/source/package_reference/utilities.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ Using these shouldn't be necessary if you use `huggingface_hub` and you don't mo

[[autodoc]] logging.get_logger

## Configure HTTP backend

In some environments, you might want to configure how HTTP calls are made, for example if you are using a proxy.
`huggingface_hub` let you configure this globally using [`configure_http_backend`]. All requests made to the Hub will
then use your settings. Under the hood, `huggingface_hub` uses `requests.Session` so you might want to refer to the
[`requests` documentation](https://requests.readthedocs.io/en/latest/user/advanced) to learn more about the parameters
available.

Since `requests.Session` is not guaranteed to be thread-safe, `huggingface_hub` creates one session instance per thread.
Using sessions allows us to keep the connection open between HTTP calls and ultimately save time. If you are
integrating `huggingface_hub` in a third-party library and wants to make a custom call to the Hub, use [`get_session`]
to get a Session configured by your users (i.e. replace any `requests.get(...)` call by `get_session().get(...)`).

[[autodoc]] configure_http_backend

[[autodoc]] get_session


## Handle HTTP errors

`huggingface_hub` defines its own HTTP errors to refine the `HTTPError` raised by
Expand Down
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@
"HFCacheInfo",
"HfFolder",
"cached_assets_path",
"configure_http_backend",
"dump_environment_info",
"get_session",
"logging",
"scan_cache_dir",
],
Expand Down Expand Up @@ -458,7 +460,9 @@ def __dir__():
HFCacheInfo, # noqa: F401
HfFolder, # noqa: F401
cached_assets_path, # noqa: F401
configure_http_backend, # noqa: F401
dump_environment_info, # noqa: F401
get_session, # noqa: F401
logging, # noqa: F401
scan_cache_dir, # noqa: F401
)
Expand Down
5 changes: 3 additions & 2 deletions src/huggingface_hub/_commit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Union

import requests
from tqdm.contrib.concurrent import thread_map

from huggingface_hub import get_session

from .constants import ENDPOINT
from .lfs import UploadInfo, _validate_batch_actions, lfs_upload, post_lfs_batch_info
from .utils import (
Expand Down Expand Up @@ -459,7 +460,7 @@ def fetch_upload_modes(
]
}

resp = requests.post(
resp = get_session().post(
f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
json=payload,
headers=headers,
Expand Down
8 changes: 3 additions & 5 deletions src/huggingface_hub/commands/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from argparse import _SubParsersAction
from typing import Dict, List, Optional

import requests

from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND, SliceFileObj

from ..utils import hf_raise_for_status, logging
from ..utils import get_session, hf_raise_for_status, logging


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -172,7 +170,7 @@ def run(self):
seek_from=i * chunk_size,
read_limit=chunk_size,
) as data:
r = requests.put(presigned_url, data=data)
r = get_session().put(presigned_url, data=data)
hf_raise_for_status(r)
parts.append(
{
Expand All @@ -192,7 +190,7 @@ def run(self):
)
# Not precise but that's ok.

r = requests.post(
r = get_session().post(
completion_url,
json={
"oid": oid,
Expand Down
73 changes: 38 additions & 35 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import requests
from requests.exceptions import HTTPError

from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, get_session

from ._commit_api import (
CommitOperation,
Expand Down Expand Up @@ -818,7 +818,7 @@ def whoami(self, token: Optional[str] = None) -> Dict:
Hugging Face token. Will default to the locally saved token if
not provided.
"""
r = requests.get(
r = get_session().get(
f"{self.endpoint}/api/whoami-v2",
headers=self._build_hf_headers(
# If `token` is provided and not `None`, it will be used by default.
Expand Down Expand Up @@ -856,7 +856,7 @@ def _is_valid_token(self, token: str) -> bool:
def get_model_tags(self) -> ModelTags:
"Gets all valid model tags as a nested namespace object"
path = f"{self.endpoint}/api/models-tags-by-type"
r = requests.get(path)
r = get_session().get(path)
hf_raise_for_status(r)
d = r.json()
return ModelTags(d)
Expand All @@ -866,7 +866,7 @@ def get_dataset_tags(self) -> DatasetTags:
Gets all valid dataset tags as a nested namespace object.
"""
path = f"{self.endpoint}/api/datasets-tags-by-type"
r = requests.get(path)
r = get_session().get(path)
hf_raise_for_status(r)
d = r.json()
return DatasetTags(d)
Expand Down Expand Up @@ -1256,7 +1256,7 @@ def list_metrics(self) -> List[MetricInfo]:
`List[MetricInfo]`: a list of [`MetricInfo`] objects which.
"""
path = f"{self.endpoint}/api/metrics"
r = requests.get(path)
r = get_session().get(path)
hf_raise_for_status(r)
d = r.json()
return [MetricInfo(**x) for x in d]
Expand Down Expand Up @@ -1390,7 +1390,7 @@ def like(
"""
if repo_type is None:
repo_type = REPO_TYPE_MODEL
response = requests.post(
response = get_session().post(
url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like",
headers=self._build_hf_headers(token=token),
)
Expand Down Expand Up @@ -1438,10 +1438,8 @@ def unlike(
"""
if repo_type is None:
repo_type = REPO_TYPE_MODEL
# TODO: use requests.delete(".../like") instead when https://github.com/huggingface/moon-landing/pull/4813 is merged
response = requests.delete(
url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like",
headers=self._build_hf_headers(token=token),
response = get_session().delete(
url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token)
)
hf_raise_for_status(response)

Expand Down Expand Up @@ -1583,7 +1581,7 @@ def model_info(
params["securityStatus"] = True
if files_metadata:
params["blobs"] = True
r = requests.get(path, headers=headers, timeout=timeout, params=params)
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
hf_raise_for_status(r)
d = r.json()
return ModelInfo(**d)
Expand Down Expand Up @@ -1646,7 +1644,7 @@ def dataset_info(
if files_metadata:
params["blobs"] = True

r = requests.get(path, headers=headers, timeout=timeout, params=params)
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
hf_raise_for_status(r)
d = r.json()
return DatasetInfo(**d)
Expand Down Expand Up @@ -1709,7 +1707,7 @@ def space_info(
if files_metadata:
params["blobs"] = True

r = requests.get(path, headers=headers, timeout=timeout, params=params)
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
hf_raise_for_status(r)
d = r.json()
return SpaceInfo(**d)
Expand Down Expand Up @@ -1876,9 +1874,8 @@ def list_repo_refs(
repo on the Hub.
"""
repo_type = repo_type or REPO_TYPE_MODEL
response = requests.get(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs",
headers=self._build_hf_headers(token=token),
response = get_session().get(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", headers=self._build_hf_headers(token=token)
)
hf_raise_for_status(response)
data = response.json()
Expand Down Expand Up @@ -2037,7 +2034,7 @@ def create_repo(
# See https://github.com/huggingface/huggingface_hub/pull/733/files#r820604472
json["lfsmultipartthresh"] = self._lfsmultipartthresh # type: ignore
headers = self._build_hf_headers(token=token, is_write_action=True)
r = requests.post(path, headers=headers, json=json)
r = get_session().post(path, headers=headers, json=json)

try:
hf_raise_for_status(r)
Expand Down Expand Up @@ -2103,7 +2100,7 @@ def delete_repo(
json["type"] = repo_type

headers = self._build_hf_headers(token=token, is_write_action=True)
r = requests.delete(path, headers=headers, json=json)
r = get_session().delete(path, headers=headers, json=json)
hf_raise_for_status(r)

@validate_hf_hub_args
Expand Down Expand Up @@ -2158,7 +2155,7 @@ def update_repo_visibility(
if repo_type is None:
repo_type = REPO_TYPE_MODEL # default repo type

r = requests.put(
r = get_session().put(
url=f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings",
headers=self._build_hf_headers(token=token, is_write_action=True),
json={"private": private},
Expand Down Expand Up @@ -2218,7 +2215,7 @@ def move_repo(

path = f"{self.endpoint}/api/repos/move"
headers = self._build_hf_headers(token=token, is_write_action=True)
r = requests.post(path, headers=headers, json=json)
r = get_session().post(path, headers=headers, json=json)
try:
hf_raise_for_status(r)
except HfHubHTTPError as e:
Expand Down Expand Up @@ -2402,7 +2399,7 @@ def _payload_as_ndjson() -> Iterable[bytes]:
params = {"create_pr": "1"} if create_pr else None

try:
commit_resp = requests.post(url=commit_url, headers=headers, data=data, params=params)
commit_resp = get_session().post(url=commit_url, headers=headers, data=data, params=params)
hf_raise_for_status(commit_resp, endpoint_name="commit")
except RepositoryNotFoundError as e:
e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE)
Expand Down Expand Up @@ -2983,7 +2980,7 @@ def create_branch(
payload["startingPoint"] = revision

# Create branch
response = requests.post(url=branch_url, headers=headers, json=payload)
response = get_session().post(url=branch_url, headers=headers, json=payload)
try:
hf_raise_for_status(response)
except HfHubHTTPError as e:
Expand Down Expand Up @@ -3036,7 +3033,7 @@ def delete_branch(
headers = self._build_hf_headers(token=token, is_write_action=True)

# Delete branch
response = requests.delete(url=branch_url, headers=headers)
response = get_session().delete(url=branch_url, headers=headers)
hf_raise_for_status(response)

@validate_hf_hub_args
Expand Down Expand Up @@ -3103,7 +3100,7 @@ def create_tag(
payload["message"] = tag_message

# Tag
response = requests.post(url=tag_url, headers=headers, json=payload)
response = get_session().post(url=tag_url, headers=headers, json=payload)
try:
hf_raise_for_status(response)
except HfHubHTTPError as e:
Expand Down Expand Up @@ -3153,7 +3150,7 @@ def delete_tag(
headers = self._build_hf_headers(token=token, is_write_action=True)

# Un-tag
response = requests.delete(url=tag_url, headers=headers)
response = get_session().delete(url=tag_url, headers=headers)
hf_raise_for_status(response)

@validate_hf_hub_args
Expand Down Expand Up @@ -3244,7 +3241,7 @@ def get_repo_discussions(

def _fetch_discussion_page(page_index: int):
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}"
resp = requests.get(path, headers=headers)
resp = get_session().get(path, headers=headers)
hf_raise_for_status(resp)
paginated_discussions = resp.json()
total = paginated_discussions["count"]
Expand Down Expand Up @@ -3319,7 +3316,7 @@ def get_discussion_details(

path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}"
headers = self._build_hf_headers(token=token)
resp = requests.get(path, params={"diff": "1"}, headers=headers)
resp = get_session().get(path, params={"diff": "1"}, headers=headers)
hf_raise_for_status(resp)

discussion_details = resp.json()
Expand Down Expand Up @@ -3416,7 +3413,7 @@ def create_discussion(
)

headers = self._build_hf_headers(token=token, is_write_action=True)
resp = requests.post(
resp = get_session().post(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions",
json={
"title": title.strip(),
Expand Down Expand Up @@ -3924,7 +3921,7 @@ def add_space_secret(self, repo_id: str, key: str, value: str, *, token: Optiona
token (`str`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
"""
r = requests.post(
r = get_session().post(
f"{self.endpoint}/api/spaces/{repo_id}/secrets",
headers=self._build_hf_headers(token=token),
json={"key": key, "value": value},
Expand All @@ -3946,7 +3943,7 @@ def delete_space_secret(self, repo_id: str, key: str, *, token: Optional[str] =
token (`str`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
"""
r = requests.delete(
r = get_session().delete(
f"{self.endpoint}/api/spaces/{repo_id}/secrets",
headers=self._build_hf_headers(token=token),
json={"key": key},
Expand All @@ -3966,7 +3963,9 @@ def get_space_runtime(self, repo_id: str, *, token: Optional[str] = None) -> Spa
Returns:
[`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware.
"""
r = requests.get(f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token))
r = get_session().get(
f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token)
)
hf_raise_for_status(r)
return SpaceRuntime(r.json())

Expand All @@ -3988,7 +3987,7 @@ def request_space_hardware(self, repo_id: str, hardware: SpaceHardware, *, token
</Tip>
"""
r = requests.post(
r = get_session().post(
f"{self.endpoint}/api/spaces/{repo_id}/hardware",
headers=self._build_hf_headers(token=token),
json={"flavor": hardware},
Expand Down Expand Up @@ -4025,7 +4024,9 @@ def pause_space(self, repo_id: str, *, token: Optional[str] = None) -> SpaceRunt
If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide
a static Space, you can set it to private.
"""
r = requests.post(f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token))
r = get_session().post(
f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token)
)
hf_raise_for_status(r)
return SpaceRuntime(r.json())

Expand Down Expand Up @@ -4059,7 +4060,9 @@ def restart_space(self, repo_id: str, *, token: Optional[str] = None) -> SpaceRu
If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide
a static Space, you can set it to private.
"""
r = requests.post(f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token))
r = get_session().post(
f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token)
)
hf_raise_for_status(r)
return SpaceRuntime(r.json())

Expand Down Expand Up @@ -4133,7 +4136,7 @@ def duplicate_space(
if private is not None:
payload["private"] = private

r = requests.post(
r = get_session().post(
f"{self.endpoint}/api/spaces/{from_id}/duplicate",
headers=self._build_hf_headers(token=token, is_write_action=True),
json=payload,
Expand Down
Loading

0 comments on commit 46d51b2

Please sign in to comment.