Skip to content

Commit

Permalink
chore: Caching for S3 store (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
jstlaurent authored Jan 14, 2025
1 parent ffca9da commit fe9da3a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
4 changes: 2 additions & 2 deletions polaris/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from os import PathLike
from pathlib import Path, PurePath
from typing import Any, MutableMapping
from typing import Any, Iterable, MutableMapping
from uuid import uuid4

import fsspec
Expand Down Expand Up @@ -232,7 +232,7 @@ def n_columns(self) -> int:

@property
@abc.abstractmethod
def rows(self) -> list[str | int]:
def rows(self) -> Iterable[str | int]:
"""Return all row indices for the dataset"""
raise NotImplementedError

Expand Down
22 changes: 11 additions & 11 deletions polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from os import PathLike
from pathlib import Path
from typing import Any, ClassVar, Literal
from typing import Any, ClassVar, Iterable, Literal

import fsspec
import numpy as np
Expand Down Expand Up @@ -82,19 +82,19 @@ def _validate_v2_dataset_model(self) -> Self:
def n_rows(self) -> int:
"""Return all row indices for the dataset"""
example = self.zarr_root[self.columns[0]]
if isinstance(example, zarr.Group):
return len(example[_INDEX_ARRAY_KEY])
return len(example)
match example:
case zarr.Group():
return len(example[_INDEX_ARRAY_KEY])
case _:
return len(example)

@property
def rows(self) -> np.ndarray[int]:
"""Return all row indices for the dataset
Warning: Memory consumption
This feature is added for completeness' sake, but it should be noted that large datasets could consume a lot of memory.
E.g. storing a billion indices with np.in64 would consume 8GB of memory. Use with caution.
def rows(self) -> Iterable[int]:
"""
Return all row indices for the dataset
This feature is added for completeness' sake, but does not provide any performance benefits.
"""
return np.arange(len(self), dtype=int)
return range(self.n_rows)

@property
def columns(self) -> list[str]:
Expand Down
18 changes: 16 additions & 2 deletions polaris/hub/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from base64 import b64encode
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from functools import lru_cache
from hashlib import md5
from itertools import islice
from pathlib import PurePath
Expand Down Expand Up @@ -129,6 +130,14 @@ def _multipart_upload(self, key: str, value: bytes) -> None:
Bucket=self.bucket_name, Key=full_key, UploadId=upload_id, MultipartUpload={"Parts": parts}
)

@lru_cache()
def _get_object_body(self, full_key: str) -> bytes:
"""
Basic caching for the object body, to avoid multiple reads on remote bucket.
"""
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=full_key)
return response["Body"].read()

## Custom methods

def copy_to_destination(
Expand Down Expand Up @@ -351,8 +360,7 @@ def __getitem__(self, key: str) -> bytes:
with handle_s3_errors():
try:
full_key = self._full_key(key)
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=full_key)
return response["Body"].read()
return self._get_object_body(full_key=full_key)
except self.s3_client.exceptions.NoSuchKey:
raise KeyError(key)

Expand Down Expand Up @@ -430,6 +438,12 @@ def __len__(self) -> int:

return sum((page["KeyCount"] for page in page_iterator))

def __hash__(self):
"""
Custom hash function, to enable lru_cache decorator on methods
"""
return hash((self.bucket_name, self.prefix, self.s3_client))


class StorageTokenAuth:
token: HubStorageOAuth2Token | None
Expand Down

0 comments on commit fe9da3a

Please sign in to comment.