Skip to content

Commit

Permalink
Update 3D tests to use label as node ID
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Nov 26, 2024
1 parent 9fcbf3c commit f14da39
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 96 deletions.
90 changes: 11 additions & 79 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ def segmentation_3d():
segmentation[0][mask] = 1

# make frame with two cells
# first cell centered at (20, 50, 80) with label 1
# second cell centered at (60, 50, 45) with label 2
# first cell centered at (20, 50, 80) with label 2
# second cell centered at (60, 50, 45) with label 3
mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape)
segmentation[1][mask] = 1
mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape)
segmentation[1][mask] = 2
mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape)
segmentation[1][mask] = 3

return segmentation

Expand Down Expand Up @@ -261,7 +261,7 @@ def graph_3d():
graph = nx.DiGraph()
nodes = [
(
"0_1",
1,
{
NodeAttr.POS.value: (50, 50, 50),
NodeAttr.TIME.value: 0,
Expand All @@ -270,97 +270,29 @@ def graph_3d():
},
),
(
"1_1",
2,
{
NodeAttr.POS.value: (20, 50, 80),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 4169,
},
),
(
"1_2",
3,
{
NodeAttr.POS.value: (60, 50, 45),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 2,
NodeAttr.SEG_ID.value: 3,
NodeAttr.AREA.value: 14147,
},
),
]
edges = [
# math.dist([50, 50], [20, 80])
("0_1", "1_1"),
(1, 2),
# math.dist([50, 50], [60, 45])
("0_1", "1_2"),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
return graph


@pytest.fixture
def multi_hypothesis_graph_3d():
graph = nx.DiGraph()
nodes = [
(
"0_0_1",
{
NodeAttr.POS.value: (50, 50, 50),
NodeAttr.TIME.value: 0,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
"0_1_1",
{
NodeAttr.POS.value: (45, 50, 55),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
"1_0_1",
{
NodeAttr.POS.value: (20, 50, 80),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
"1_0_2",
{
NodeAttr.POS.value: (60, 50, 45),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 305,
},
),
(
"1_1_1",
{
NodeAttr.POS.value: (15, 50, 70),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
]
edges = [
("0_0_1", "1_0_1"),
("0_0_1", "1_0_2"),
("0_1_1", "1_0_1"),
("0_1_1", "1_0_2"),
("0_0_1", "1_1_1"),
("0_1_1", "1_1_1"),
(1, 3),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_candidate_graph/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def test_compute_ious_2d(segmentation_2d):

def test_compute_ious_3d(segmentation_3d):
ious = _compute_ious(segmentation_3d[0], segmentation_3d[1])
expected = [(1, 2, 0.30)]
expected = [(1, 3, 0.30)]
for iou, expected_iou in zip(ious, expected, strict=False):
assert iou == pytest.approx(expected_iou, abs=0.01)

ious = _compute_ious(segmentation_3d[1], segmentation_3d[1])
expected = [(1, 1, 1.0), (2, 2, 1.0)]
expected = [(2, 2, 1.0), (3, 3, 1.0)]
for iou, expected_iou in zip(ious, expected, strict=False):
assert iou == pytest.approx(expected_iou, abs=0.01)

Expand Down
30 changes: 15 additions & 15 deletions tests/test_candidate_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,27 @@ def test_nodes_from_segmentation_3d(segmentation_3d):
node_graph, node_frame_dict = nodes_from_segmentation(
segmentation=segmentation_3d,
)
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80)
assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3])
assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2
assert node_graph.nodes[2][NodeAttr.TIME.value] == 1
assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169
assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 50, 80)

assert node_frame_dict[0] == ["0_1"]
assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"])
assert node_frame_dict[0] == [1]
assert Counter(node_frame_dict[1]) == Counter([2, 3])

# test with scaling
node_graph, node_frame_dict = nodes_from_segmentation(
segmentation=segmentation_3d, scale=[1, 1, 4.5, 1]
)
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 * 4.5
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20.0, 225.0, 80.0)
assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3])
assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2
assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 * 4.5
assert node_graph.nodes[2][NodeAttr.TIME.value] == 1
assert node_graph.nodes[2][NodeAttr.POS.value] == (20.0, 225.0, 80.0)

assert node_frame_dict[0] == ["0_1"]
assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"])
assert node_frame_dict[0] == [1]
assert Counter(node_frame_dict[1]) == Counter([2, 3])


# add_cand_edges
Expand All @@ -89,7 +89,7 @@ def test_add_cand_edges_2d(graph_2d):
def test_add_cand_edges_3d(graph_3d):
cand_graph = nx.create_empty_copy(graph_3d)
add_cand_edges(cand_graph, max_edge_distance=15)
graph_3d.remove_edge("0_1", "1_1")
graph_3d.remove_edge(1, 2)
assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges))


Expand Down

0 comments on commit f14da39

Please sign in to comment.