diff --git a/env.yml b/env.yml index 63d47d90..df6dc398 100644 --- a/env.yml +++ b/env.yml @@ -4,7 +4,6 @@ channels: dependencies: - python >=3.10 - pip - - tqdm - typer - pyyaml - pydantic >=2 diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 5f0e33ee..92f17761 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -25,6 +25,7 @@ from polaris.dataset.zarr import MemoryMappedDirectoryStore from polaris.dataset.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.context import track_progress from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( @@ -373,19 +374,24 @@ def _cache_zarr(self, destination: str | PathLike, if_exists: ZarrConflictResolu # Copy over Zarr data to the destination self._warn_about_remote_zarr = False - logger.info(f"Copying Zarr archive to {destination_zarr_root}. This may take a while.") - destination_store = zarr.open(str(destination_zarr_root), "w").store - source_store = self.zarr_root.store.store - - if isinstance(source_store, S3Store): - source_store.copy_to_destination(destination_store, if_exists, logger.info) - else: - zarr.copy_store( - source=source_store, - dest=destination_store, - log=logger.info, - if_exists=if_exists, - ) - self.zarr_root_path = str(destination_zarr_root) - self._zarr_root = None - self._zarr_data = None + with track_progress(description="Copying Zarr archive", total=1) as ( + progress, + task, + ): + progress.log(f"[green]Copying to destination {destination_zarr_root}") + progress.log("[yellow]For large Zarr archives, this may take a while.") + destination_store = zarr.open(str(destination_zarr_root), "w").store + source_store = self.zarr_root.store.store + + if isinstance(source_store, S3Store): + source_store.copy_to_destination(destination_store, if_exists, logger.info) + else: + zarr.copy_store( + source=source_store, + dest=destination_store, + log=logger.info, + if_exists=if_exists, + ) + self.zarr_root_path = str(destination_zarr_root) + self._zarr_root = None + self._zarr_data = None diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 9aaa5deb..5829b837 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -332,6 +332,8 @@ def cache( if verify_checksum: self.verify_checksum() + else: + self._md5sum = None return str(destination) diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 8a1a8348..ccf5d9fd 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -45,8 +45,8 @@ import zarr.errors from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from tqdm import tqdm +from polaris.utils.context import track_progress from polaris.utils.errors import InvalidZarrChecksum ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" @@ -56,8 +56,8 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", r""" Implements an algorithm to compute the Zarr checksum. - Warning: This checksum is sensitive to Zarr configuration. - This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, + Warning: This checksum is sensitive to Zarr configuration. + This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, the checksum will also change. To understand how this works, consider the following directory structure: @@ -67,17 +67,17 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", a c / b - + Within zarr, this would for example be: - `root`: A Zarr Group with a single Array. - `a`: A Zarr Array - `b`: A single chunk of the Zarr Array - - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) + - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) - To compute the checksum, we first find all the trees in the node, in this case b and c. + To compute the checksum, we first find all the trees in the node, in this case b and c. We compute the hash of the content (the raw bytes) for each of these files. - + We then work our way up the tree. For any node (directory), we find all children of that node. In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children. The hash of the directory is then equal to the hash of the serialized JSON. @@ -116,33 +116,40 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", leaves = fs.find(zarr_root_path, detail=True) zarr_md5sum_manifest = [] - for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): - path = file["name"] - - relpath = path.removeprefix(zarr_root_path) - relpath = relpath.lstrip("/") - relpath = Path(relpath) - - size = file["size"] - - # Compute md5sum of file - md5sum = hashlib.md5() - with fs.open(path, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - md5sum.update(chunk) - digest = md5sum.hexdigest() + files = leaves.values() + with track_progress(description="Finding all files in the Zarr archive", total=len(files)) as ( + progress, + task, + ): + for file in files: + path = file["name"] + + relpath = path.removeprefix(zarr_root_path) + relpath = relpath.lstrip("/") + relpath = Path(relpath) + + size = file["size"] + + # Compute md5sum of file + md5sum = hashlib.md5() + with fs.open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + md5sum.update(chunk) + digest = md5sum.hexdigest() + + # Add a leaf to the tree + # (This actually adds the file's checksum to the parent directory's manifest) + tree.add_leaf( + path=relpath, + size=size, + digest=digest, + ) - # Add a leaf to the tree - # (This actually adds the file's checksum to the parent directory's manifest) - tree.add_leaf( - path=relpath, - size=size, - digest=digest, - ) + # We persist the checksums for leaf nodes separately, + # because this is what the Hub needs to verify data integrity. + zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) - # We persist the checksums for leaf nodes separately, - # because this is what the Hub needs to verify data integrity. - zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) + progress.update(task, advance=1, refresh=True) # Compute digest return tree.process(), zarr_md5sum_manifest diff --git a/polaris/hub/client.py b/polaris/hub/client.py index c9e1cc5d..52f2e6a5 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -641,9 +641,7 @@ def _upload_v1_dataset( # Step 3: Upload any associated Zarr archive if dataset.uses_zarr: - with track_progress(description="Copying Zarr archive", total=1) as (progress, task): - progress.log("[yellow]This may take a while.") - + with track_progress(description="Copying Zarr archive", total=1): destination = storage.store("extension") # Locally consolidate Zarr archive metadata. Future updates on handling consolidated @@ -702,8 +700,11 @@ def _upload_v2_dataset( storage.set_file("manifest", manifest_file.read()) # Step 3: Upload the Zarr archive - with track_progress(description="Copying Zarr archive", total=1) as (progress, task): - progress.log("[yellow]This may take a while.") + with track_progress(description="Copying Zarr archive", total=1) as ( + progress_zarr, + task_zarr, + ): + progress_zarr.log("[yellow]This may take a while.") destination = storage.store("root") @@ -830,11 +831,11 @@ def _upload_v2_benchmark( # 2. Upload each index set bitmap with track_progress( description="Copying index sets", total=benchmark.split.n_test_sets + 1 - ) as (progress, task): + ) as (progress_index_sets, task_index_sets): for label, index_set in benchmark.split: logger.info(f"Copying index set {label} to the Hub.") storage.set_file(label, index_set.serialize()) - progress.update(task, advance=1, refresh=True) + progress_index_sets.update(task_index_sets, advance=1, refresh=True) benchmark_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) progress.log( diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 2704875a..4e225e52 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -20,6 +20,7 @@ from zarr.util import buffer_size from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token +from polaris.utils.context import track_progress from polaris.utils.errors import PolarisHubError from polaris.utils.types import ArtifactUrn, ZarrConflictResolution @@ -168,25 +169,33 @@ def copy_to_destination( number_source_keys = len(self) - batch_iter = iter(self) - while batch := tuple(islice(batch_iter, self._batch_size)): - to_put = batch if if_exists == "replace" else filter(lambda key: key not in destination, batch) - skipped = len(batch) - len(to_put) + with track_progress(description="Copying Zarr keys", total=number_source_keys) as ( + progress_keys, + task_keys, + ): + batch_iter = iter(self) + while batch := tuple(islice(batch_iter, self._batch_size)): + to_put = ( + batch if if_exists == "replace" else filter(lambda key: key not in destination, batch) + ) + skipped = len(batch) - len(to_put) - if skipped > 0 and if_exists == "raise": - raise CopyError(f"keys {to_put} exist in destination") + if skipped > 0: + if if_exists == "raise": + raise CopyError(f"keys {to_put} exist in destination") + else: + progress_keys.log(f"Skipped {skipped} keys that already exists") - items = self.getitems(to_put, contexts={}) - for key, content in items.items(): - destination[key] = content - total_bytes_copied += buffer_size(content) + items = self.getitems(to_put, contexts={}) + for key, content in items.items(): + destination[key] = content - total_copied += len(to_put) - total_skipped += skipped + size_copied = buffer_size(content) + total_bytes_copied += size_copied + total_copied += 1 - log( - f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed." - ) + total_skipped += skipped + progress_keys.update(task_keys, advance=len(batch), refresh=True) return total_copied, total_skipped, total_bytes_copied @@ -259,24 +268,29 @@ def copy_key(key: str, source: Store, if_exists: ZarrConflictResolution) -> tupl number_source_keys = len(source) - # Batch the keys, otherwise we end up with too many files open at the same time - batch_iter = iter(source.keys()) - while batch := tuple(islice(batch_iter, self._batch_size)): - # Create a future for each key to copy - future_to_key = [ - executor.submit(copy_key, source_key, source, if_exists) for source_key in batch - ] - - # As each future completes, collect the results - for future in as_completed(future_to_key): - result_copied, result_skipped, result_bytes_copied = future.result() - total_copied += result_copied - total_skipped += result_skipped - total_bytes_copied += result_bytes_copied - - log( - f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed." - ) + with track_progress(description="Copying Zarr keys", total=number_source_keys) as ( + progress_keys, + task_keys, + ): + # Batch the keys, otherwise we end up with too many files open at the same time + batch_iter = iter(source.keys()) + while batch := tuple(islice(batch_iter, self._batch_size)): + # Create a future for each key to copy + future_to_key = [ + executor.submit(copy_key, source_key, source, if_exists) for source_key in batch + ] + + # As each future completes, collect the results + for future in as_completed(future_to_key): + result_copied, result_skipped, result_bytes_copied = future.result() + total_copied += result_copied + progress_keys.update(task_keys, advance=result_copied, refresh=True) + + total_skipped += result_skipped + if result_skipped > 0: + progress_keys.log(f"Skipped {result_skipped} keys that already exists") + + total_bytes_copied += result_bytes_copied return total_copied, total_skipped, total_bytes_copied diff --git a/polaris/utils/context.py b/polaris/utils/context.py index fccf66e4..edabd0cb 100644 --- a/polaris/utils/context.py +++ b/polaris/utils/context.py @@ -4,9 +4,9 @@ from rich.progress import ( BarColumn, + MofNCompleteColumn, Progress, SpinnerColumn, - TaskProgressColumn, TextColumn, TimeElapsedColumn, ) @@ -18,7 +18,7 @@ SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), - TaskProgressColumn(), + MofNCompleteColumn(), TimeElapsedColumn(), ), ) @@ -33,7 +33,7 @@ @contextmanager -def track_progress(description: str, total: float | None = 100.0): +def track_progress(description: str, total: float | None = 1.0): """ Use the Progress instance to track a task's progress """ diff --git a/pyproject.toml b/pyproject.toml index 966bd772..4d5ec5ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ dependencies = [ "scikit-learn", "scipy", "seaborn", - "tqdm", "typer", "typing-extensions>=4.12.0", "zarr >=2,<3", diff --git a/uv.lock b/uv.lock index acc82aaa..13664469 100644 --- a/uv.lock +++ b/uv.lock @@ -2233,7 +2233,6 @@ dependencies = [ { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, - { name = "tqdm" }, { name = "typer" }, { name = "typing-extensions" }, { name = "zarr" }, @@ -2284,7 +2283,6 @@ requires-dist = [ { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, - { name = "tqdm" }, { name = "typer" }, { name = "typing-extensions", specifier = ">=4.12.0" }, { name = "zarr", specifier = ">=2,<3" },