diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 1cc2ea1f..875de6f4 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -2,7 +2,7 @@ import json from os import PathLike from pathlib import Path, PurePath -from typing import Any, MutableMapping +from typing import Any, Iterable, MutableMapping from uuid import uuid4 import fsspec @@ -232,7 +232,7 @@ def n_columns(self) -> int: @property @abc.abstractmethod - def rows(self) -> list[str | int]: + def rows(self) -> Iterable[str | int]: """Return all row indices for the dataset""" raise NotImplementedError diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 5144f1f7..4e35f062 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -2,7 +2,7 @@ import re from os import PathLike from pathlib import Path -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Iterable, Literal import fsspec import numpy as np @@ -82,19 +82,19 @@ def _validate_v2_dataset_model(self) -> Self: def n_rows(self) -> int: """Return all row indices for the dataset""" example = self.zarr_root[self.columns[0]] - if isinstance(example, zarr.Group): - return len(example[_INDEX_ARRAY_KEY]) - return len(example) + match example: + case zarr.Group(): + return len(example[_INDEX_ARRAY_KEY]) + case _: + return len(example) @property - def rows(self) -> np.ndarray[int]: - """Return all row indices for the dataset - - Warning: Memory consumption - This feature is added for completeness' sake, but it should be noted that large datasets could consume a lot of memory. - E.g. storing a billion indices with np.in64 would consume 8GB of memory. Use with caution. + def rows(self) -> Iterable[int]: + """ + Return all row indices for the dataset + This feature is added for completeness' sake, but does not provide any performance benefits. """ - return np.arange(len(self), dtype=int) + return range(self.n_rows) @property def columns(self) -> list[str]: diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 325c9534..4ef41558 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -2,6 +2,7 @@ from base64 import b64encode from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager +from functools import lru_cache from hashlib import md5 from itertools import islice from pathlib import PurePath @@ -129,6 +130,14 @@ def _multipart_upload(self, key: str, value: bytes) -> None: Bucket=self.bucket_name, Key=full_key, UploadId=upload_id, MultipartUpload={"Parts": parts} ) + @lru_cache() + def _get_object_body(self, full_key: str) -> bytes: + """ + Basic caching for the object body, to avoid multiple reads on remote bucket. + """ + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=full_key) + return response["Body"].read() + ## Custom methods def copy_to_destination( @@ -351,8 +360,7 @@ def __getitem__(self, key: str) -> bytes: with handle_s3_errors(): try: full_key = self._full_key(key) - response = self.s3_client.get_object(Bucket=self.bucket_name, Key=full_key) - return response["Body"].read() + return self._get_object_body(full_key=full_key) except self.s3_client.exceptions.NoSuchKey: raise KeyError(key) @@ -430,6 +438,12 @@ def __len__(self) -> int: return sum((page["KeyCount"] for page in page_iterator)) + def __hash__(self): + """ + Custom hash function, to enable lru_cache decorator on methods + """ + return hash((self.bucket_name, self.prefix, self.s3_client)) + class StorageTokenAuth: token: HubStorageOAuth2Token | None