Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: skip dataset re-download and ensure safe dataset syncing #220

Merged
merged 10 commits into from
Jan 8, 2025
2 changes: 2 additions & 0 deletions luxonis_ml/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LuxonisComponent,
LuxonisDataset,
LuxonisSource,
UpdateMode,
)
from .loaders import LOADERS_REGISTRY, BaseLoader, LuxonisLoader
from .parsers import LuxonisParser
Expand Down Expand Up @@ -46,6 +47,7 @@ def load_loader_plugins() -> None: # pragma: no cover
"ImageType",
"LuxonisComponent",
"LuxonisDataset",
"UpdateMode",
"LuxonisLoader",
"LuxonisParser",
"LuxonisSource",
Expand Down
3 changes: 2 additions & 1 deletion luxonis_ml/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
load_annotation,
)
from .base_dataset import DATASETS_REGISTRY, BaseDataset, DatasetIterator
from .luxonis_dataset import LuxonisDataset
from .luxonis_dataset import LuxonisDataset, UpdateMode
from .source import LuxonisComponent, LuxonisSource

__all__ = [
Expand All @@ -25,4 +25,5 @@
"load_annotation",
"Detection",
"ArrayAnnotation",
"UpdateMode",
]
53 changes: 42 additions & 11 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from collections import defaultdict
from contextlib import suppress
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import (
Expand All @@ -26,6 +27,7 @@
import numpy as np
import polars as pl
import pyarrow.parquet as pq
from filelock import FileLock
from ordered_set import OrderedSet
from semver.version import Version
from typing_extensions import Self, override
Expand Down Expand Up @@ -67,6 +69,11 @@
skeletons: Dict[str, Skeletons]


class UpdateMode(Enum):
JSabadin marked this conversation as resolved.
Show resolved Hide resolved
ALWAYS = "always"
IF_EMPTY = "if_empty"


class LuxonisDataset(BaseDataset):
def __init__(
self,
Expand Down Expand Up @@ -226,9 +233,16 @@
def _load_df_offline(
self, lazy: bool = False
) -> Optional[Union[pl.DataFrame, pl.LazyFrame]]:
path = get_dir(self.fs, "annotations", self.local_path)
path = (
self.base_path
/ "data"
/ self.team_id
/ "datasets"
/ self.dataset_name
/ "annotations"
)

if path is None or not path.exists():
if not path.exists():
return None

if lazy:
Expand Down Expand Up @@ -418,22 +432,39 @@
def get_tasks(self) -> List[str]:
return self.metadata.get("tasks", [])

def sync_from_cloud(self, force: bool = False) -> None:
def sync_from_cloud(
self, update_mode: UpdateMode = UpdateMode.IF_EMPTY
) -> None:
JSabadin marked this conversation as resolved.
Show resolved Hide resolved
"""Downloads data from a remote cloud bucket."""

if not self.is_remote:
logger.warning("This is a local dataset! Cannot sync")
else:
if not self._is_synced or force:
logger.info("Syncing from cloud...")
local_dir = self.base_path / "data" / self.team_id / "datasets"
local_dir.mkdir(exist_ok=True, parents=True)
logger.warning("This is a local dataset! Cannot sync from cloud.")
return

Check warning on line 442 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L442

Added line #L442 was not covered by tests

self.fs.get_dir(remote_paths="", local_dir=local_dir)
local_dir = self.base_path / "data" / self.team_id / "datasets"
local_dir.mkdir(exist_ok=True, parents=True)

lock_path = local_dir / ".sync.lock"

with FileLock(str(lock_path)):
any_subfolder_empty = any(
subfolder.is_dir() and not any(subfolder.iterdir())
for subfolder in (local_dir / self.dataset_name).iterdir()
if subfolder.is_dir()
)
if update_mode == UpdateMode.IF_EMPTY and not any_subfolder_empty:
logger.info(
"Local dataset directory already exists. Skipping download."
)
return

Check warning on line 459 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L459

Added line #L459 was not covered by tests
if update_mode == UpdateMode.ALWAYS or not self._is_synced:
logger.info("Syncing from cloud...")
self.fs.get_dir(remote_paths="", local_dir=local_dir)
self._is_synced = True
else:
logger.warning("Already synced. Use force=True to resync")
logger.warning(
"Already synced. Use update_mode=ALWAYS to resync."
)

@override
def delete_dataset(self, *, delete_remote: bool = False) -> None:
Expand Down
11 changes: 6 additions & 5 deletions luxonis_ml/data/loaders/luxonis_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from luxonis_ml.data.datasets import (
Annotation,
LuxonisDataset,
UpdateMode,
load_annotation,
)
from luxonis_ml.data.loaders.base_loader import BaseLoader
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
keep_aspect_ratio: bool = True,
out_image_format: Literal["RGB", "BGR"] = "RGB",
*,
force_resync: bool = False,
update_mode: UpdateMode = UpdateMode.ALWAYS,
) -> None:
"""A loader class used for loading data from L{LuxonisDataset}.

Expand Down Expand Up @@ -84,9 +85,9 @@ def __init__(
@type width: Optional[int]
@param width: The width of the output images. Defaults to
C{None}.
@type force_resync: bool
@param force_resync: Flag to force resync from cloud. Defaults
to C{False}.
@param update_mode: Enum that determines the sync mode:
JSabadin marked this conversation as resolved.
Show resolved Hide resolved
- UpdateMode.ALWAYS: Force a fresh download
- UpdateMode.IF_EMPTY: Skip downloading if local data exists
"""

self.logger = logging.getLogger(__name__)
Expand All @@ -96,7 +97,7 @@ def __init__(
self.sync_mode = self.dataset.is_remote

if self.sync_mode:
self.dataset.sync_from_cloud(force=force_resync)
self.dataset.sync_from_cloud(update_mode=update_mode)

if isinstance(view, str):
view = [view]
Expand Down
2 changes: 2 additions & 0 deletions luxonis_ml/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import time
from functools import wraps
from importlib.util import find_spec
from pathlib import Path
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
if rank == 0:
self.run_name = self._get_run_name()
else:
time.sleep(1) # DDP hotfix
self.run_name = self._get_latest_run_name()

Path(f"{self.save_directory}/{self.run_name}").mkdir(
Expand Down
9 changes: 7 additions & 2 deletions tests/test_data/test_task_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import numpy as np
import pytest

from luxonis_ml.data import BucketStorage, LuxonisDataset, LuxonisLoader
from luxonis_ml.data import (
BucketStorage,
LuxonisDataset,
LuxonisLoader,
UpdateMode,
)
from luxonis_ml.data.utils import get_task_name, get_task_type

DATA_DIR = Path("tests/data/test_task_ingestion")
Expand Down Expand Up @@ -36,7 +41,7 @@ def make_image(i) -> Path:

def compute_histogram(dataset: LuxonisDataset) -> Dict[str, int]:
classes = defaultdict(int)
loader = LuxonisLoader(dataset, force_resync=True)
loader = LuxonisLoader(dataset, update_mode=UpdateMode.ALWAYS)
for _, record in loader:
for task, _ in record.items():
if get_task_type(task) != "classification":
Expand Down
Loading