diff --git a/luxonis_ml/data/datasets/luxonis_dataset.py b/luxonis_ml/data/datasets/luxonis_dataset.py index 383f602c..17b7917a 100644 --- a/luxonis_ml/data/datasets/luxonis_dataset.py +++ b/luxonis_ml/data/datasets/luxonis_dataset.py @@ -289,7 +289,8 @@ def clone( self, new_dataset_name: str, push_to_cloud: bool = True ) -> "LuxonisDataset": """Create a new LuxonisDataset that is a local copy of the - current dataset. + current dataset. Cloned dataset will overwrite the existing + dataset with the same name. @type new_dataset_name: str @param new_dataset_name: Name of the newly created dataset. @@ -358,64 +359,67 @@ def merge_with( @param new_dataset_name: The name of the new dataset to create if inplace is False. """ - if not inplace: - if not new_dataset_name: - raise ValueError( - "You must specify a name for the new dataset when inplace is False." - ) - new_dataset = self.clone(new_dataset_name, push_to_cloud=False) - new_dataset.merge_with(other, inplace=True) - return new_dataset - - if other.is_remote != self.is_remote: + if not inplace and not new_dataset_name: raise ValueError( - "Merging is only supported for datasets with the same bucket storage." + "You must specify a name for the new dataset when inplace is False." ) + target_dataset = ( + self + if inplace + else self.clone(new_dataset_name, push_to_cloud=False) + ) + if self.is_remote: other.sync_from_cloud(update_mode=UpdateMode.ALWAYS) - self.sync_from_cloud(update_mode=UpdateMode.IF_EMPTY) + self.sync_from_cloud( + update_mode=UpdateMode.ALWAYS + if inplace + else UpdateMode.IF_EMPTY + ) df_self = self._load_df_offline() df_other = other._load_df_offline() duplicate_uuids = set(df_self["uuid"]).intersection(df_other["uuid"]) - if duplicate_uuids: # skip duplicate uuids + if duplicate_uuids: df_other = df_other.filter( ~df_other["uuid"].is_in(duplicate_uuids) ) + df_merged = pl.concat([df_self, df_other]) + target_dataset._save_df_offline(df_merged) file_index_self = self._get_file_index() file_index_other = other._get_file_index() file_index_duplicates = set(file_index_self["uuid"]).intersection( file_index_other["uuid"] ) - if file_index_duplicates: # skip duplicate uuids + if file_index_duplicates: file_index_other = file_index_other.filter( ~file_index_other["uuid"].is_in(file_index_duplicates) ) - merged_file_index = pl.concat([file_index_self, file_index_other]) + merged_file_index = pl.concat([file_index_self, file_index_other]) if merged_file_index is not None: - file_index_path = self.metadata_path / "file_index.parquet" + file_index_path = ( + target_dataset.metadata_path / "file_index.parquet" + ) merged_file_index.write_parquet(file_index_path) splits_self = self._load_splits(self.metadata_path) splits_other = self._load_splits(other.metadata_path) self._merge_splits(splits_self, splits_other) - - self._save_df_offline(df_merged) - self._save_splits(splits_self) + target_dataset._save_splits(splits_self) if self.is_remote: shutil.copytree( - other.media_path, self.media_path, dirs_exist_ok=True + other.media_path, target_dataset.media_path, dirs_exist_ok=True ) - self.sync_to_cloud() + target_dataset.sync_to_cloud() - self._merge_metadata_with(other) + target_dataset._merge_metadata_with(other) - return self + return target_dataset def _load_splits(self, path: Path) -> Dict[str, List[str]]: splits_path = path / "splits.json" diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index ac4ae519..ada15702 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -527,7 +527,7 @@ def generator1(): assert df_cloned.equals(df_original) -def test_merge_datasets_inplace(tempdir: Path, bucket_storage: BucketStorage): +def test_merge_datasets(tempdir: Path, bucket_storage: BucketStorage): dataset1_name = "test_merge_1" dataset1 = LuxonisDataset( dataset1_name, @@ -572,6 +572,7 @@ def generator2(): dataset2.add(generator2()) dataset2.make_splits({"train": 0.6, "val": 0.4}) + # Test in-place merge cloned_dataset1 = dataset1.clone( new_dataset_name=dataset1_name + "_cloned" ) @@ -579,69 +580,18 @@ def generator2(): dataset2, inplace=True ) - all_classes, _ = cloned_dataset1_merged_with_dataset2.get_classes() - assert set(all_classes) == {"person", "dog"} - - df_cloned_merged = cloned_dataset1_merged_with_dataset2._load_df_offline() - df_merged = dataset1.merge_with(dataset2, inplace=False)._load_df_offline() - assert df_cloned_merged.equals(df_merged) - - -def test_merge_datasets_out_of_place( - tempdir: Path, bucket_storage: BucketStorage -): - dataset1_name = "test_merge_1" - dataset1 = LuxonisDataset( - dataset1_name, - bucket_storage=bucket_storage, - delete_existing=True, - delete_remote=True, - ) - - def generator1(): - for i in range(3): - img = create_image(i, tempdir) - yield { - "file": img, - "annotation": { - "class": "person", - "boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1}, - }, - } - - dataset1.add(generator1()) - dataset1.make_splits({"train": 0.6, "val": 0.4}) - - dataset2_name = "test_merge_2" - dataset2 = LuxonisDataset( - dataset2_name, - bucket_storage=bucket_storage, - delete_existing=True, - delete_remote=True, - ) - - def generator2(): - for i in range(3, 6): - img = create_image(i, tempdir) - yield { - "file": img, - "annotation": { - "class": "dog", - "boundingbox": {"x": 0.2, "y": 0.2, "w": 0.2, "h": 0.2}, - }, - } - - dataset2.add(generator2()) - dataset2.make_splits({"train": 0.6, "val": 0.4}) + all_classes_inplace, _ = cloned_dataset1_merged_with_dataset2.get_classes() + assert set(all_classes_inplace) == {"person", "dog"} + # Test out-of-place merge dataset1_merged_with_dataset2 = dataset1.merge_with( dataset2, inplace=False, new_dataset_name=dataset1_name + "_" + dataset2_name + "_merged", ) - all_classes, _ = dataset1_merged_with_dataset2.get_classes() - assert set(all_classes) == {"person", "dog"} + all_classes_out_of_place, _ = dataset1_merged_with_dataset2.get_classes() + assert set(all_classes_out_of_place) == {"person", "dog"} df_merged = dataset1_merged_with_dataset2._load_df_offline() df_cloned_merged = dataset1.merge_with(