Skip to content

Commit

Permalink
Fix batch augmentations for p=0 (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 14, 2025
1 parent d2840b2 commit 2fbec83
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
3 changes: 3 additions & 0 deletions luxonis_ml/data/augmentations/batch_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __call__(
for batch in yield_batches(data_batch, transform.batch_size):
data = transform(**batch) # type: ignore

if isinstance(next(iter(data.values())), list):
data = {key: value[0] for key, value in batch.items()}

data = self.check_data_post_transform(data)
new_batch.append(data)
data_batch = new_batch
Expand Down
9 changes: 6 additions & 3 deletions luxonis_ml/data/augmentations/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, TypeVar

import numpy as np

Expand Down Expand Up @@ -66,9 +66,12 @@ def postprocess_keypoints(
return keypoints


T = TypeVar("T")


def yield_batches(
data_batch: List[Dict[str, Any]], batch_size: int
) -> Iterator[Dict[str, List[Any]]]:
data_batch: List[Dict[str, T]], batch_size: int
) -> Iterator[Dict[str, List[T]]]:
"""Yield batches of data.
@type data_batch: List[Dict[str, Any]]
Expand Down
14 changes: 14 additions & 0 deletions tests/test_data/test_augmentations/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,17 @@ def test_mixup(
config = [{"name": "MixUp", "params": {"p": 1.0}}]
augmentations = AlbumentationsEngine(256, 256, targets, config)
augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(2)])


def test_batched_p_0(
image: np.ndarray, labels: Labels, targets: Dict[str, TaskType]
):
config = [
{
"name": "Mosaic4",
"params": {"p": 0, "out_width": 640, "out_height": 640},
},
{"name": "MixUp", "params": {"p": 0}},
]
augmentations = AlbumentationsEngine(256, 256, targets, config)
augmentations.apply([(image.copy(), deepcopy(labels)) for _ in range(8)])

0 comments on commit 2fbec83

Please sign in to comment.