Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapping labels over array boundaries #344

Merged
merged 10 commits into from
Feb 23, 2024
13 changes: 10 additions & 3 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def histogram(image,
return result


def label(image, structure=None):
def label(image, structure=None, wrap_axes=None):
"""
Label features in an array.

Expand All @@ -322,6 +322,11 @@ def label(image, structure=None):
[[0,1,0],
[1,1,1],
[0,1,0]]
wrap_axes : tuple of int, optional
Should labels be wrapped across array boundaries, and if so which axes.
- (0) only wrap across the 0th axis.
Holmgren825 marked this conversation as resolved.
Show resolved Hide resolved
- (0, 1) wrap across the 0th and 1st axis.
- (0, 1, 3) wrap across 0th, 1st and 3rd axis.

Returns
-------
Expand Down Expand Up @@ -363,12 +368,14 @@ def label(image, structure=None):
# Now, build a label connectivity graph that groups labels across blocks.
# We use this graph to find connected components and then relabel each
# block according to those.
label_groups = _label.label_adjacency_graph(block_labeled, structure,
total)
label_groups = _label.label_adjacency_graph(
block_labeled, structure, total, wrap_axes=wrap_axes
)
new_labeling = _label.connected_components_delayed(label_groups)
relabeled = _label.relabel_blocks(block_labeled, new_labeling)
n = da.max(relabeled)


return (relabeled, n)


Expand Down
47 changes: 36 additions & 11 deletions dask_image/ndmeasure/_utils/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def _to_csr_matrix(i, j, n):
return mat.tocsr()


def label_adjacency_graph(labels, structure, nlabels):
def set_tup_value(tup, idx, value):
"""Return a copy of `tup` with `value` at `idx`."""
return tuple((elem if i == idx else value) for i, elem in enumerate(tup))
Copy link
Contributor Author

@Holmgren825 Holmgren825 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I changed this to ==. I guess this comes down to how one thinks about wrapping. As it is now, setting wrap_axes=(0) wraps labels across the boundary of the 0th axis, whereas != would wrap the array over the 0th axis and wrap labels across the boundary of the 1st axis. Could use either one really, depends on which is more intuitive I guess.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'm a bit confused here. To me, indicating the axes over which dask_image.ndmeasure.label should consider the input array to wrap over is most intuitive :) (which is what I think this code does)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. Just wanted to comment it since it differs from the suggestion of @jni.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestion from @jni was pretty abstract and untested and he well could have gotten a sign wrong somewhere 😅 Thanks @Holmgren825 and @m-albert! Sorry the next few weeks are very busy for me so I may not be able to do an in-depth review, but I'll try. Please ping me if there is a rush/you specifically want an extra pair of eyes on something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, indicating the axes over which dask_image.ndmeasure.label should consider the input array to wrap over is most intuitive :) (which is what I think this code does)

I had been wrong earlier in the sense that the code above should be saying != as jni suggested earlier. Because the idea is to only replace the element at index idx and leave other elements unchanged (those with i != idx).

Sorry the next few weeks are very busy for me so I may not be able to do an in-depth review, but I'll try. Please ping me if there is a rush/you specifically want an extra pair of eyes on something.

Thanks for your availability @jni :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, same here. I guess I got confused thinking about array axes and geographical axes for some reason. Fixed in 3aaeaca.



def label_adjacency_graph(labels, structure, nlabels, wrap_axes=None):
"""
Adjacency graph of labels between chunks of ``labels``.

Expand All @@ -144,6 +149,11 @@ def label_adjacency_graph(labels, structure, nlabels):
nlabels : delayed int
The total number of labels in ``labels`` *before* correcting for
global consistency.
wrap_axes : tuple of int, optional
Should labels be wrapped across array boundaries, and if so which axes.
- (0) only wrap over the 0th axis.
Holmgren825 marked this conversation as resolved.
Show resolved Hide resolved
- (0, 1) wrap over the 0th and 1st axis.
- (0, 1, 3) wrap over 0th, 1st and 3rd axis.

Returns
-------
Expand All @@ -155,15 +165,30 @@ def label_adjacency_graph(labels, structure, nlabels):
if structure is None:
structure = scipy.ndimage.generate_binary_structure(labels.ndim, 1)

faces = _chunk_faces(labels.chunks, labels.shape, structure)
face_slices = _chunk_faces(labels.chunks, labels.shape, structure)
all_mappings = [da.empty((2, 0), dtype=LABEL_DTYPE, chunks=1)]
for face_slice in faces:
face = labels[face_slice]
faces = []

for face_slice in face_slices:
faces.append(labels[face_slice])

if wrap_axes is not None:
for ax in wrap_axes:
none_slice = (slice(None),) * labels.ndim
sl_back = set_tup_value(none_slice, ax, [-1])
sl_front = set_tup_value(none_slice, ax, [0])
faces.append(
da.stack([labels[sl_back], labels[sl_front]], axis=ax).squeeze()
)

for face in faces:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic could live inside of _chunk_faces by extending the existing implementation (tried to explain what I mean here #344 (comment)). Also, in this way all code determining the chunk faces to consider would live together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @m-albert! I agree that it would nicer to move this to _chunk_faces, although I've struggled a bit to understand what's going on in it. One idea I had was that, in the main loop over the blocks, you could stack the bottom block on top when neigh_block[dim] >= numblocks[dim], but these slices do not wrap. So I went with a simpler approach and just moved the loop over the wrap_axes to the end of _chunk_faces, and added a slice that covers the corners of the array. This makes it pass the corner feature test case that previously failed. Lowering the connectivity to one for this case returns two features despite wrapping both axes, which I think is correct.

mapped = _across_block_label_grouping_delayed(face, structure)
all_mappings.append(mapped)

all_mappings = da.concatenate(all_mappings, axis=1)
i, j = all_mappings
mat = _to_csr_matrix(i, j, nlabels + 1)

return mat


Expand Down Expand Up @@ -203,17 +228,17 @@ def _chunk_faces(chunks, shape, structure):

ndim = len(shape)
numblocks = tuple(list(len(c) for c in chunks))

slices = da.core.slices_from_chunks(chunks)

# arrange block/chunk indices on grid
block_summary = np.arange(len(slices)).reshape(numblocks)

faces = []
for ind_curr_block, curr_block in enumerate(np.ndindex(numblocks)):

for pos_structure_coord in np.array(np.where(structure)).T:

# only consider forward neighbors
if min(pos_structure_coord) < 1 or \
max(pos_structure_coord) < 2: continue
Expand All @@ -236,9 +261,9 @@ def _chunk_faces(chunks, shape, structure):
curr_slice.append(slice(
slices[ind_curr_block][dim].stop - 1,
slices[ind_curr_block][dim].stop + 1))

faces.append(tuple(curr_slice))

return faces


Expand Down
93 changes: 92 additions & 1 deletion tests/test_dask_image/test_ndmeasure/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,97 @@ def test_label(seed, prob, shape, chunks, connectivity):
_assert_equivalent_labeling(a_l, d_l.compute())


a = np.array(
[
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 0, 1, 1, 1, 0],
[0, 1, 0, 0, 1, 0, 1, 1, 1, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0],
]
)


@pytest.mark.parametrize(
"a, a_res, wrap_axes",
[
(
a,
np.array(
[
[0, 0, 1, 0, 0, 3, 3, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 2, 0, 1, 1, 1, 0],
[0, 1, 0, 0, 2, 0, 1, 1, 1, 0],
[0, 0, 1, 0, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 4, 0, 0, 5, 5, 0, 0, 0],
]
),
(0,),
),
(
a,
np.array(
[
[0, 0, 1, 0, 0, 3, 3, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 4, 4, 4, 4],
[1, 1, 0, 0, 0, 0, 4, 4, 4, 4],
[1, 0, 0, 0, 2, 0, 4, 4, 4, 0],
[0, 1, 0, 0, 2, 0, 4, 4, 4, 0],
[0, 0, 1, 0, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 3, 3, 0, 0, 0],
]
),
(1,),
),
(
a,
np.array(
[
[0, 0, 1, 0, 0, 3, 3, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 2, 0, 1, 1, 1, 0],
[0, 1, 0, 0, 2, 0, 1, 1, 1, 0],
[0, 0, 1, 0, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 3, 3, 0, 0, 0],
]
),
(0, 1),
),
pytest.param(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]),
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]),
(0, 1),
marks=pytest.mark.xfail(reason="Can't wrap corner labels"),
),
],
)
def test_label_wrap(a, a_res, wrap_axes):
d = da.from_array(a, chunks=(5, 5))

s = np.ones((3, 3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As jni commented, using structuring elements with connectivity > 1 would lead to problems in the corners. scipy.ndimage.morphology.generate_binary_structure is a nice convenience function for creating structuring elements.


d_l, d_nl = dask_image.ndmeasure.label(d, s, wrap_axes=wrap_axes)

_assert_equivalent_labeling(a_res, d_l.compute())


@pytest.mark.parametrize(
"ndim", (2, 3, 4, 5)
)
Expand All @@ -388,7 +479,7 @@ def test_label_full_struct_element(ndim):
labels_ndi, N_ndi = scipy.ndimage.label(mask, structure=full_s)
labels_di_da, N_di_da = dask_image.ndmeasure.label(
mask_da, structure=full_s)

assert N_ndi == N_di_da.compute()

_assert_equivalent_labeling(
Expand Down