Skip to content

Commit

Permalink
Fix overlapping masks in semantic segmentation targets (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn authored Dec 6, 2024
1 parent 2a04ad7 commit e3fee91
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
7 changes: 4 additions & 3 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,11 @@ def combine_to_numpy(
[class_mapping.get(ann.class_, 0) for ann in annotations]
)

assigned_pixels = np.zeros((height, width), dtype=bool)
for i, class_ in enumerate(classes):
seg[class_, ...] = np.maximum(
seg[class_, ...], masks[i].astype(np.uint8)
)
mask = masks[i] & (assigned_pixels == 0)
seg[class_, ...] = np.maximum(seg[class_, ...], mask)
assigned_pixels |= mask

return seg

Expand Down
55 changes: 49 additions & 6 deletions tests/test_data/test_ann_creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pydantic
import pytest

Expand Down Expand Up @@ -73,10 +74,52 @@ def test_poly_no_auto_clip():
)


def test_poly_auto_clip():
poly_ann = PolylineSegmentationAnnotation(
**{"points": [(-0.1, 1.1), (-2, 2), (-0.1, -1.1)]}
def test_combine_to_numpy_handle_overlaps():
class_mapping = {"class1": 0, "class2": 1, "class3": 2}
height, width = 4, 4

polylines = [
np.array([(0, 0), (1, 0), (1, 1), (0, 1)], dtype=float)
/ [width, height],
np.array([(1, 0), (2, 0), (2, 1), (1, 1)], dtype=float)
/ [width, height],
np.array([(2, 2), (3, 2), (3, 3), (2, 3)], dtype=float)
/ [width, height],
]

annotations = [
PolylineSegmentationAnnotation(
**{"points": polylines[0], "class": "class1"}
),
PolylineSegmentationAnnotation(
**{"points": polylines[1], "class": "class2"}
),
PolylineSegmentationAnnotation(
**{"points": polylines[2], "class": "class3"}
),
]

combined = PolylineSegmentationAnnotation.combine_to_numpy(
annotations, class_mapping, height=height, width=width
)
assert 0 <= poly_ann.points[0][0] <= 1 and 0 <= poly_ann.points[0][1] <= 1
assert 0 <= poly_ann.points[1][0] <= 1 and 0 <= poly_ann.points[1][1] <= 1
assert 0 <= poly_ann.points[2][0] <= 1 and 0 <= poly_ann.points[2][1] <= 1

expected_combined = np.zeros(
(len(class_mapping), height, width), dtype=np.uint8
)
expected_combined[0, :, :] = np.array(
[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8,
)
expected_combined[1, :, :] = np.array(
[[0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8,
)
expected_combined[2, :, :] = np.array(
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]],
dtype=np.uint8,
)

assert combined.shape == (len(class_mapping), height, width)
assert np.array_equal(
combined, expected_combined
), f"Expected {expected_combined}, but got {combined}"

0 comments on commit e3fee91

Please sign in to comment.