Skip to content

Commit

Permalink
moved to #358
Browse files Browse the repository at this point in the history
  • Loading branch information
malcolmgreaves committed Oct 23, 2024
1 parent 50318bb commit 5089bf5
Showing 1 changed file with 21 additions and 63 deletions.
84 changes: 21 additions & 63 deletions sub-packages/bionemo-testing/src/bionemo/testing/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import contextlib
import shutil
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Sequence
from typing import Literal

import boto3
import ngcsdk
Expand All @@ -32,9 +32,6 @@
from bionemo.testing.data.resource import Resource, get_all_resources


__all__: Sequence[str] = ("load",)


def default_pbss_client():
"""Create a default S3 client for PBSS."""
retry_config = Config(retries={"max_attempts": 10, "mode": "standard"})
Expand Down Expand Up @@ -198,20 +195,10 @@ def _get_processor(extension: str, unpack: bool | None, decompress: bool | None)
return None


def print_resources():
"""Prints all available downloadable resources & their sources to STDOUT."""
print("#resource_name\tsource_options")
for resource_name, resource in sorted(get_all_resources().items()):
sources = []
if resource.ngc is not None:
sources.append("ngc")
if resource.pbss is not None:
sources.append("pbss")
print(f"{resource_name}\t{','.join(sources)}")


def main_cli():
"""Allows a user to get a specific artifact from the command line."""
import argparse

parser = argparse.ArgumentParser(
description="Retrieve the local path to the requested artifact name or list resources."
)
Expand All @@ -236,56 +223,27 @@ def main_cli():
help='Backend to use, NVIDIA users should set this to "pbss".',
)

parser.add_argument(
"--all",
action="store_true",
default=False,
help="Download all resources. Ignores all other options.",
)

# Parse the command line arguments
args = parser.parse_args()

download_all = args.all
list_resources = args.list_resources
artifact_name = args.artifact_name
source = args.source

# main script logic
if download_all:
print("Downloading all resources:")
print_resources()
print("-" * 80)

resource_to_local: dict[str, Path] = {}
for resource_name in tqdm(
sorted(get_all_resources()),
desc="Downloading Resources",
):
local_path = load(resource_name, source=source)
resource_to_local[resource_name] = local_path

print("-" * 80)
print("All resources downloaded:")
for resource_name, local_path in sorted(resource_to_local.items()):
print(f" {resource_name}: {str(local_path.absolute())}")

if args.list_resources:
print("#resource_name\tsource_options")
for resource_name, resource in sorted(get_all_resources().items()):
sources = []
if resource.ngc is not None:
sources.append("ngc")
if resource.pbss is not None:
sources.append("pbss")
print(f"{resource_name}\t{','.join(sources)}")
sys.exit(0) # Successful exit
elif args.artifact_name:
# Redirect stdout from the subprocess calls to stderr
with contextlib.redirect_stdout(sys.stderr):
local_path = load(args.artifact_name, source=args.source)
# Print the result
print(str(local_path.absolute()))
else:
if list_resources:
print_resources()

elif artifact_name is not None and len(artifact_name) > 0:
# Get the local path for the provided artifact name
local_path = load(artifact_name, source=source)
# Print the result
print(str(local_path.absolute()))
else:
parser.error("You must provide an artifact name if --list-resources or --all is not set!")

# Any exceptions encountered above will cause the program to exist w/ a non-zero exit code.
# Additionally, the `parser.error` call will end the program with a non-zero exit code.

sys.exit(0) # Successful exit
parser.error("You must provide an artifact name if --list-resources is not set.")


if __name__ == "__main__":
Expand Down

0 comments on commit 5089bf5

Please sign in to comment.