Skip to content

Commit

Permalink
Tweak progress columns. Replace tqdm.
Browse files Browse the repository at this point in the history
  • Loading branch information
jstlaurent committed Jan 27, 2025
1 parent 54f3e21 commit 6ca38df
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 95 deletions.
1 change: 0 additions & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ channels:
dependencies:
- python >=3.10
- pip
- tqdm
- typer
- pyyaml
- pydantic >=2
Expand Down
38 changes: 22 additions & 16 deletions polaris/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from polaris.dataset.zarr import MemoryMappedDirectoryStore
from polaris.dataset.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.context import track_progress
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import (
Expand Down Expand Up @@ -373,19 +374,24 @@ def _cache_zarr(self, destination: str | PathLike, if_exists: ZarrConflictResolu
# Copy over Zarr data to the destination
self._warn_about_remote_zarr = False

logger.info(f"Copying Zarr archive to {destination_zarr_root}. This may take a while.")
destination_store = zarr.open(str(destination_zarr_root), "w").store
source_store = self.zarr_root.store.store

if isinstance(source_store, S3Store):
source_store.copy_to_destination(destination_store, if_exists, logger.info)
else:
zarr.copy_store(
source=source_store,
dest=destination_store,
log=logger.info,
if_exists=if_exists,
)
self.zarr_root_path = str(destination_zarr_root)
self._zarr_root = None
self._zarr_data = None
with track_progress(description="Copying Zarr archive", total=1) as (
progress,
task,
):
progress.log(f"[green]Copying to destination {destination_zarr_root}")
progress.log("[yellow]For large Zarr archives, this may take a while.")
destination_store = zarr.open(str(destination_zarr_root), "w").store
source_store = self.zarr_root.store.store

if isinstance(source_store, S3Store):
source_store.copy_to_destination(destination_store, if_exists, logger.info)
else:
zarr.copy_store(
source=source_store,
dest=destination_store,
log=logger.info,
if_exists=if_exists,
)
self.zarr_root_path = str(destination_zarr_root)
self._zarr_root = None
self._zarr_data = None
2 changes: 2 additions & 0 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def cache(

if verify_checksum:
self.verify_checksum()
else:
self._md5sum = None

return str(destination)

Expand Down
71 changes: 39 additions & 32 deletions polaris/dataset/zarr/_checksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
import zarr.errors
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
from tqdm import tqdm

from polaris.utils.context import track_progress
from polaris.utils.errors import InvalidZarrChecksum

ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)"
Expand All @@ -56,8 +56,8 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
r"""
Implements an algorithm to compute the Zarr checksum.
Warning: This checksum is sensitive to Zarr configuration.
This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size,
Warning: This checksum is sensitive to Zarr configuration.
This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size,
the checksum will also change.
To understand how this works, consider the following directory structure:
Expand All @@ -67,17 +67,17 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
a c
/
b
Within zarr, this would for example be:
- `root`: A Zarr Group with a single Array.
- `a`: A Zarr Array
- `b`: A single chunk of the Zarr Array
- `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup)
- `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup)
To compute the checksum, we first find all the trees in the node, in this case b and c.
To compute the checksum, we first find all the trees in the node, in this case b and c.
We compute the hash of the content (the raw bytes) for each of these files.
We then work our way up the tree. For any node (directory), we find all children of that node.
In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children.
The hash of the directory is then equal to the hash of the serialized JSON.
Expand Down Expand Up @@ -116,33 +116,40 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
leaves = fs.find(zarr_root_path, detail=True)
zarr_md5sum_manifest = []

for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"):
path = file["name"]

relpath = path.removeprefix(zarr_root_path)
relpath = relpath.lstrip("/")
relpath = Path(relpath)

size = file["size"]

# Compute md5sum of file
md5sum = hashlib.md5()
with fs.open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
md5sum.update(chunk)
digest = md5sum.hexdigest()
files = leaves.values()
with track_progress(description="Finding all files in the Zarr archive", total=len(files)) as (
progress,
task,
):
for file in files:
path = file["name"]

relpath = path.removeprefix(zarr_root_path)
relpath = relpath.lstrip("/")
relpath = Path(relpath)

size = file["size"]

# Compute md5sum of file
md5sum = hashlib.md5()
with fs.open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
md5sum.update(chunk)
digest = md5sum.hexdigest()

# Add a leaf to the tree
# (This actually adds the file's checksum to the parent directory's manifest)
tree.add_leaf(
path=relpath,
size=size,
digest=digest,
)

# Add a leaf to the tree
# (This actually adds the file's checksum to the parent directory's manifest)
tree.add_leaf(
path=relpath,
size=size,
digest=digest,
)
# We persist the checksums for leaf nodes separately,
# because this is what the Hub needs to verify data integrity.
zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size))

# We persist the checksums for leaf nodes separately,
# because this is what the Hub needs to verify data integrity.
zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size))
progress.update(task, advance=1, refresh=True)

# Compute digest
return tree.process(), zarr_md5sum_manifest
Expand Down
15 changes: 8 additions & 7 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,7 @@ def _upload_v1_dataset(

# Step 3: Upload any associated Zarr archive
if dataset.uses_zarr:
with track_progress(description="Copying Zarr archive", total=1) as (progress, task):
progress.log("[yellow]This may take a while.")

with track_progress(description="Copying Zarr archive", total=1):
destination = storage.store("extension")

# Locally consolidate Zarr archive metadata. Future updates on handling consolidated
Expand Down Expand Up @@ -702,8 +700,11 @@ def _upload_v2_dataset(
storage.set_file("manifest", manifest_file.read())

# Step 3: Upload the Zarr archive
with track_progress(description="Copying Zarr archive", total=1) as (progress, task):
progress.log("[yellow]This may take a while.")
with track_progress(description="Copying Zarr archive", total=1) as (
progress_zarr,
task_zarr,
):
progress_zarr.log("[yellow]This may take a while.")

destination = storage.store("root")

Expand Down Expand Up @@ -830,11 +831,11 @@ def _upload_v2_benchmark(
# 2. Upload each index set bitmap
with track_progress(
description="Copying index sets", total=benchmark.split.n_test_sets + 1
) as (progress, task):
) as (progress_index_sets, task_index_sets):
for label, index_set in benchmark.split:
logger.info(f"Copying index set {label} to the Hub.")
storage.set_file(label, index_set.serialize())
progress.update(task, advance=1, refresh=True)
progress_index_sets.update(task_index_sets, advance=1, refresh=True)

benchmark_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location"))
progress.log(
Expand Down
80 changes: 47 additions & 33 deletions polaris/hub/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from zarr.util import buffer_size

from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token
from polaris.utils.context import track_progress
from polaris.utils.errors import PolarisHubError
from polaris.utils.types import ArtifactUrn, ZarrConflictResolution

Expand Down Expand Up @@ -168,25 +169,33 @@ def copy_to_destination(

number_source_keys = len(self)

batch_iter = iter(self)
while batch := tuple(islice(batch_iter, self._batch_size)):
to_put = batch if if_exists == "replace" else filter(lambda key: key not in destination, batch)
skipped = len(batch) - len(to_put)
with track_progress(description="Copying Zarr keys", total=number_source_keys) as (
progress_keys,
task_keys,
):
batch_iter = iter(self)
while batch := tuple(islice(batch_iter, self._batch_size)):
to_put = (
batch if if_exists == "replace" else filter(lambda key: key not in destination, batch)
)
skipped = len(batch) - len(to_put)

if skipped > 0 and if_exists == "raise":
raise CopyError(f"keys {to_put} exist in destination")
if skipped > 0:
if if_exists == "raise":
raise CopyError(f"keys {to_put} exist in destination")
else:
progress_keys.log(f"Skipped {skipped} keys that already exists")

items = self.getitems(to_put, contexts={})
for key, content in items.items():
destination[key] = content
total_bytes_copied += buffer_size(content)
items = self.getitems(to_put, contexts={})
for key, content in items.items():
destination[key] = content

total_copied += len(to_put)
total_skipped += skipped
size_copied = buffer_size(content)
total_bytes_copied += size_copied
total_copied += 1

log(
f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed."
)
total_skipped += skipped
progress_keys.update(task_keys, advance=len(batch), refresh=True)

return total_copied, total_skipped, total_bytes_copied

Expand Down Expand Up @@ -259,24 +268,29 @@ def copy_key(key: str, source: Store, if_exists: ZarrConflictResolution) -> tupl

number_source_keys = len(source)

# Batch the keys, otherwise we end up with too many files open at the same time
batch_iter = iter(source.keys())
while batch := tuple(islice(batch_iter, self._batch_size)):
# Create a future for each key to copy
future_to_key = [
executor.submit(copy_key, source_key, source, if_exists) for source_key in batch
]

# As each future completes, collect the results
for future in as_completed(future_to_key):
result_copied, result_skipped, result_bytes_copied = future.result()
total_copied += result_copied
total_skipped += result_skipped
total_bytes_copied += result_bytes_copied

log(
f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed."
)
with track_progress(description="Copying Zarr keys", total=number_source_keys) as (
progress_keys,
task_keys,
):
# Batch the keys, otherwise we end up with too many files open at the same time
batch_iter = iter(source.keys())
while batch := tuple(islice(batch_iter, self._batch_size)):
# Create a future for each key to copy
future_to_key = [
executor.submit(copy_key, source_key, source, if_exists) for source_key in batch
]

# As each future completes, collect the results
for future in as_completed(future_to_key):
result_copied, result_skipped, result_bytes_copied = future.result()
total_copied += result_copied
progress_keys.update(task_keys, advance=result_copied, refresh=True)

total_skipped += result_skipped
if result_skipped > 0:
progress_keys.log(f"Skipped {result_skipped} keys that already exists")

total_bytes_copied += result_bytes_copied

return total_copied, total_skipped, total_bytes_copied

Expand Down
6 changes: 3 additions & 3 deletions polaris/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)
Expand All @@ -18,7 +18,7 @@
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
),
)
Expand All @@ -33,7 +33,7 @@


@contextmanager
def track_progress(description: str, total: float | None = 100.0):
def track_progress(description: str, total: float | None = 1.0):
"""
Use the Progress instance to track a task's progress
"""
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ dependencies = [
"scikit-learn",
"scipy",
"seaborn",
"tqdm",
"typer",
"typing-extensions>=4.12.0",
"zarr >=2,<3",
Expand Down
2 changes: 0 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6ca38df

Please sign in to comment.