Skip to content

Commit

Permalink
refactor sync_from_cloud for improved clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Jan 7, 2025
1 parent f86d8d5 commit f92881d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
3 changes: 2 additions & 1 deletion luxonis_ml/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
load_annotation,
)
from .base_dataset import DATASETS_REGISTRY, BaseDataset, DatasetIterator
from .luxonis_dataset import LuxonisDataset
from .luxonis_dataset import LuxonisDataset, UpdateMode
from .source import LuxonisComponent, LuxonisSource

__all__ = [
Expand All @@ -25,4 +25,5 @@
"load_annotation",
"Detection",
"ArrayAnnotation",
"UpdateMode",
]
25 changes: 15 additions & 10 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from collections import defaultdict
from contextlib import suppress
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -68,6 +69,11 @@ class Metadata(TypedDict):
skeletons: Dict[str, Skeletons]


class UpdateMode(Enum):
ALWAYS = "always"
IF_EMPTY = "if_empty"


class LuxonisDataset(BaseDataset):
def __init__(
self,
Expand Down Expand Up @@ -427,11 +433,12 @@ def get_tasks(self) -> List[str]:
return self.metadata.get("tasks", [])

def sync_from_cloud(
self, force: bool = False, skip_redownload_dataset: bool = False
self, update_mode: UpdateMode = UpdateMode.IF_EMPTY
) -> None:
"""Downloads data from a remote cloud bucket."""

if not self.is_remote:
logger.warning("This is a local dataset! Cannot sync.")
logger.warning("This is a local dataset! Cannot sync from cloud.")
return

local_dir = self.base_path / "data" / self.team_id / "datasets"
Expand All @@ -443,23 +450,21 @@ def sync_from_cloud(
any_subfolder_empty = any(
subfolder.is_dir() and not any(subfolder.iterdir())
for subfolder in (local_dir / self.dataset_name).iterdir()
if subfolder.is_dir()
)
if (
not any_subfolder_empty
and skip_redownload_dataset
and not force
):
if update_mode == UpdateMode.IF_EMPTY and not any_subfolder_empty:
logger.info(
"Local dataset directory already exists. Skipping download."
)
return

if not self._is_synced or force:
if update_mode == UpdateMode.ALWAYS or not self._is_synced:
logger.info("Syncing from cloud...")
self.fs.get_dir(remote_paths="", local_dir=local_dir)
self._is_synced = True
else:
logger.warning("Already synced. Use force=True to resync.")
logger.warning(
"Already synced. Use update_mode=ALWAYS to resync."
)

@override
def delete_dataset(self, *, delete_remote: bool = False) -> None:
Expand Down
17 changes: 6 additions & 11 deletions luxonis_ml/data/loaders/luxonis_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from luxonis_ml.data.datasets import (
Annotation,
LuxonisDataset,
UpdateMode,
load_annotation,
)
from luxonis_ml.data.loaders.base_loader import BaseLoader
Expand Down Expand Up @@ -48,8 +49,7 @@ def __init__(
keep_aspect_ratio: bool = True,
out_image_format: Literal["RGB", "BGR"] = "RGB",
*,
force_resync: bool = False,
skip_redownload_dataset: bool = False,
update_mode: UpdateMode = UpdateMode.ALWAYS,
) -> None:
"""A loader class used for loading data from L{LuxonisDataset}.
Expand Down Expand Up @@ -85,11 +85,9 @@ def __init__(
@type width: Optional[int]
@param width: The width of the output images. Defaults to
C{None}.
@type force_resync: bool
@param force_resync: Flag to force resync from cloud. Defaults
to C{False}.
@param skip_redownload_dataset: If True, skip downloading when local dataset
already exists. If False, force redownload (unless force_resync is True).
@param update_mode: Enum that determines the sync mode:
- UpdateMode.ALWAYS: Force a fresh download
- UpdateMode.IF_EMPTY: Skip downloading if local data exists
"""

self.logger = logging.getLogger(__name__)
Expand All @@ -99,10 +97,7 @@ def __init__(
self.sync_mode = self.dataset.is_remote

if self.sync_mode:
self.dataset.sync_from_cloud(
force=force_resync,
skip_redownload_dataset=skip_redownload_dataset,
)
self.dataset.sync_from_cloud(update_mode=update_mode)

if isinstance(view, str):
view = [view]
Expand Down

0 comments on commit f92881d

Please sign in to comment.