diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index 505603de..5f4eaf48 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -303,7 +303,7 @@ def histogram(image, return result -def label(image, structure=None): +def label(image, structure=None, wrap_axes=None): """ Label features in an array. @@ -323,6 +323,14 @@ def label(image, structure=None): [1,1,1], [0,1,0]] + wrap_axes : tuple of int, optional + Whether labels should be wrapped across array boundaries, and if so which axes. + This feature is not present in `ndimage.label`. + Examples: + - (0,) only wrap across the 0th axis. + - (0, 1) wrap across the 0th and 1st axis. + - (0, 1, 3) wrap across 0th, 1st and 3rd axis. + Returns ------- label : ndarray or int @@ -363,12 +371,14 @@ def label(image, structure=None): # Now, build a label connectivity graph that groups labels across blocks. # We use this graph to find connected components and then relabel each # block according to those. - label_groups = _label.label_adjacency_graph(block_labeled, structure, - total) + label_groups = _label.label_adjacency_graph( + block_labeled, structure, total, wrap_axes=wrap_axes + ) new_labeling = _label.connected_components_delayed(label_groups) relabeled = _label.relabel_blocks(block_labeled, new_labeling) n = da.max(relabeled) + return (relabeled, n) diff --git a/dask_image/ndmeasure/_utils/_label.py b/dask_image/ndmeasure/_utils/_label.py index f3fd0eb6..3a130bb2 100644 --- a/dask_image/ndmeasure/_utils/_label.py +++ b/dask_image/ndmeasure/_utils/_label.py @@ -123,7 +123,7 @@ def _to_csr_matrix(i, j, n): return mat.tocsr() -def label_adjacency_graph(labels, structure, nlabels): +def label_adjacency_graph(labels, structure, nlabels, wrap_axes=None): """ Adjacency graph of labels between chunks of ``labels``. @@ -144,6 +144,11 @@ def label_adjacency_graph(labels, structure, nlabels): nlabels : delayed int The total number of labels in ``labels`` *before* correcting for global consistency. + wrap_axes : tuple of int, optional + Should labels be wrapped across array boundaries, and if so which axes. + - (0,) only wrap over the 0th axis. + - (0, 1) wrap over the 0th and 1st axis. + - (0, 1, 3) wrap over 0th, 1st and 3rd axis. Returns ------- @@ -155,19 +160,24 @@ def label_adjacency_graph(labels, structure, nlabels): if structure is None: structure = scipy.ndimage.generate_binary_structure(labels.ndim, 1) - faces = _chunk_faces(labels.chunks, labels.shape, structure) + face_slices = _chunk_faces( + labels.chunks, labels.shape, structure, wrap_axes=wrap_axes + ) all_mappings = [da.empty((2, 0), dtype=LABEL_DTYPE, chunks=1)] - for face_slice in faces: + + for face_slice in face_slices: face = labels[face_slice] mapped = _across_block_label_grouping_delayed(face, structure) all_mappings.append(mapped) + all_mappings = da.concatenate(all_mappings, axis=1) i, j = all_mappings mat = _to_csr_matrix(i, j, nlabels + 1) + return mat -def _chunk_faces(chunks, shape, structure): +def _chunk_faces(chunks, shape, structure, wrap_axes=None): """ Return slices for two-pixel-wide boundaries between chunks. @@ -179,11 +189,16 @@ def _chunk_faces(chunks, shape, structure): The shape of the array. structure: array of bool Structuring element, shape (3,) * ndim. + wrap_axes : tuple of int, optional + Should labels be wrapped across array boundaries, and if so which axes. + - (0,) only wrap over the 0th axis. + - (0, 1) wrap over the 0th and 1st axis. + - (0, 1, 3) wrap over 0th, 1st and 3rd axis. - Returns + Yields ------- - faces : list of tuple of slices - Each element in this list indexes a face between two chunks. + tuple of slices + Each element indexes a face between two chunks. Examples -------- @@ -191,7 +206,7 @@ def _chunk_faces(chunks, shape, structure): >>> import scipy.ndimage as ndi >>> a = da.arange(110, chunks=110).reshape((10, 11)).rechunk(5) >>> structure = ndi.generate_binary_structure(2, 1) - >>> chunk_faces(a.chunks, a.shape, structure) + >>> list(chunk_faces(a.chunks, a.shape, structure)) [(slice(4, 6, None), slice(0, 5, None)), (slice(4, 6, None), slice(5, 10, None)), (slice(4, 6, None), slice(10, 11, None)), @@ -202,18 +217,31 @@ def _chunk_faces(chunks, shape, structure): """ ndim = len(shape) - numblocks = tuple(list(len(c) for c in chunks)) - + slices = da.core.slices_from_chunks(chunks) - + # arrange block/chunk indices on grid - block_summary = np.arange(len(slices)).reshape(numblocks) - - faces = [] - for ind_curr_block, curr_block in enumerate(np.ndindex(numblocks)): - + block_summary = np.arange(len(slices)).reshape( + [len(c) for c in chunks]) + + # Iterate over all blocks and use the structuring element + # to determine which blocks should be connected. + # For wrappped axes, we need to consider the block + # before the current block with index -1 as well. + numblocks = [len(c) if wrap_axes is None or ax not in wrap_axes + else len(c) + 1 for ax, c in enumerate(chunks)] + for curr_block in np.ndindex(tuple(numblocks)): + + curr_block = list(curr_block) + + if wrap_axes is not None: + # start at -1 indices for wrapped axes + for wrap_axis in wrap_axes: + curr_block[wrap_axis] = curr_block[wrap_axis] - 1 + + # iterate over neighbors of the current block for pos_structure_coord in np.array(np.where(structure)).T: - + # only consider forward neighbors if min(pos_structure_coord) < 1 or \ max(pos_structure_coord) < 2: continue @@ -221,25 +249,27 @@ def _chunk_faces(chunks, shape, structure): neigh_block = [curr_block[dim] + pos_structure_coord[dim] - 1 for dim in range(ndim)] - if max([neigh_block[dim] >= numblocks[dim] for dim in range(ndim)]): continue + if max([neigh_block[dim] >= block_summary.shape[dim] + for dim in range(ndim)]): continue # get neighbor slice index - ind_neigh_block = block_summary[tuple(neigh_block)] + ind_curr_block = block_summary[tuple(curr_block)] curr_slice = [] for dim in range(ndim): # keep slice if not on boundary - if slices[ind_curr_block][dim] == slices[ind_neigh_block][dim]: + if neigh_block[dim] == curr_block[dim]: curr_slice.append(slices[ind_curr_block][dim]) # otherwise, add two-pixel-wide boundary else: - curr_slice.append(slice( - slices[ind_curr_block][dim].stop - 1, - slices[ind_curr_block][dim].stop + 1)) - - faces.append(tuple(curr_slice)) - - return faces + if slices[ind_curr_block][dim].stop == shape[dim]: + curr_slice.append(slice(None, None, shape[dim] - 1)) + else: + curr_slice.append(slice( + slices[ind_curr_block][dim].stop - 1, + slices[ind_curr_block][dim].stop + 1)) + + yield tuple(curr_slice) def block_ndi_label_delayed(block, structure): diff --git a/tests/test_dask_image/test_ndmeasure/test_core.py b/tests/test_dask_image/test_ndmeasure/test_core.py index 57add8cf..d72bcef5 100644 --- a/tests/test_dask_image/test_ndmeasure/test_core.py +++ b/tests/test_dask_image/test_ndmeasure/test_core.py @@ -366,6 +366,294 @@ def test_label(seed, prob, shape, chunks, connectivity): _assert_equivalent_labeling(a_l, d_l.compute()) +a = np.array( + [ + [0, 0, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 0, 1, 1, 1, 0], + [0, 1, 0, 0, 1, 0, 1, 1, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 1, 0, 0, 0], + ] +) + + +@pytest.mark.parametrize( + "a, a_res, wrap_axes, connectivity, chunks", + [ + pytest.param( + a, + np.array( + [ + [0, 0, 1, 0, 0, 3, 3, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 2, 0, 1, 1, 1, 0], + [0, 1, 0, 0, 2, 0, 1, 1, 1, 0], + [0, 0, 1, 0, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 4, 0, 0, 5, 5, 0, 0, 0], + ] + ), + (1,), + 2, + (5, 5), + id="2d, wrapping 1st axis.", + ), + pytest.param( + a, + np.array( + [ + [0, 0, 1, 0, 0, 3, 3, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 4, 4, 4, 4], + [1, 1, 0, 0, 0, 0, 4, 4, 4, 4], + [1, 0, 0, 0, 2, 0, 4, 4, 4, 0], + [0, 1, 0, 0, 2, 0, 4, 4, 4, 0], + [0, 0, 1, 0, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 3, 3, 0, 0, 0], + ] + ), + (0,), + 2, + (5, 5), + id="2d, wrapping 0th axes.", + ), + pytest.param( + a, + np.array( + [ + [0, 0, 1, 0, 0, 3, 3, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 2, 0, 1, 1, 1, 0], + [0, 1, 0, 0, 2, 0, 1, 1, 1, 0], + [0, 0, 1, 0, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 3, 3, 0, 0, 0], + ] + ), + (0, 1), + 2, + (5, 5), + id="2d, wrapping both axes", + ), + pytest.param( + np.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]), + np.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]), + (0, 1), + 2, + "auto", + id="2d, full wrap, high connectivity (corners).", + ), + pytest.param( + np.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]]), + # Corners should not be connected for lower connectivity. + np.array([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 2]]), + (0, 1), + 1, + "auto", + id="2d, full wrap, low connectivity (no corners).", + ), + # 3d + pytest.param( + np.array( + [ + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1]], + ] + ), + np.array( + [ + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 2], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [3, 0, 0, 0, 4], [3, 0, 0, 0, 4]], + ] + ), + None, + 3, + "auto", + id="3d no wrap", + ), + pytest.param( + np.array( + [ + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1]], + ] + ), + np.array( + [ + [[0, 0, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [2, 0, 0, 0, 2], [2, 0, 0, 0, 2]], + ] + ), + (2,), + 3, + "auto", + id="3d wrap 2nd axis", + ), + pytest.param( + np.array( + [ + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 1], + ], + ] + ), + np.array( + [ + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [2, 0, 0, 0, 2], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [3, 0, 0, 0, 3], + ], + ] + ), + (1, 2), + 3, + "auto", + id="3d, wrap 1st and 2nd axis, with corners", + ), + pytest.param( + np.array( + [ + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 1], + ], + ] + ), + np.array( + [ + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [2, 0, 0, 0, 2], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 1], + ], + ] + ), + (1, 2), + 3, + "auto", + id="3d, with corners, connection through adjacent timesteps.", + ), + ], +) +def test_label_wrap(a, a_res, wrap_axes, connectivity, chunks): + d = da.from_array(a, chunks=chunks) + + s = scipy.ndimage.generate_binary_structure(a.ndim, connectivity) + + d_l, _ = dask_image.ndmeasure.label(d, s, wrap_axes=wrap_axes) + + _assert_equivalent_labeling(a_res, d_l.compute()) + + @pytest.mark.parametrize( "ndim", (2, 3, 4, 5) ) @@ -388,7 +676,7 @@ def test_label_full_struct_element(ndim): labels_ndi, N_ndi = scipy.ndimage.label(mask, structure=full_s) labels_di_da, N_di_da = dask_image.ndmeasure.label( mask_da, structure=full_s) - + assert N_ndi == N_di_da.compute() _assert_equivalent_labeling(