Skip to content

Commit

Permalink
faster merge_with, fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Jan 9, 2025
1 parent f975549 commit bbde0d2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 81 deletions.
52 changes: 28 additions & 24 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def clone(
self, new_dataset_name: str, push_to_cloud: bool = True
) -> "LuxonisDataset":
"""Create a new LuxonisDataset that is a local copy of the
current dataset.
current dataset. Cloned dataset will overwrite the existing
dataset with the same name.
@type new_dataset_name: str
@param new_dataset_name: Name of the newly created dataset.
Expand Down Expand Up @@ -358,64 +359,67 @@ def merge_with(
@param new_dataset_name: The name of the new dataset to create
if inplace is False.
"""
if not inplace:
if not new_dataset_name:
raise ValueError(
"You must specify a name for the new dataset when inplace is False."
)
new_dataset = self.clone(new_dataset_name, push_to_cloud=False)
new_dataset.merge_with(other, inplace=True)
return new_dataset

if other.is_remote != self.is_remote:
if not inplace and not new_dataset_name:
raise ValueError(

Check warning on line 363 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L363

Added line #L363 was not covered by tests
"Merging is only supported for datasets with the same bucket storage."
"You must specify a name for the new dataset when inplace is False."
)

target_dataset = (
self
if inplace
else self.clone(new_dataset_name, push_to_cloud=False)
)

if self.is_remote:
other.sync_from_cloud(update_mode=UpdateMode.ALWAYS)
self.sync_from_cloud(update_mode=UpdateMode.IF_EMPTY)
self.sync_from_cloud(
update_mode=UpdateMode.ALWAYS
if inplace
else UpdateMode.IF_EMPTY
)

df_self = self._load_df_offline()
df_other = other._load_df_offline()
duplicate_uuids = set(df_self["uuid"]).intersection(df_other["uuid"])
if duplicate_uuids: # skip duplicate uuids
if duplicate_uuids:
df_other = df_other.filter(

Check warning on line 385 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L385

Added line #L385 was not covered by tests
~df_other["uuid"].is_in(duplicate_uuids)
)

df_merged = pl.concat([df_self, df_other])
target_dataset._save_df_offline(df_merged)

file_index_self = self._get_file_index()
file_index_other = other._get_file_index()
file_index_duplicates = set(file_index_self["uuid"]).intersection(
file_index_other["uuid"]
)
if file_index_duplicates: # skip duplicate uuids
if file_index_duplicates:
file_index_other = file_index_other.filter(

Check warning on line 398 in luxonis_ml/data/datasets/luxonis_dataset.py

View check run for this annotation

Codecov / codecov/patch

luxonis_ml/data/datasets/luxonis_dataset.py#L398

Added line #L398 was not covered by tests
~file_index_other["uuid"].is_in(file_index_duplicates)
)
merged_file_index = pl.concat([file_index_self, file_index_other])

merged_file_index = pl.concat([file_index_self, file_index_other])
if merged_file_index is not None:
file_index_path = self.metadata_path / "file_index.parquet"
file_index_path = (
target_dataset.metadata_path / "file_index.parquet"
)
merged_file_index.write_parquet(file_index_path)

splits_self = self._load_splits(self.metadata_path)
splits_other = self._load_splits(other.metadata_path)
self._merge_splits(splits_self, splits_other)

self._save_df_offline(df_merged)
self._save_splits(splits_self)
target_dataset._save_splits(splits_self)

if self.is_remote:
shutil.copytree(
other.media_path, self.media_path, dirs_exist_ok=True
other.media_path, target_dataset.media_path, dirs_exist_ok=True
)
self.sync_to_cloud()
target_dataset.sync_to_cloud()

self._merge_metadata_with(other)
target_dataset._merge_metadata_with(other)

return self
return target_dataset

def _load_splits(self, path: Path) -> Dict[str, List[str]]:
splits_path = path / "splits.json"
Expand Down
64 changes: 7 additions & 57 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def generator1():
assert df_cloned.equals(df_original)


def test_merge_datasets_inplace(tempdir: Path, bucket_storage: BucketStorage):
def test_merge_datasets(tempdir: Path, bucket_storage: BucketStorage):
dataset1_name = "test_merge_1"
dataset1 = LuxonisDataset(
dataset1_name,
Expand Down Expand Up @@ -572,76 +572,26 @@ def generator2():
dataset2.add(generator2())
dataset2.make_splits({"train": 0.6, "val": 0.4})

# Test in-place merge
cloned_dataset1 = dataset1.clone(
new_dataset_name=dataset1_name + "_cloned"
)
cloned_dataset1_merged_with_dataset2 = cloned_dataset1.merge_with(
dataset2, inplace=True
)

all_classes, _ = cloned_dataset1_merged_with_dataset2.get_classes()
assert set(all_classes) == {"person", "dog"}

df_cloned_merged = cloned_dataset1_merged_with_dataset2._load_df_offline()
df_merged = dataset1.merge_with(dataset2, inplace=False)._load_df_offline()
assert df_cloned_merged.equals(df_merged)


def test_merge_datasets_out_of_place(
tempdir: Path, bucket_storage: BucketStorage
):
dataset1_name = "test_merge_1"
dataset1 = LuxonisDataset(
dataset1_name,
bucket_storage=bucket_storage,
delete_existing=True,
delete_remote=True,
)

def generator1():
for i in range(3):
img = create_image(i, tempdir)
yield {
"file": img,
"annotation": {
"class": "person",
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
},
}

dataset1.add(generator1())
dataset1.make_splits({"train": 0.6, "val": 0.4})

dataset2_name = "test_merge_2"
dataset2 = LuxonisDataset(
dataset2_name,
bucket_storage=bucket_storage,
delete_existing=True,
delete_remote=True,
)

def generator2():
for i in range(3, 6):
img = create_image(i, tempdir)
yield {
"file": img,
"annotation": {
"class": "dog",
"boundingbox": {"x": 0.2, "y": 0.2, "w": 0.2, "h": 0.2},
},
}

dataset2.add(generator2())
dataset2.make_splits({"train": 0.6, "val": 0.4})
all_classes_inplace, _ = cloned_dataset1_merged_with_dataset2.get_classes()
assert set(all_classes_inplace) == {"person", "dog"}

# Test out-of-place merge
dataset1_merged_with_dataset2 = dataset1.merge_with(
dataset2,
inplace=False,
new_dataset_name=dataset1_name + "_" + dataset2_name + "_merged",
)

all_classes, _ = dataset1_merged_with_dataset2.get_classes()
assert set(all_classes) == {"person", "dog"}
all_classes_out_of_place, _ = dataset1_merged_with_dataset2.get_classes()
assert set(all_classes_out_of_place) == {"person", "dog"}

df_merged = dataset1_merged_with_dataset2._load_df_offline()
df_cloned_merged = dataset1.merge_with(
Expand Down

0 comments on commit bbde0d2

Please sign in to comment.