Skip to content

Commit

Permalink
feat: skip dataset re-download and ensure safe dataset syncing (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Jan 8, 2025
1 parent 1cbc795 commit 11ee765
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 22 deletions.
2 changes: 2 additions & 0 deletions luxonis_ml/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LuxonisComponent,
LuxonisDataset,
LuxonisSource,
UpdateMode,
)
from .loaders import LOADERS_REGISTRY, BaseLoader, LuxonisLoader
from .parsers import LuxonisParser
Expand Down Expand Up @@ -46,6 +47,7 @@ def load_loader_plugins() -> None: # pragma: no cover
"ImageType",
"LuxonisComponent",
"LuxonisDataset",
"UpdateMode",
"LuxonisLoader",
"LuxonisParser",
"LuxonisSource",
Expand Down
3 changes: 2 additions & 1 deletion luxonis_ml/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
load_annotation,
)
from .base_dataset import DATASETS_REGISTRY, BaseDataset, DatasetIterator
from .luxonis_dataset import LuxonisDataset
from .luxonis_dataset import LuxonisDataset, UpdateMode
from .source import LuxonisComponent, LuxonisSource

__all__ = [
Expand All @@ -25,4 +25,5 @@
"load_annotation",
"Detection",
"ArrayAnnotation",
"UpdateMode",
]
79 changes: 66 additions & 13 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
import polars as pl
import pyarrow.parquet as pq
from filelock import FileLock
from ordered_set import OrderedSet
from semver.version import Version
from typing_extensions import Self, override
Expand All @@ -34,6 +35,7 @@
BucketStorage,
BucketType,
ParquetFileManager,
UpdateMode,
infer_task,
warn_on_duplicates,
)
Expand Down Expand Up @@ -133,7 +135,13 @@ def __init__(
else:
self.fs = LuxonisFileSystem(self.path)

self.metadata = cast(Metadata, defaultdict(dict, self._get_metadata()))
_lock_metadata = self.base_path / ".metadata.lock"
with FileLock(
str(_lock_metadata)
): # DDP GCS training - multiple processes
self.metadata = cast(
Metadata, defaultdict(dict, self._get_metadata())
)

if self.version != LDF_VERSION:
logger.warning(
Expand Down Expand Up @@ -226,9 +234,18 @@ def _load_df_offline(
def _load_df_offline(
self, lazy: bool = False
) -> Optional[Union[pl.DataFrame, pl.LazyFrame]]:
path = get_dir(self.fs, "annotations", self.local_path)
"""Loads the dataset DataFrame **always** from the local
storage."""
path = (
self.base_path
/ "data"
/ self.team_id
/ "datasets"
/ self.dataset_name
/ "annotations"
)

if path is None or not path.exists():
if not path.exists():
return None

if lazy:
Expand Down Expand Up @@ -278,6 +295,11 @@ def _get_file_index(
def _get_file_index(
self, lazy: bool = False
) -> Optional[Union[pl.DataFrame, pl.LazyFrame]]:
"""Loads the file index DataFrame from the local storage or the
cloud.
If loads from cloud it always downloads before loading.
"""
path = get_file(
self.fs, "metadata/file_index.parquet", self.metadata_path
)
Expand Down Expand Up @@ -327,6 +349,11 @@ def _init_credentials(self) -> Dict[str, Any]:
return {}

def _get_metadata(self) -> Metadata:
"""Loads metadata from local storage or cloud, depending on the
BucketStorage type.
If loads from cloud it always downloads before loading.
"""
if self.fs.exists("metadata/metadata.json"):
path = get_file(
self.fs,
Expand Down Expand Up @@ -418,22 +445,48 @@ def get_skeletons(
def get_tasks(self) -> List[str]:
return self.metadata.get("tasks", [])

def sync_from_cloud(self, force: bool = False) -> None:
"""Downloads data from a remote cloud bucket."""
def sync_from_cloud(
self, update_mode: UpdateMode = UpdateMode.IF_EMPTY
) -> None:
"""Synchronizes the dataset from a remote cloud bucket to the
local directory.
This method performs the download only if local data is empty, or always downloads
depending on the provided update_mode.
@type update_mode: UpdateMode
@param update_mode: Specifies the update behavior.
- UpdateMode.IF_EMPTY: Downloads data only if the local dataset is empty.
- UpdateMode.ALWAYS: Always downloads and overwrites the local dataset.
"""
if not self.is_remote:
logger.warning("This is a local dataset! Cannot sync")
else:
if not self._is_synced or force:
logger.info("Syncing from cloud...")
local_dir = self.base_path / "data" / self.team_id / "datasets"
local_dir.mkdir(exist_ok=True, parents=True)
logger.warning("This is a local dataset! Cannot sync from cloud.")
return

self.fs.get_dir(remote_paths="", local_dir=local_dir)
local_dir = self.base_path / "data" / self.team_id / "datasets"
local_dir.mkdir(exist_ok=True, parents=True)

lock_path = local_dir / ".sync.lock"

with FileLock(str(lock_path)): # DDP GCS training - multiple processes
any_subfolder_empty = any(
subfolder.is_dir() and not any(subfolder.iterdir())
for subfolder in (local_dir / self.dataset_name).iterdir()
if subfolder.is_dir()
)
if update_mode == UpdateMode.IF_EMPTY and not any_subfolder_empty:
logger.info(
"Local dataset directory already exists. Skipping download."
)
return
if update_mode == UpdateMode.ALWAYS or not self._is_synced:
logger.info("Syncing from cloud...")
self.fs.get_dir(remote_paths="", local_dir=local_dir)
self._is_synced = True
else:
logger.warning("Already synced. Use force=True to resync")
logger.warning(
"Already synced. Use update_mode=ALWAYS to resync."
)

@override
def delete_dataset(self, *, delete_remote: bool = False) -> None:
Expand Down
12 changes: 7 additions & 5 deletions luxonis_ml/data/loaders/luxonis_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from luxonis_ml.data.datasets import (
Annotation,
LuxonisDataset,
UpdateMode,
load_annotation,
)
from luxonis_ml.data.loaders.base_loader import BaseLoader
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
keep_aspect_ratio: bool = True,
out_image_format: Literal["RGB", "BGR"] = "RGB",
*,
force_resync: bool = False,
update_mode: UpdateMode = UpdateMode.ALWAYS,
) -> None:
"""A loader class used for loading data from L{LuxonisDataset}.
Expand Down Expand Up @@ -84,9 +85,10 @@ def __init__(
@type width: Optional[int]
@param width: The width of the output images. Defaults to
C{None}.
@type force_resync: bool
@param force_resync: Flag to force resync from cloud. Defaults
to C{False}.
@type update_mode: UpdateMode
@param update_mode: Enum that determines the sync mode:
- UpdateMode.ALWAYS: Force a fresh download
- UpdateMode.IF_EMPTY: Skip downloading if local data exists
"""

self.logger = logging.getLogger(__name__)
Expand All @@ -96,7 +98,7 @@ def __init__(
self.sync_mode = self.dataset.is_remote

if self.sync_mode:
self.dataset.sync_from_cloud(force=force_resync)
self.dataset.sync_from_cloud(update_mode=update_mode)

if isinstance(view, str):
view = [view]
Expand Down
3 changes: 2 additions & 1 deletion luxonis_ml/data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .data_utils import infer_task, rgb_to_bool_masks, warn_on_duplicates
from .enums import BucketStorage, BucketType, ImageType, MediaType
from .enums import BucketStorage, BucketType, ImageType, MediaType, UpdateMode
from .parquet import ParquetDetection, ParquetFileManager, ParquetRecord
from .task_utils import (
get_task_name,
Expand All @@ -24,6 +24,7 @@
"ImageType",
"BucketType",
"BucketStorage",
"UpdateMode",
"get_task_name",
"task_type_iterator",
"task_is_metadata",
Expand Down
7 changes: 7 additions & 0 deletions luxonis_ml/data/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ class BucketStorage(Enum):
S3 = "s3"
GCS = "gcs"
AZURE_BLOB = "azure"


class UpdateMode(Enum):
"""Update mode for the dataset."""

ALWAYS = "always"
IF_EMPTY = "if_empty"
2 changes: 2 additions & 0 deletions luxonis_ml/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import time
from functools import wraps
from importlib.util import find_spec
from pathlib import Path
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
if rank == 0:
self.run_name = self._get_run_name()
else:
time.sleep(1) # DDP hotfix
self.run_name = self._get_latest_run_name()

Path(f"{self.save_directory}/{self.run_name}").mkdir(
Expand Down
9 changes: 7 additions & 2 deletions tests/test_data/test_task_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import numpy as np
import pytest

from luxonis_ml.data import BucketStorage, LuxonisDataset, LuxonisLoader
from luxonis_ml.data import (
BucketStorage,
LuxonisDataset,
LuxonisLoader,
UpdateMode,
)
from luxonis_ml.data.utils import get_task_name, get_task_type

DATA_DIR = Path("tests/data/test_task_ingestion")
Expand Down Expand Up @@ -36,7 +41,7 @@ def make_image(i) -> Path:

def compute_histogram(dataset: LuxonisDataset) -> Dict[str, int]:
classes = defaultdict(int)
loader = LuxonisLoader(dataset, force_resync=True)
loader = LuxonisLoader(dataset, update_mode=UpdateMode.ALWAYS)
for _, record in loader:
for task, _ in record.items():
if get_task_type(task) != "classification":
Expand Down

0 comments on commit 11ee765

Please sign in to comment.