Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapping labels over array boundaries #344

Merged
merged 10 commits into from
Feb 23, 2024
16 changes: 13 additions & 3 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def histogram(image,
return result


def label(image, structure=None):
def label(image, structure=None, wrap_axes=None):
"""
Label features in an array.

Expand All @@ -323,6 +323,14 @@ def label(image, structure=None):
[1,1,1],
[0,1,0]]

wrap_axes : tuple of int, optional
Whether labels should be wrapped across array boundaries, and if so which axes.
This feature is not present in `ndimage.label`.
Examples:
- (0,) only wrap across the 0th axis.
- (0, 1) wrap across the 0th and 1st axis.
- (0, 1, 3) wrap across 0th, 1st and 3rd axis.

Returns
-------
label : ndarray or int
Expand Down Expand Up @@ -363,12 +371,14 @@ def label(image, structure=None):
# Now, build a label connectivity graph that groups labels across blocks.
# We use this graph to find connected components and then relabel each
# block according to those.
label_groups = _label.label_adjacency_graph(block_labeled, structure,
total)
label_groups = _label.label_adjacency_graph(
block_labeled, structure, total, wrap_axes=wrap_axes
)
new_labeling = _label.connected_components_delayed(label_groups)
relabeled = _label.relabel_blocks(block_labeled, new_labeling)
n = da.max(relabeled)


return (relabeled, n)


Expand Down
84 changes: 57 additions & 27 deletions dask_image/ndmeasure/_utils/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _to_csr_matrix(i, j, n):
return mat.tocsr()


def label_adjacency_graph(labels, structure, nlabels):
def label_adjacency_graph(labels, structure, nlabels, wrap_axes=None):
"""
Adjacency graph of labels between chunks of ``labels``.

Expand All @@ -144,6 +144,11 @@ def label_adjacency_graph(labels, structure, nlabels):
nlabels : delayed int
The total number of labels in ``labels`` *before* correcting for
global consistency.
wrap_axes : tuple of int, optional
Should labels be wrapped across array boundaries, and if so which axes.
- (0,) only wrap over the 0th axis.
- (0, 1) wrap over the 0th and 1st axis.
- (0, 1, 3) wrap over 0th, 1st and 3rd axis.

Returns
-------
Expand All @@ -155,19 +160,24 @@ def label_adjacency_graph(labels, structure, nlabels):
if structure is None:
structure = scipy.ndimage.generate_binary_structure(labels.ndim, 1)

faces = _chunk_faces(labels.chunks, labels.shape, structure)
face_slices = _chunk_faces(
labels.chunks, labels.shape, structure, wrap_axes=wrap_axes
)
all_mappings = [da.empty((2, 0), dtype=LABEL_DTYPE, chunks=1)]
for face_slice in faces:

for face_slice in face_slices:
face = labels[face_slice]
mapped = _across_block_label_grouping_delayed(face, structure)
all_mappings.append(mapped)

all_mappings = da.concatenate(all_mappings, axis=1)
i, j = all_mappings
mat = _to_csr_matrix(i, j, nlabels + 1)

return mat


def _chunk_faces(chunks, shape, structure):
def _chunk_faces(chunks, shape, structure, wrap_axes=None):
"""
Return slices for two-pixel-wide boundaries between chunks.

Expand All @@ -179,19 +189,24 @@ def _chunk_faces(chunks, shape, structure):
The shape of the array.
structure: array of bool
Structuring element, shape (3,) * ndim.
wrap_axes : tuple of int, optional
Should labels be wrapped across array boundaries, and if so which axes.
- (0,) only wrap over the 0th axis.
- (0, 1) wrap over the 0th and 1st axis.
- (0, 1, 3) wrap over 0th, 1st and 3rd axis.

Returns
Yields
-------
faces : list of tuple of slices
Each element in this list indexes a face between two chunks.
tuple of slices
Each element indexes a face between two chunks.

Examples
--------
>>> import dask.array as da
>>> import scipy.ndimage as ndi
>>> a = da.arange(110, chunks=110).reshape((10, 11)).rechunk(5)
>>> structure = ndi.generate_binary_structure(2, 1)
>>> chunk_faces(a.chunks, a.shape, structure)
>>> list(chunk_faces(a.chunks, a.shape, structure))
[(slice(4, 6, None), slice(0, 5, None)),
(slice(4, 6, None), slice(5, 10, None)),
(slice(4, 6, None), slice(10, 11, None)),
Expand All @@ -202,44 +217,59 @@ def _chunk_faces(chunks, shape, structure):
"""

ndim = len(shape)
numblocks = tuple(list(len(c) for c in chunks))


slices = da.core.slices_from_chunks(chunks)

# arrange block/chunk indices on grid
block_summary = np.arange(len(slices)).reshape(numblocks)

faces = []
for ind_curr_block, curr_block in enumerate(np.ndindex(numblocks)):

block_summary = np.arange(len(slices)).reshape(
[len(c) for c in chunks])

# Iterate over all blocks and use the structuring element
# to determine which blocks should be connected.
# For wrappped axes, we need to consider the block
# before the current block with index -1 as well.
numblocks = [len(c) if wrap_axes is None or ax not in wrap_axes
else len(c) + 1 for ax, c in enumerate(chunks)]
for curr_block in np.ndindex(tuple(numblocks)):

curr_block = list(curr_block)

if wrap_axes is not None:
# start at -1 indices for wrapped axes
for wrap_axis in wrap_axes:
curr_block[wrap_axis] = curr_block[wrap_axis] - 1

# iterate over neighbors of the current block
for pos_structure_coord in np.array(np.where(structure)).T:

# only consider forward neighbors
if min(pos_structure_coord) < 1 or \
max(pos_structure_coord) < 2: continue

neigh_block = [curr_block[dim] + pos_structure_coord[dim] - 1
for dim in range(ndim)]

if max([neigh_block[dim] >= numblocks[dim] for dim in range(ndim)]): continue
if max([neigh_block[dim] >= block_summary.shape[dim]
for dim in range(ndim)]): continue

# get neighbor slice index
ind_neigh_block = block_summary[tuple(neigh_block)]
ind_curr_block = block_summary[tuple(curr_block)]
Comment on lines 255 to +256
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a small change to this comment in PR: #353


curr_slice = []
for dim in range(ndim):
# keep slice if not on boundary
if slices[ind_curr_block][dim] == slices[ind_neigh_block][dim]:
if neigh_block[dim] == curr_block[dim]:
curr_slice.append(slices[ind_curr_block][dim])
# otherwise, add two-pixel-wide boundary
else:
curr_slice.append(slice(
slices[ind_curr_block][dim].stop - 1,
slices[ind_curr_block][dim].stop + 1))

faces.append(tuple(curr_slice))

return faces
if slices[ind_curr_block][dim].stop == shape[dim]:
curr_slice.append(slice(None, None, shape[dim] - 1))
else:
curr_slice.append(slice(
slices[ind_curr_block][dim].stop - 1,
slices[ind_curr_block][dim].stop + 1))

yield tuple(curr_slice)


def block_ndi_label_delayed(block, structure):
Expand Down
Loading