From 2fbec8341d03c8c88a80f35b28b726e7e87818e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kozlovsk=C3=BD?= Date: Tue, 14 Jan 2025 18:36:01 +0100 Subject: [PATCH] Fix batch augmentations for `p=0` (#223) --- luxonis_ml/data/augmentations/batch_compose.py | 3 +++ luxonis_ml/data/augmentations/utils.py | 9 ++++++--- tests/test_data/test_augmentations/test_batched.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/luxonis_ml/data/augmentations/batch_compose.py b/luxonis_ml/data/augmentations/batch_compose.py index bab427d2..b6239c42 100644 --- a/luxonis_ml/data/augmentations/batch_compose.py +++ b/luxonis_ml/data/augmentations/batch_compose.py @@ -47,6 +47,9 @@ def __call__( for batch in yield_batches(data_batch, transform.batch_size): data = transform(**batch) # type: ignore + if isinstance(next(iter(data.values())), list): + data = {key: value[0] for key, value in batch.items()} + data = self.check_data_post_transform(data) new_batch.append(data) data_batch = new_batch diff --git a/luxonis_ml/data/augmentations/utils.py b/luxonis_ml/data/augmentations/utils.py index b6532fa4..ea6e356b 100644 --- a/luxonis_ml/data/augmentations/utils.py +++ b/luxonis_ml/data/augmentations/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, TypeVar import numpy as np @@ -66,9 +66,12 @@ def postprocess_keypoints( return keypoints +T = TypeVar("T") + + def yield_batches( - data_batch: List[Dict[str, Any]], batch_size: int -) -> Iterator[Dict[str, List[Any]]]: + data_batch: List[Dict[str, T]], batch_size: int +) -> Iterator[Dict[str, List[T]]]: """Yield batches of data. @type data_batch: List[Dict[str, Any]] diff --git a/tests/test_data/test_augmentations/test_batched.py b/tests/test_data/test_augmentations/test_batched.py index 15baca4f..56c15446 100644 --- a/tests/test_data/test_augmentations/test_batched.py +++ b/tests/test_data/test_augmentations/test_batched.py @@ -61,3 +61,17 @@ def test_mixup( config = [{"name": "MixUp", "params": {"p": 1.0}}] augmentations = AlbumentationsEngine(256, 256, targets, config) augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(2)]) + + +def test_batched_p_0( + image: np.ndarray, labels: Labels, targets: Dict[str, TaskType] +): + config = [ + { + "name": "Mosaic4", + "params": {"p": 0, "out_width": 640, "out_height": 640}, + }, + {"name": "MixUp", "params": {"p": 0}}, + ] + augmentations = AlbumentationsEngine(256, 256, targets, config) + augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(8)])