diff --git a/luxonis_ml/data/augmentations/batch_compose.py b/luxonis_ml/data/augmentations/batch_compose.py index bab427d2..b6239c42 100644 --- a/luxonis_ml/data/augmentations/batch_compose.py +++ b/luxonis_ml/data/augmentations/batch_compose.py @@ -47,6 +47,9 @@ def __call__( for batch in yield_batches(data_batch, transform.batch_size): data = transform(**batch) # type: ignore + if isinstance(next(iter(data.values())), list): + data = {key: value[0] for key, value in batch.items()} + data = self.check_data_post_transform(data) new_batch.append(data) data_batch = new_batch diff --git a/luxonis_ml/data/augmentations/utils.py b/luxonis_ml/data/augmentations/utils.py index b6532fa4..ea6e356b 100644 --- a/luxonis_ml/data/augmentations/utils.py +++ b/luxonis_ml/data/augmentations/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, TypeVar import numpy as np @@ -66,9 +66,12 @@ def postprocess_keypoints( return keypoints +T = TypeVar("T") + + def yield_batches( - data_batch: List[Dict[str, Any]], batch_size: int -) -> Iterator[Dict[str, List[Any]]]: + data_batch: List[Dict[str, T]], batch_size: int +) -> Iterator[Dict[str, List[T]]]: """Yield batches of data. @type data_batch: List[Dict[str, Any]] diff --git a/tests/test_data/test_augmentations/test_batched.py b/tests/test_data/test_augmentations/test_batched.py index 15baca4f..56c15446 100644 --- a/tests/test_data/test_augmentations/test_batched.py +++ b/tests/test_data/test_augmentations/test_batched.py @@ -61,3 +61,17 @@ def test_mixup( config = [{"name": "MixUp", "params": {"p": 1.0}}] augmentations = AlbumentationsEngine(256, 256, targets, config) augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(2)]) + + +def test_batched_p_0( + image: np.ndarray, labels: Labels, targets: Dict[str, TaskType] +): + config = [ + { + "name": "Mosaic4", + "params": {"p": 0, "out_width": 640, "out_height": 640}, + }, + {"name": "MixUp", "params": {"p": 0}}, + ] + augmentations = AlbumentationsEngine(256, 256, targets, config) + augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(8)])