From 767b049c864368825afc9adfd35eac558cba8748 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 11:08:05 -0600 Subject: [PATCH 1/3] improved types --- luxonis_ml/data/augmentations/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/luxonis_ml/data/augmentations/utils.py b/luxonis_ml/data/augmentations/utils.py index b6532fa4..ea6e356b 100644 --- a/luxonis_ml/data/augmentations/utils.py +++ b/luxonis_ml/data/augmentations/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, TypeVar import numpy as np @@ -66,9 +66,12 @@ def postprocess_keypoints( return keypoints +T = TypeVar("T") + + def yield_batches( - data_batch: List[Dict[str, Any]], batch_size: int -) -> Iterator[Dict[str, List[Any]]]: + data_batch: List[Dict[str, T]], batch_size: int +) -> Iterator[Dict[str, List[T]]]: """Yield batches of data. @type data_batch: List[Dict[str, Any]] From 728f1c9dc29463e8ec485ab331a3c7bdbc509fdd Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 11:08:17 -0600 Subject: [PATCH 2/3] fixed batch augmentation bug --- luxonis_ml/data/augmentations/batch_compose.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/luxonis_ml/data/augmentations/batch_compose.py b/luxonis_ml/data/augmentations/batch_compose.py index bab427d2..b6239c42 100644 --- a/luxonis_ml/data/augmentations/batch_compose.py +++ b/luxonis_ml/data/augmentations/batch_compose.py @@ -47,6 +47,9 @@ def __call__( for batch in yield_batches(data_batch, transform.batch_size): data = transform(**batch) # type: ignore + if isinstance(next(iter(data.values())), list): + data = {key: value[0] for key, value in batch.items()} + data = self.check_data_post_transform(data) new_batch.append(data) data_batch = new_batch From ac1ecabc23213311da937280032912d3fbbff7e1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 11:08:24 -0600 Subject: [PATCH 3/3] added test case --- tests/test_data/test_augmentations/test_batched.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_data/test_augmentations/test_batched.py b/tests/test_data/test_augmentations/test_batched.py index 15baca4f..56c15446 100644 --- a/tests/test_data/test_augmentations/test_batched.py +++ b/tests/test_data/test_augmentations/test_batched.py @@ -61,3 +61,17 @@ def test_mixup( config = [{"name": "MixUp", "params": {"p": 1.0}}] augmentations = AlbumentationsEngine(256, 256, targets, config) augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(2)]) + + +def test_batched_p_0( + image: np.ndarray, labels: Labels, targets: Dict[str, TaskType] +): + config = [ + { + "name": "Mosaic4", + "params": {"p": 0, "out_width": 640, "out_height": 640}, + }, + {"name": "MixUp", "params": {"p": 0}}, + ] + augmentations = AlbumentationsEngine(256, 256, targets, config) + augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(8)])