Skip to content

Commit

Permalink
Update monaihosting download method (#8364)
Browse files Browse the repository at this point in the history
Related to Project-MONAI/model-zoo#723.

### Description

Currently, bundle download on source "monaihosting" uses fixed download
url according to the function `_get_monaihosting_bundle_url`.
A possible enhancement if to support on bundles that are hosted in
different places.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
yiheng-wang-nv and KumoLiu authored Feb 25, 2025
1 parent ab07523 commit a09c1f0
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import os
import re
import urllib
import warnings
import zipfile
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -58,7 +59,7 @@
validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")
requests, has_requests = optional_import("requests")
onnx, _ = optional_import("onnx")
huggingface_hub, _ = optional_import("huggingface_hub")

Expand Down Expand Up @@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
extractall(filepath=filepath, output_dir=download_path, has_base=True)


def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
bundle_info = get_bundle_info(bundle_name=filename, version=version)
if not bundle_info:
raise ValueError(f"Bundle info not found for {filename} v{version}.")
url = bundle_info["browser_download_url"]
filepath = download_path / f"{filename}_v{version}.zip"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=download_path, has_base=True)


def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
if name.startswith(prefix):
return name
Expand All @@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
if not has_requests:
raise ValueError("requests package is required, please install it.")
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response = requests.get(request_url, headers=headers)
response.raise_for_status()
model_info = json.loads(response.text)

Expand Down Expand Up @@ -266,7 +277,7 @@ def _download_from_ngc_private(
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
if has_requests:
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response = requests.get(request_url, headers=headers)
response.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
Expand All @@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
url = "https://authn.nvidia.com/token?service=ngc"
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
if has_requests:
response = requests_get(url, headers=headers)
response = requests.get(url, headers=headers)
if not response.ok:
# retry 3 times, if failed, raise an error.
if retry < 3:
Expand All @@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):

def _get_latest_bundle_version_monaihosting(name):
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
resp = requests.get(full_url)
try:
resp.raise_for_status()
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
except requests.exceptions.HTTPError:
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
return get_bundle_versions(name)["latest_version"]

raise ValueError("NGC API requires requests package. Please install it.")


def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
Expand Down Expand Up @@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
if headers:
version_header.update(headers)
resp = requests_get(version_endpoint, headers=version_header)
resp = requests.get(version_endpoint, headers=version_header)
resp.raise_for_status()
model_info = json.loads(resp.text)
latest_versions = _list_latest_versions(model_info)

for version in latest_versions:
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
resp = requests_get(file_endpoint, headers=headers)
resp = requests.get(file_endpoint, headers=headers)
metadata = json.loads(resp.text)
resp.raise_for_status()
# if the package version is not available or the model is compatible with the package version
Expand Down Expand Up @@ -585,7 +599,16 @@ def download(
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
elif source_ == "monaihosting":
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
try:
_download_from_monaihosting(
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
)
except urllib.error.HTTPError:
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
_download_from_bundle_info(
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
)

elif source_ == "ngc":
_download_from_ngc(
download_path=bundle_dir_,
Expand Down Expand Up @@ -792,9 +815,9 @@ def _get_all_bundles_info(

if auth_token is not None:
headers = {"Authorization": f"Bearer {auth_token}"}
resp = requests_get(request_url, headers=headers)
resp = requests.get(request_url, headers=headers)
else:
resp = requests_get(request_url)
resp = requests.get(request_url)
resp.raise_for_status()
else:
raise ValueError("requests package is required, please install it.")
Expand Down

0 comments on commit a09c1f0

Please sign in to comment.