Skip to content

Commit

Permalink
feat: skip dataset re-download
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Jan 6, 2025
1 parent 1cbc795 commit a414b7e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
13 changes: 10 additions & 3 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,22 @@ def get_skeletons(
def get_tasks(self) -> List[str]:
return self.metadata.get("tasks", [])

def sync_from_cloud(self, force: bool = False) -> None:
def sync_from_cloud(
self, force: bool = False, skip_redownload_dataset: bool = False
) -> None:
"""Downloads data from a remote cloud bucket."""

if not self.is_remote:
logger.warning("This is a local dataset! Cannot sync")
else:
local_dir = self.base_path / "data" / self.team_id / "datasets"
if local_dir.exists() and skip_redownload_dataset and not force:
logger.info(
"Local dataset directory already exists. Skipping download."
)
return

Check warning on line 433 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L433

Added line #L433 was not covered by tests

if not self._is_synced or force:
logger.info("Syncing from cloud...")
local_dir = self.base_path / "data" / self.team_id / "datasets"
local_dir.mkdir(exist_ok=True, parents=True)

self.fs.get_dir(remote_paths="", local_dir=local_dir)
Expand Down
8 changes: 7 additions & 1 deletion luxonis_ml/data/loaders/luxonis_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
out_image_format: Literal["RGB", "BGR"] = "RGB",
*,
force_resync: bool = False,
skip_redownload_dataset: bool = False,
) -> None:
"""A loader class used for loading data from L{LuxonisDataset}.
Expand Down Expand Up @@ -87,6 +88,8 @@ def __init__(
@type force_resync: bool
@param force_resync: Flag to force resync from cloud. Defaults
to C{False}.
@param skip_redownload_dataset: If True, skip downloading when local dataset
already exists. If False, force redownload (unless force_resync is True).
"""

self.logger = logging.getLogger(__name__)
Expand All @@ -96,7 +99,10 @@ def __init__(
self.sync_mode = self.dataset.is_remote

if self.sync_mode:
self.dataset.sync_from_cloud(force=force_resync)
self.dataset.sync_from_cloud(
force=force_resync,
skip_redownload_dataset=skip_redownload_dataset,
)

if isinstance(view, str):
view = [view]
Expand Down

0 comments on commit a414b7e

Please sign in to comment.