From 5508d0ff5fb98ca231f8c129250110591aed1990 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 28 Feb 2024 11:31:15 -0500 Subject: [PATCH 01/18] Add use_seeds config parameter --- cellulus/configs/inference_config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py index 707333f..eaa9d1e 100644 --- a/cellulus/configs/inference_config.py +++ b/cellulus/configs/inference_config.py @@ -72,6 +72,12 @@ class InferenceConfig: How to cluster the embeddings? Can be one of 'meanshift' or 'greedy'. + use_seeds (default = False): + + If set to True, the local optima of the distance map from the + predicted object centers is used. + Else, seeds are determined by sklearn.cluster.MeanShift. + num_bandwidths (default = 1): Number of bandwidths to obtain segmentations for. @@ -139,6 +145,7 @@ class InferenceConfig: clustering = attrs.field( default="meanshift", validator=in_(["meanshift", "greedy"]) ) + use_seeds = attrs.field(default=False, validator=instance_of(bool)) bandwidth = attrs.field( default=None, validator=attrs.validators.optional(instance_of(float)) ) From 2ef68fb0d327e01db5e60eb0bfd5152057137b37 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 28 Feb 2024 11:32:06 -0500 Subject: [PATCH 02/18] Ignore samples which are completely background --- cellulus/datasets/zarr_dataset.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index 1a6ad78..a864d89 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -137,17 +137,24 @@ def __yield_sample(self): with gp.build(self.pipeline): while True: + array_is_zero = True # request one sample, all channels, plus crop dimensions - request = gp.BatchRequest() - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + while array_is_zero: + request = gp.BatchRequest() + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, + (1, self.num_channels, *self.crop_size), + ) ) - ) - sample = self.pipeline.request_batch(request) - sample_data = sample[self.raw].data[0] - anchor_samples, reference_samples = self.sample_coordinates() + sample = self.pipeline.request_batch(request) + sample_data = sample[self.raw].data[0] + if np.max(sample_data) <= 0.0: + pass + else: + array_is_zero = False + anchor_samples, reference_samples = self.sample_coordinates() yield sample_data, anchor_samples, reference_samples def __read_meta_data(self): From 1081a703e348c0dbb19ca22d16f9fd79c8f39b41 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 28 Feb 2024 11:34:25 -0500 Subject: [PATCH 03/18] Add if-else cases for handling provided seeds --- cellulus/segment.py | 54 +++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/cellulus/segment.py b/cellulus/segment.py index f1ba70c..2872015 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -1,5 +1,7 @@ import numpy as np import zarr +from scipy.ndimage import gaussian_filter +from skimage.feature import peak_local_max from skimage.filters import threshold_otsu from tqdm import tqdm @@ -119,22 +121,46 @@ def segment(inference_config: InferenceConfig) -> None: embeddings_centered[2] -= c_z ds_object_centered_embeddings[sample] = embeddings_centered + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims + ] + embeddings_centered_std = embeddings_centered[-1] + if inference_config.clustering == "meanshift": for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_size=inference_config.min_size, - reduction_probability=inference_config.reduction_probability, - threshold=threshold, - ) - # Note that the line below is needed - # because the embeddings_mean is modified - # by mean_shift_segmentation - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() + if inference_config.use_seeds: + offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0) + offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2) + coordinates = peak_local_max(-offset_magnitude_smooth) + seeds = np.flip(coordinates, 1) + segmentation = mean_shift_segmentation( + embeddings_centered_mean, + embeddings_centered_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=seeds, + ) + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + else: + segmentation = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=None, + ) + # Note that the line below is needed + # because the embeddings_mean is modified + # by mean_shift_segmentation + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() ds_segmentation[sample, bandwidth_factor, ...] = segmentation elif inference_config.clustering == "greedy": if dataset_meta_data.num_spatial_dims == 3: From 41d92356f488093cb520df7cd8025a5bdffc73f6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 28 Feb 2024 11:37:37 -0500 Subject: [PATCH 04/18] Provide seeds from segment.py --- cellulus/utils/mean_shift.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/cellulus/utils/mean_shift.py b/cellulus/utils/mean_shift.py index bdf6012..a51e488 100644 --- a/cellulus/utils/mean_shift.py +++ b/cellulus/utils/mean_shift.py @@ -4,7 +4,13 @@ def mean_shift_segmentation( - embedding_mean, embedding_std, bandwidth, min_size, reduction_probability, threshold + embedding_mean, + embedding_std, + bandwidth, + min_size, + reduction_probability, + threshold, + seeds, ): embedding_mean = torch.from_numpy(embedding_mean) if embedding_mean.ndim == 4: @@ -34,22 +40,28 @@ def mean_shift_segmentation( mask=mask, reduction_probability=reduction_probability, cluster_all=False, + seeds=seeds, )[0] return segmentation def segment_with_meanshift( - embedding, bandwidth, mask, reduction_probability, cluster_all + embedding, bandwidth, mask, reduction_probability, cluster_all, seeds ): anchor_mean_shift = AnchorMeanshift( - bandwidth, reduction_probability=reduction_probability, cluster_all=cluster_all + bandwidth, + reduction_probability=reduction_probability, + cluster_all=cluster_all, + seeds=seeds, ) return anchor_mean_shift(embedding, mask=mask) + 1 class AnchorMeanshift: - def __init__(self, bandwidth, reduction_probability, cluster_all): - self.mean_shift = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) + def __init__(self, bandwidth, reduction_probability, cluster_all, seeds): + self.mean_shift = MeanShift( + bandwidth=bandwidth, cluster_all=cluster_all, seeds=seeds + ) self.reduction_probability = reduction_probability def compute_mean_shift(self, X): From 6e55f448364e8010de08b460287a134fbc76c173 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:19:40 -0500 Subject: [PATCH 05/18] Update default min size in 3d --- cellulus/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/infer.py b/cellulus/infer.py index 0ddff24..3a48818 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -35,7 +35,7 @@ def infer(experiment_config): ) elif dataset_meta_data.num_spatial_dims == 3: inference_config.min_size = int( - 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) + 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) / 8 ) # set model model = get_model( From b85d305f9f751dafec0cfc61f022018c243db183 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:20:43 -0500 Subject: [PATCH 06/18] Return re-ided labels --- cellulus/utils/misc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cellulus/utils/misc.py b/cellulus/utils/misc.py index 2f77c72..f3edfbc 100644 --- a/cellulus/utils/misc.py +++ b/cellulus/utils/misc.py @@ -13,15 +13,16 @@ def size_filter(segmentation, min_size, filter_non_connected=True): return segmentation if filter_non_connected: - filter_labels = measure.label(segmentation, background=0) + filter_labels = measure.label(segmentation) else: filter_labels = segmentation + ids, sizes = np.unique(filter_labels, return_counts=True) filter_ids = ids[sizes < min_size] mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape) segmentation[mask] = 0 - return segmentation + return measure.label(segmentation) def extract_data(zip_url, data_dir, project_name): From 682c6a53844c7b5e76915ccfe115cd62f567df25 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:21:30 -0500 Subject: [PATCH 07/18] Update post_process code --- cellulus/post_process.py | 63 ++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/cellulus/post_process.py b/cellulus/post_process.py index 4c31994..f12630b 100644 --- a/cellulus/post_process.py +++ b/cellulus/post_process.py @@ -1,6 +1,6 @@ import numpy as np import zarr -from scipy.ndimage import binary_fill_holes, label +from scipy.ndimage import binary_fill_holes from scipy.ndimage import distance_transform_edt as dtedt from skimage.filters import threshold_otsu from tqdm import tqdm @@ -60,41 +60,48 @@ def post_process(inference_config: InferenceConfig) -> None: ids = np.unique(segmentation) ids = ids[ids != 0] for id_ in ids: - raw_image_masked = raw_image[segmentation == id_] + segmentation_id_mask = segmentation == id_ + if dataset_meta_data.num_spatial_dims == 2: + y, x = np.where(segmentation_id_mask) + y_min, y_max, x_min, x_max = ( + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + elif dataset_meta_data.num_spatial_dims == 3: + z, y, x = np.where(segmentation_id_mask) + z_min, z_max, y_min, y_max, x_min, x_max = ( + np.min(z), + np.max(z), + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + raw_image_masked = raw_image[segmentation_id_mask] threshold = threshold_otsu(raw_image_masked) - mask = (segmentation == id_) & (raw_image > threshold) - mask = binary_fill_holes(mask) + mask = segmentation_id_mask & (raw_image > threshold) + if dataset_meta_data.num_spatial_dims == 2: + mask_small = binary_fill_holes( + mask[y_min : y_max + 1, x_min : x_max + 1] + ) + mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small y, x = np.where(mask) ds_postprocessed[sample, bandwidth_factor, y, x] = id_ elif dataset_meta_data.num_spatial_dims == 3: + mask_small = binary_fill_holes( + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] + ) + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] = mask_small z, y, x = np.where(mask) ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ - # remove non-connected components - for bandwidth_factor in range(inference_config.num_bandwidths): - ids = np.unique(ds_postprocessed[sample, bandwidth_factor]) - ids = ids[ids != 0] - counter = np.max(ids) + 1 - for id_ in ids: - ma_id = ds_postprocessed[sample, bandwidth_factor] == id_ - array, num_features = label(ma_id) - if num_features > 1: - ids_array = np.unique(array) - ids_array = ids_array[ids_array != 0] - for id_array in ids_array: - if dataset_meta_data.num_spatial_dims == 2: - y, x = np.where(array == id_array) - ds_postprocessed[ - sample, bandwidth_factor, y, x - ] = counter - elif dataset_meta_data.num_spatial_dims == 3: - z, y, x = np.where(array == id_array) - ds_postprocessed[ - sample, bandwidth_factor, z, y, x - ] = counter - counter += 1 - # size filter - remove small objects for sample in tqdm(range(dataset_meta_data.num_samples)): for bandwidth_factor in range(inference_config.num_bandwidths): From 955d56ecb422b16c42cd1e8a714e1d2e0caa48a4 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:26:48 -0500 Subject: [PATCH 08/18] Update segment.py and postprocess.py to detect.py and segment.py --- cellulus/infer.py | 10 +- cellulus/post_process.py | 110 ----------------- cellulus/segment.py | 255 +++++++++++++-------------------------- 3 files changed, 90 insertions(+), 285 deletions(-) delete mode 100644 cellulus/post_process.py diff --git a/cellulus/infer.py b/cellulus/infer.py index 3a48818..00f320d 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -4,9 +4,9 @@ import torch from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.detect import detect from cellulus.evaluate import evaluate from cellulus.models import get_model -from cellulus.post_process import post_process from cellulus.predict import predict from cellulus.segment import segment @@ -69,12 +69,12 @@ def infer(experiment_config): # get predicted embeddings... if inference_config.prediction_dataset_config is not None: predict(model, inference_config, normalization_factor) - # ...turn them into a segmentation... + # ...turn them into a detection ... if inference_config.segmentation_dataset_config is not None: - segment(inference_config) - # ...and post-process the segmentation + detect(inference_config) + # ...and post-process the detection to obtain an instance segmentation if inference_config.post_processed_dataset_config is not None: - post_process(inference_config) + segment(inference_config) # ...and evaluate if ground-truth exists if inference_config.evaluation_dataset_config is not None: evaluate(inference_config) diff --git a/cellulus/post_process.py b/cellulus/post_process.py deleted file mode 100644 index f12630b..0000000 --- a/cellulus/post_process.py +++ /dev/null @@ -1,110 +0,0 @@ -import numpy as np -import zarr -from scipy.ndimage import binary_fill_holes -from scipy.ndimage import distance_transform_edt as dtedt -from skimage.filters import threshold_otsu -from tqdm import tqdm - -from cellulus.configs.inference_config import InferenceConfig -from cellulus.datasets.meta_data import DatasetMetaData -from cellulus.utils.misc import size_filter - - -def post_process(inference_config: InferenceConfig) -> None: - # filter small objects, erosion, etc. - - dataset_config = inference_config.dataset_config - dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) - - f = zarr.open(inference_config.post_processed_dataset_config.container_path) - ds = f[inference_config.post_processed_dataset_config.secondary_dataset_name] - - # prepare the zarr dataset to write to - f_postprocessed = zarr.open( - inference_config.post_processed_dataset_config.container_path - ) - ds_postprocessed = f_postprocessed.create_dataset( - inference_config.post_processed_dataset_config.dataset_name, - shape=( - dataset_meta_data.num_samples, - inference_config.num_bandwidths, - *dataset_meta_data.spatial_array, - ), - dtype=np.uint16, - ) - - ds_postprocessed.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - - # remove halo - if inference_config.post_processing == "cell": - for sample in tqdm(range(dataset_meta_data.num_samples)): - # first instance label masks are expanded by `grow_distance` - # next, expanded instance label masks are shrunk by `shrink_distance` - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = ds[sample, bandwidth_factor] - distance_foreground = dtedt(segmentation == 0) - expanded_mask = distance_foreground < inference_config.grow_distance - distance_background = dtedt(expanded_mask) - segmentation[distance_background < inference_config.shrink_distance] = 0 - ds_postprocessed[sample, bandwidth_factor, ...] = segmentation - elif inference_config.post_processing == "nucleus": - ds_raw = f[inference_config.dataset_config.dataset_name] - for sample in tqdm(range(dataset_meta_data.num_samples)): - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = ds[sample, bandwidth_factor] - raw_image = ds_raw[sample, 0] - ids = np.unique(segmentation) - ids = ids[ids != 0] - for id_ in ids: - segmentation_id_mask = segmentation == id_ - if dataset_meta_data.num_spatial_dims == 2: - y, x = np.where(segmentation_id_mask) - y_min, y_max, x_min, x_max = ( - np.min(y), - np.max(y), - np.min(x), - np.max(x), - ) - elif dataset_meta_data.num_spatial_dims == 3: - z, y, x = np.where(segmentation_id_mask) - z_min, z_max, y_min, y_max, x_min, x_max = ( - np.min(z), - np.max(z), - np.min(y), - np.max(y), - np.min(x), - np.max(x), - ) - raw_image_masked = raw_image[segmentation_id_mask] - threshold = threshold_otsu(raw_image_masked) - mask = segmentation_id_mask & (raw_image > threshold) - - if dataset_meta_data.num_spatial_dims == 2: - mask_small = binary_fill_holes( - mask[y_min : y_max + 1, x_min : x_max + 1] - ) - mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small - y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, y, x] = id_ - elif dataset_meta_data.num_spatial_dims == 3: - mask_small = binary_fill_holes( - mask[ - z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 - ] - ) - mask[ - z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 - ] = mask_small - z, y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ - - # size filter - remove small objects - for sample in tqdm(range(dataset_meta_data.num_samples)): - for bandwidth_factor in range(inference_config.num_bandwidths): - ds_postprocessed[sample, bandwidth_factor, ...] = size_filter( - ds_postprocessed[sample, bandwidth_factor], inference_config.min_size - ) diff --git a/cellulus/segment.py b/cellulus/segment.py index 2872015..f12630b 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -1,29 +1,30 @@ import numpy as np import zarr -from scipy.ndimage import gaussian_filter -from skimage.feature import peak_local_max +from scipy.ndimage import binary_fill_holes +from scipy.ndimage import distance_transform_edt as dtedt from skimage.filters import threshold_otsu from tqdm import tqdm from cellulus.configs.inference_config import InferenceConfig from cellulus.datasets.meta_data import DatasetMetaData -from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d -from cellulus.utils.mean_shift import mean_shift_segmentation +from cellulus.utils.misc import size_filter -def segment(inference_config: InferenceConfig) -> None: +def post_process(inference_config: InferenceConfig) -> None: + # filter small objects, erosion, etc. + dataset_config = inference_config.dataset_config dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) - f = zarr.open(inference_config.segmentation_dataset_config.container_path) - ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] + f = zarr.open(inference_config.post_processed_dataset_config.container_path) + ds = f[inference_config.post_processed_dataset_config.secondary_dataset_name] - # prepare the instance segmentation zarr dataset to write to - f_segmentation = zarr.open( - inference_config.segmentation_dataset_config.container_path + # prepare the zarr dataset to write to + f_postprocessed = zarr.open( + inference_config.post_processed_dataset_config.container_path ) - ds_segmentation = f_segmentation.create_dataset( - inference_config.segmentation_dataset_config.dataset_name, + ds_postprocessed = f_postprocessed.create_dataset( + inference_config.post_processed_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, inference_config.num_bandwidths, @@ -32,164 +33,78 @@ def segment(inference_config: InferenceConfig) -> None: dtype=np.uint16, ) - ds_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - ds_segmentation.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - - # prepare the binary segmentation zarr dataset to write to - ds_binary_segmentation = f_segmentation.create_dataset( - "binary_" + inference_config.segmentation_dataset_config.dataset_name, - shape=( - dataset_meta_data.num_samples, - 1, - *dataset_meta_data.spatial_array, - ), - dtype=np.uint16, - ) - - ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + ds_postprocessed.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ -dataset_meta_data.num_spatial_dims : ] - ds_binary_segmentation.attrs["resolution"] = ( - 1, - ) * dataset_meta_data.num_spatial_dims - ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - - # prepare the object centered embeddings zarr dataset to write to - ds_object_centered_embeddings = f_segmentation.create_dataset( - "centered_" - + inference_config.segmentation_dataset_config.secondary_dataset_name, - shape=( - dataset_meta_data.num_samples, - dataset_meta_data.num_spatial_dims + 1, - *dataset_meta_data.spatial_array, - ), - dtype=float, - ) - - ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [ - "t", - "z", - "y", - "x", - ][-dataset_meta_data.num_spatial_dims :] - ds_object_centered_embeddings.attrs["resolution"] = ( - 1, - ) * dataset_meta_data.num_spatial_dims - ds_object_centered_embeddings.attrs["offset"] = ( - 0, - ) * dataset_meta_data.num_spatial_dims - - for sample in tqdm(range(dataset_meta_data.num_samples)): - embeddings = ds[sample] - embeddings_std = embeddings[-1, ...] - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - if inference_config.threshold is None: - threshold = threshold_otsu(embeddings_std) - else: - threshold = inference_config.threshold - - print(f"For sample {sample}, binary threshold {threshold} was used.") - binary_mask = embeddings_std < threshold - ds_binary_segmentation[sample, 0, ...] = binary_mask - - # find mean of embeddings - embeddings_centered = embeddings.copy() - embeddings_mean_masked = ( - binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean - ) - if embeddings_centered.shape[0] == 3: - c_x = embeddings_mean_masked[0, 0] - c_y = embeddings_mean_masked[0, 1] - c_x = c_x[c_x != 0].mean() - c_y = c_y[c_y != 0].mean() - embeddings_centered[0] -= c_x - embeddings_centered[1] -= c_y - elif embeddings_centered.shape[0] == 4: - c_x = embeddings_mean_masked[0, 0] - c_y = embeddings_mean_masked[0, 1] - c_z = embeddings_mean_masked[0, 2] - c_x = c_x[c_x != 0].mean() - c_y = c_y[c_y != 0].mean() - c_z = c_z[c_z != 0].mean() - embeddings_centered[0] -= c_x - embeddings_centered[1] -= c_y - embeddings_centered[2] -= c_z - ds_object_centered_embeddings[sample] = embeddings_centered - - embeddings_centered_mean = embeddings_centered[ - np.newaxis, : dataset_meta_data.num_spatial_dims - ] - embeddings_centered_std = embeddings_centered[-1] - - if inference_config.clustering == "meanshift": + ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + + # remove halo + if inference_config.post_processing == "cell": + for sample in tqdm(range(dataset_meta_data.num_samples)): + # first instance label masks are expanded by `grow_distance` + # next, expanded instance label masks are shrunk by `shrink_distance` for bandwidth_factor in range(inference_config.num_bandwidths): - if inference_config.use_seeds: - offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0) - offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2) - coordinates = peak_local_max(-offset_magnitude_smooth) - seeds = np.flip(coordinates, 1) - segmentation = mean_shift_segmentation( - embeddings_centered_mean, - embeddings_centered_std, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_size=inference_config.min_size, - reduction_probability=inference_config.reduction_probability, - threshold=threshold, - seeds=seeds, - ) - embeddings_centered_mean = embeddings_centered[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - else: - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_size=inference_config.min_size, - reduction_probability=inference_config.reduction_probability, - threshold=threshold, - seeds=None, - ) - # Note that the line below is needed - # because the embeddings_mean is modified - # by mean_shift_segmentation - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - ds_segmentation[sample, bandwidth_factor, ...] = segmentation - elif inference_config.clustering == "greedy": - if dataset_meta_data.num_spatial_dims == 3: - cluster3d = Cluster3d( - width=embeddings.shape[-1], - height=embeddings.shape[-2], - depth=embeddings.shape[-3], - fg_mask=binary_mask, - device=inference_config.device, - ) - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = cluster3d.cluster( - prediction=embeddings, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_object_size=inference_config.min_size, - ) - ds_segmentation[sample, bandwidth_factor, ...] = segmentation - elif dataset_meta_data.num_spatial_dims == 2: - cluster2d = Cluster2d( - width=embeddings.shape[-1], - height=embeddings.shape[-2], - fg_mask=binary_mask, - device=inference_config.device, - ) - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = cluster2d.cluster( - prediction=embeddings, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_object_size=inference_config.min_size, - ) - - ds_segmentation[sample, bandwidth_factor, ...] = segmentation + segmentation = ds[sample, bandwidth_factor] + distance_foreground = dtedt(segmentation == 0) + expanded_mask = distance_foreground < inference_config.grow_distance + distance_background = dtedt(expanded_mask) + segmentation[distance_background < inference_config.shrink_distance] = 0 + ds_postprocessed[sample, bandwidth_factor, ...] = segmentation + elif inference_config.post_processing == "nucleus": + ds_raw = f[inference_config.dataset_config.dataset_name] + for sample in tqdm(range(dataset_meta_data.num_samples)): + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = ds[sample, bandwidth_factor] + raw_image = ds_raw[sample, 0] + ids = np.unique(segmentation) + ids = ids[ids != 0] + for id_ in ids: + segmentation_id_mask = segmentation == id_ + if dataset_meta_data.num_spatial_dims == 2: + y, x = np.where(segmentation_id_mask) + y_min, y_max, x_min, x_max = ( + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + elif dataset_meta_data.num_spatial_dims == 3: + z, y, x = np.where(segmentation_id_mask) + z_min, z_max, y_min, y_max, x_min, x_max = ( + np.min(z), + np.max(z), + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + raw_image_masked = raw_image[segmentation_id_mask] + threshold = threshold_otsu(raw_image_masked) + mask = segmentation_id_mask & (raw_image > threshold) + + if dataset_meta_data.num_spatial_dims == 2: + mask_small = binary_fill_holes( + mask[y_min : y_max + 1, x_min : x_max + 1] + ) + mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small + y, x = np.where(mask) + ds_postprocessed[sample, bandwidth_factor, y, x] = id_ + elif dataset_meta_data.num_spatial_dims == 3: + mask_small = binary_fill_holes( + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] + ) + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] = mask_small + z, y, x = np.where(mask) + ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ + + # size filter - remove small objects + for sample in tqdm(range(dataset_meta_data.num_samples)): + for bandwidth_factor in range(inference_config.num_bandwidths): + ds_postprocessed[sample, bandwidth_factor, ...] = size_filter( + ds_postprocessed[sample, bandwidth_factor], inference_config.min_size + ) From ee99d384aa64cb40e0867ac30ea4ebb060146412 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:44:32 -0500 Subject: [PATCH 09/18] Rename segmentation_dataset_config to detection_dataset_config --- cellulus/configs/inference_config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py index eaa9d1e..2719378 100644 --- a/cellulus/configs/inference_config.py +++ b/cellulus/configs/inference_config.py @@ -22,13 +22,13 @@ class InferenceConfig: Configuration object produced by predict.py. - segmentation_dataset_config: + detection_dataset_config: - Configuration object produced by segment.py. + Configuration object produced by detect.py. - post_processed_dataset_config: + segmentation_dataset_config: - Configuration object produced by post_process.py. + Configuration object produced by segment.py. evaluation_dataset_config: @@ -124,11 +124,11 @@ class InferenceConfig: default=None, converter=to_config(DatasetConfig) ) - segmentation_dataset_config: DatasetConfig = attrs.field( + detection_dataset_config: DatasetConfig = attrs.field( default=None, converter=to_config(DatasetConfig) ) - post_processed_dataset_config: DatasetConfig = attrs.field( + segmentation_dataset_config: DatasetConfig = attrs.field( default=None, converter=to_config(DatasetConfig) ) From 48212248393277a40cde24088c53b35b9d2b7bd8 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:44:58 -0500 Subject: [PATCH 10/18] Rename function to segment --- cellulus/segment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellulus/segment.py b/cellulus/segment.py index f12630b..566c950 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -10,7 +10,7 @@ from cellulus.utils.misc import size_filter -def post_process(inference_config: InferenceConfig) -> None: +def segment(inference_config: InferenceConfig) -> None: # filter small objects, erosion, etc. dataset_config = inference_config.dataset_config From 6fdb417cdfe1b3dec5387451adca888c4b1979ca Mon Sep 17 00:00:00 2001 From: lmanan Date: Sat, 2 Mar 2024 23:45:17 -0500 Subject: [PATCH 11/18] Add detect.py --- cellulus/detect.py | 195 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 cellulus/detect.py diff --git a/cellulus/detect.py b/cellulus/detect.py new file mode 100644 index 0000000..5208075 --- /dev/null +++ b/cellulus/detect.py @@ -0,0 +1,195 @@ +import numpy as np +import zarr +from scipy.ndimage import gaussian_filter +from skimage.feature import peak_local_max +from skimage.filters import threshold_otsu +from tqdm import tqdm + +from cellulus.configs.inference_config import InferenceConfig +from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d +from cellulus.utils.mean_shift import mean_shift_segmentation + + +def detect(inference_config: InferenceConfig) -> None: + dataset_config = inference_config.dataset_config + dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) + + f = zarr.open(inference_config.segmentation_dataset_config.container_path) + ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] + + # prepare the instance segmentation zarr dataset to write to + f_segmentation = zarr.open( + inference_config.segmentation_dataset_config.container_path + ) + ds_segmentation = f_segmentation.create_dataset( + inference_config.segmentation_dataset_config.dataset_name, + shape=( + dataset_meta_data.num_samples, + inference_config.num_bandwidths, + *dataset_meta_data.spatial_array, + ), + dtype=np.uint16, + ) + + ds_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + ds_segmentation.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + + # prepare the binary segmentation zarr dataset to write to + ds_binary_segmentation = f_segmentation.create_dataset( + "binary_" + inference_config.segmentation_dataset_config.dataset_name, + shape=( + dataset_meta_data.num_samples, + 1, + *dataset_meta_data.spatial_array, + ), + dtype=np.uint16, + ) + + ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + ds_binary_segmentation.attrs["resolution"] = ( + 1, + ) * dataset_meta_data.num_spatial_dims + ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + + # prepare the object centered embeddings zarr dataset to write to + ds_object_centered_embeddings = f_segmentation.create_dataset( + "centered_" + + inference_config.segmentation_dataset_config.secondary_dataset_name, + shape=( + dataset_meta_data.num_samples, + dataset_meta_data.num_spatial_dims + 1, + *dataset_meta_data.spatial_array, + ), + dtype=float, + ) + + ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [ + "t", + "z", + "y", + "x", + ][-dataset_meta_data.num_spatial_dims :] + ds_object_centered_embeddings.attrs["resolution"] = ( + 1, + ) * dataset_meta_data.num_spatial_dims + ds_object_centered_embeddings.attrs["offset"] = ( + 0, + ) * dataset_meta_data.num_spatial_dims + + for sample in tqdm(range(dataset_meta_data.num_samples)): + embeddings = ds[sample] + embeddings_std = embeddings[-1, ...] + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + if inference_config.threshold is None: + threshold = threshold_otsu(embeddings_std) + else: + threshold = inference_config.threshold + + print(f"For sample {sample}, binary threshold {threshold} was used.") + binary_mask = embeddings_std < threshold + ds_binary_segmentation[sample, 0, ...] = binary_mask + + # find mean of embeddings + embeddings_centered = embeddings.copy() + embeddings_mean_masked = ( + binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean + ) + if embeddings_centered.shape[0] == 3: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + embeddings_centered[0] -= c_x + embeddings_centered[1] -= c_y + elif embeddings_centered.shape[0] == 4: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_z = embeddings_mean_masked[0, 2] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + c_z = c_z[c_z != 0].mean() + embeddings_centered[0] -= c_x + embeddings_centered[1] -= c_y + embeddings_centered[2] -= c_z + ds_object_centered_embeddings[sample] = embeddings_centered + + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims + ] + embeddings_centered_std = embeddings_centered[-1] + + if inference_config.clustering == "meanshift": + for bandwidth_factor in range(inference_config.num_bandwidths): + if inference_config.use_seeds: + offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0) + offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2) + coordinates = peak_local_max(-offset_magnitude_smooth) + seeds = np.flip(coordinates, 1) + segmentation = mean_shift_segmentation( + embeddings_centered_mean, + embeddings_centered_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=seeds, + ) + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + else: + segmentation = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=None, + ) + # Note that the line below is needed + # because the embeddings_mean is modified + # by mean_shift_segmentation + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + ds_segmentation[sample, bandwidth_factor, ...] = segmentation + elif inference_config.clustering == "greedy": + if dataset_meta_data.num_spatial_dims == 3: + cluster3d = Cluster3d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + depth=embeddings.shape[-3], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster3d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + ds_segmentation[sample, bandwidth_factor, ...] = segmentation + elif dataset_meta_data.num_spatial_dims == 2: + cluster2d = Cluster2d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster2d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + + ds_segmentation[sample, bandwidth_factor, ...] = segmentation From e99c98aa718d91e19e6fd4644df8a05ab161ead9 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 00:42:41 -0500 Subject: [PATCH 12/18] Rename binary segmentation zarr dataset and centered embeddings zarr dataset --- cellulus/detect.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cellulus/detect.py b/cellulus/detect.py index 5208075..6576c4c 100644 --- a/cellulus/detect.py +++ b/cellulus/detect.py @@ -15,15 +15,15 @@ def detect(inference_config: InferenceConfig) -> None: dataset_config = inference_config.dataset_config dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) - f = zarr.open(inference_config.segmentation_dataset_config.container_path) - ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] + f = zarr.open(inference_config.detection_dataset_config.container_path) + ds = f[inference_config.detection_dataset_config.secondary_dataset_name] - # prepare the instance segmentation zarr dataset to write to + # prepare the zarr dataset to write to f_segmentation = zarr.open( - inference_config.segmentation_dataset_config.container_path + inference_config.detection_dataset_config.container_path ) ds_segmentation = f_segmentation.create_dataset( - inference_config.segmentation_dataset_config.dataset_name, + inference_config.detection_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, inference_config.num_bandwidths, @@ -40,7 +40,7 @@ def detect(inference_config: InferenceConfig) -> None: # prepare the binary segmentation zarr dataset to write to ds_binary_segmentation = f_segmentation.create_dataset( - "binary_" + inference_config.segmentation_dataset_config.dataset_name, + "binary-segmentation", shape=( dataset_meta_data.num_samples, 1, @@ -59,8 +59,7 @@ def detect(inference_config: InferenceConfig) -> None: # prepare the object centered embeddings zarr dataset to write to ds_object_centered_embeddings = f_segmentation.create_dataset( - "centered_" - + inference_config.segmentation_dataset_config.secondary_dataset_name, + "centered-embeddings", shape=( dataset_meta_data.num_samples, dataset_meta_data.num_spatial_dims + 1, From 28a14837da678918380ef57af06a65b517e10ca6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 00:44:34 -0500 Subject: [PATCH 13/18] Update segment.py and infer.py --- cellulus/infer.py | 4 ++-- cellulus/segment.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cellulus/infer.py b/cellulus/infer.py index 00f320d..e15df55 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -70,10 +70,10 @@ def infer(experiment_config): if inference_config.prediction_dataset_config is not None: predict(model, inference_config, normalization_factor) # ...turn them into a detection ... - if inference_config.segmentation_dataset_config is not None: + if inference_config.detection_dataset_config is not None: detect(inference_config) # ...and post-process the detection to obtain an instance segmentation - if inference_config.post_processed_dataset_config is not None: + if inference_config.segmentation_dataset_config is not None: segment(inference_config) # ...and evaluate if ground-truth exists if inference_config.evaluation_dataset_config is not None: diff --git a/cellulus/segment.py b/cellulus/segment.py index 566c950..79dfb16 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -16,15 +16,15 @@ def segment(inference_config: InferenceConfig) -> None: dataset_config = inference_config.dataset_config dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) - f = zarr.open(inference_config.post_processed_dataset_config.container_path) - ds = f[inference_config.post_processed_dataset_config.secondary_dataset_name] + f = zarr.open(inference_config.segmentation_dataset_config.container_path) + ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] # prepare the zarr dataset to write to f_postprocessed = zarr.open( - inference_config.post_processed_dataset_config.container_path + inference_config.segmentation_dataset_config.container_path ) ds_postprocessed = f_postprocessed.create_dataset( - inference_config.post_processed_dataset_config.dataset_name, + inference_config.segmentation_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, inference_config.num_bandwidths, From a859c158d7691951f6103363a42a8cb8d383fac2 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 01:05:53 -0500 Subject: [PATCH 14/18] Rename labels for figures --- docs/examples/2d/03-infer.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index 1e9d20e..bfc9b70 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -38,15 +38,15 @@ prediction_dataset_config = DatasetConfig( container_path=name + ".zarr", dataset_name="embeddings" ) -segmentation_dataset_config = DatasetConfig( +detection_dataset_config = DatasetConfig( container_path=name + ".zarr", - dataset_name="segmentation", + dataset_name="detection", secondary_dataset_name="embeddings", ) -post_processed_dataset_config = DatasetConfig( +segmentation_dataset_config = DatasetConfig( container_path=name + ".zarr", - dataset_name="post_processed_segmentation", - secondary_dataset_name="segmentation", + dataset_name="segmentation", + secondary_dataset_name="detection", ) # ## Specify config values for the model @@ -86,7 +86,7 @@ # We initialize the `inference_config` which contains our # `embeddings_dataset_config`, `segmentation_dataset_config` and -# `post_processed_dataset_config`. +# `post_processed_dataset_config`.
# We set post_processing to one of `cell` or `nucleus`, depending on if we # would like the cell membrane to be segmented or the nucleus. @@ -95,8 +95,8 @@ inference_config = InferenceConfig( dataset_config=asdict(dataset_config), prediction_dataset_config=asdict(prediction_dataset_config), + detection_dataset_config=asdict(detection_dataset_config), segmentation_dataset_config=asdict(segmentation_dataset_config), - post_processed_dataset_config=asdict(post_processed_dataset_config), post_processing=post_processing, device=device, ) @@ -109,7 +109,7 @@ experiment_config = ExperimentConfig( inference_config=asdict(inference_config), model_config=asdict(model_config), - normalization_factor=1.0, + normalization_factor=1.0, # since the data was already normalized. ) # Now we are ready to start the inference!!
@@ -139,7 +139,7 @@ f = zarr.open(name + ".zarr") ds = f["train/raw"] -ds2 = f["centered_embeddings"] +ds2 = f["centered-embeddings"] image = ds[index, 0] embedding = ds2[index] @@ -163,8 +163,8 @@ # + f = zarr.open(name + ".zarr") ds = f["train/raw"] -ds2 = f["segmentation"] -ds3 = f["post_processed_segmentation"] +ds2 = f["detection"] +ds3 = f["segmentation"] visualize_2d( image, @@ -172,9 +172,12 @@ bottom_left=ds2[index, 0], bottom_right=ds3[index, 0], top_right_label="THRESHOLDED F.G.", - bottom_left_label="SEGMENTATION", - bottom_right_label="POSTPROCESSED", + bottom_left_label="DETECTION", + bottom_right_label="SEGMENTATION", top_right_cmap="gray", bottom_left_cmap=new_cmp, bottom_right_cmap=new_cmp, ) +# - + + From 47c4c061ae229c4a2cab2b03057c15687d9bd8cc Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 01:06:47 -0500 Subject: [PATCH 15/18] Add example python scripts for 3d example --- docs/examples/3d/01-data.py | 61 ++++++++++++ docs/examples/3d/02-train.py | 80 +++++++++++++++ docs/examples/3d/03-infer.py | 186 +++++++++++++++++++++++++++++++++++ 3 files changed, 327 insertions(+) create mode 100644 docs/examples/3d/01-data.py create mode 100644 docs/examples/3d/02-train.py create mode 100644 docs/examples/3d/03-infer.py diff --git a/docs/examples/3d/01-data.py b/docs/examples/3d/01-data.py new file mode 100644 index 0000000..76ace06 --- /dev/null +++ b/docs/examples/3d/01-data.py @@ -0,0 +1,61 @@ +# # Download Data + +# In this notebook, we will download data and convert it to a zarr dataset.
+ +# For demonstration, we will use one image from the `Platynereis-Nuclei-CBG` dataset +# provided with [this](https://www.sciencedirect.com/science/article/pii/S1361841522001700) +# publication. + +# Firstly, the `tif` raw images are downloaded to a directory indicated by `data_dir`. + +from pathlib import Path + +# + +import numpy as np +import tifffile +import zarr +from cellulus.utils.misc import extract_data +from csbdeep.utils import normalize +from skimage.transform import rescale +from tqdm import tqdm + +# + +name = "3d-data-demo" +data_dir = "./data" + +extract_data( + zip_url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/3d-data-demo.zip", + data_dir=data_dir, + project_name=name, +) +# - + +# Currently, `cellulus` expects that the images are isotropic (i.e. the +# voxel size along z dimension (which is usually undersampled) is the same +# as the voxel size alng the x and y dimensions).
+# This dataset has a step size of $2.031 \mu m$ in z and $0.406 \mu m$ along +# x and y dimensions, thus, the upsampling factor (which we refer to as +# `anisotropy` equals $2.031/0.406$.
+# These raw images are upsampled, intensity-normalized and appended in a list. +# Here, we use the percentile normalization technique. + +anisotropy = 2.031 / 0.406 + +container_path = zarr.open(name + ".zarr") +subsets = ["train", "test"] +for subset in subsets: + dataset_name = subset + "/raw" + image_filenames = sorted((Path(data_dir) / name / subset).glob("*.tif")) + print(f"Number of raw images in {subset} directory is {len(image_filenames)}") + image_list = [] + + for i in tqdm(range(len(image_filenames))): + im = tifffile.imread(image_filenames[i]).astype(np.float32) + im_normalized = normalize(im, 1, 99.8, axis=(0, 1, 2)) + im_rescaled = rescale(im_normalized, (anisotropy, 1.0, 1.0)) + image_list.append(im_rescaled[np.newaxis, ...]) + + image_list = np.asarray(image_list) + container_path[dataset_name] = image_list + container_path[dataset_name].attrs["resolution"] = (1, 1, 1) + container_path[dataset_name].attrs["axis_names"] = ("s", "c", "z", "y", "x") diff --git a/docs/examples/3d/02-train.py b/docs/examples/3d/02-train.py new file mode 100644 index 0000000..017525c --- /dev/null +++ b/docs/examples/3d/02-train.py @@ -0,0 +1,80 @@ +# # Train Model + +# In this notebook, we will train a `cellulus` model. + +from attrs import asdict +from cellulus.configs.dataset_config import DatasetConfig +from cellulus.configs.experiment_config import ExperimentConfig +from cellulus.configs.model_config import ModelConfig +from cellulus.configs.train_config import TrainConfig + +# ## Specify config values for dataset + +# In the next cell, we specify the name of the zarr container and the dataset +# within it from which data would be read. + +name = "3d-data-demo" +dataset_name = "train/raw" + +train_data_config = DatasetConfig( + container_path=name + ".zarr", dataset_name=dataset_name +) + +# ## Specify config values for model + +# In the next cell, we specify the number of feature maps (`num_fmaps`) in the +# first layer in our model.
+# Additionally, we specify `fmap_inc_factor` and `downsampling_factors`, which +# indicates by how much the number of feature maps increase between adjacent +# layers, and how much the spatial extents of the image gets downsampled between +# adjacent layers respectively. + +num_fmaps = 24 +fmap_inc_factor = 3 +downsampling_factors = [ + [2, 2, 2], +] + +model_config = ModelConfig( + num_fmaps=num_fmaps, + fmap_inc_factor=fmap_inc_factor, + downsampling_factors=downsampling_factors, +) + +# ## Specify config values for the training process + +# Then, we specify training-specific parameters such as the `device`, which +# indicates the actual device to run the training on. +# We also specify the `crop_size`. Mini - batches of crops are shown to the model +# during training. +#
The device could be set equal to `cuda:n` (where `n` is the index of +# the GPU, for e.g. `cuda:0`) or `cpu`.
+# We set the `max_iterations` equal to `5000` for demonstration purposes. + +device = "cuda:0" +max_iterations = 5000 +crop_size = [80, 80, 80] + +train_config = TrainConfig( + train_data_config=asdict(train_data_config), + device=device, + max_iterations=max_iterations, + crop_size=crop_size, +) + +# Next, we initialize the experiment config which puts together the config +# objects (`train_config` and `model_config`) which we defined above. + +experiment_config = ExperimentConfig( + train_config=asdict(train_config), + model_config=asdict(model_config), + normalization_factor=1.0, # since we already normalized in previous notebook +) + +# Now we can begin the training!
+# Uncomment the next two lines to train the model. + +# + +# from cellulus.train import train +# train(experiment_config) +# - diff --git a/docs/examples/3d/03-infer.py b/docs/examples/3d/03-infer.py new file mode 100644 index 0000000..236934c --- /dev/null +++ b/docs/examples/3d/03-infer.py @@ -0,0 +1,186 @@ +# # Infer using Trained Model + +# In this notebook, we will use the `cellulus` model trained in the previous +# step to obtain instance segmentations. + +import urllib +import zipfile + +import numpy as np +import skimage +import torch +import zarr +from attrs import asdict +from cellulus.configs.dataset_config import DatasetConfig +from cellulus.configs.experiment_config import ExperimentConfig +from cellulus.configs.inference_config import InferenceConfig +from cellulus.configs.model_config import ModelConfig +from cellulus.infer import infer +from cellulus.utils.misc import visualize_2d +from matplotlib.colors import ListedColormap + +# ## Specify config values for datasets + +# We again specify `name` of the zarr container, and `dataset_name` which +# identifies the path to the raw image data, which needs to be segmented. + +name = "3d-data-demo" +dataset_name = "test/raw" + +# We initialize the `dataset_config` which relates to the raw image data, +# `prediction_dataset_config` which relates to the per-pixel embeddings and the +# uncertainty, the `segmentation_dataset_config` which relates to the segmentations +# post the mean-shift clustering and the `post_processed_config` which relates +# to the segmentations after some post-processing. + +dataset_config = DatasetConfig(container_path=name + ".zarr", dataset_name=dataset_name) +prediction_dataset_config = DatasetConfig( + container_path=name + ".zarr", dataset_name="embeddings" +) +detection_dataset_config = DatasetConfig( + container_path=name + ".zarr", + dataset_name="detection", + secondary_dataset_name="embeddings", +) +segmentation_dataset_config = DatasetConfig( + container_path=name + ".zarr", + dataset_name="segmentation", + secondary_dataset_name="detection", +) + +# ## Specify config values for the model + +# We must also specify the `num_fmaps`, `fmap_inc_factor` (use same values as +# in the training step) and set `checkpoint` equal to `models/best_loss.pth` +# (best in terms of the lowest loss obtained). + +# Here, we download a pretrained model trained by us for `5e3` iterations.
+# But please comment the next cell to use your own trained model, which +# should be available in the `models` directory. + +torch.hub.download_url_to_file( + url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/2d-demo-model.zip", + dst="pretrained_model", + progress=True, +) +with zipfile.ZipFile("pretrained_model", "r") as zip_ref: + zip_ref.extractall("") + +num_fmaps = 24 +fmap_inc_factor = 3 +downsampling_factors = [ + [2, 2, 2], +] +checkpoint = "models/best_loss.pth" + +model_config = ModelConfig( + num_fmaps=num_fmaps, + fmap_inc_factor=fmap_inc_factor, + downsampling_factors=downsampling_factors, + checkpoint=checkpoint, +) + +# ## Initialize `inference_config` + +# Then, we specify inference-specific parameters such as the `device`, which +# indicates the actual device to run the inference on. +#
The device could be set equal to `cuda:n` (where `n` is the index of the +# GPU, for e.g. `cuda:0`), `cpu` or `mps`. + +device = "cuda:0" + +# We initialize the `inference_config` which contains our `embeddings_dataset_config`, +# `segmentation_dataset_config` and `post_processed_dataset_config`. + +inference_config = InferenceConfig( + dataset_config=asdict(dataset_config), + # prediction_dataset_config=asdict(prediction_dataset_config), + # detection_dataset_config=asdict(detection_dataset_config), + segmentation_dataset_config=asdict(segmentation_dataset_config), + crop_size=[120, 120, 120], + post_processing="nucleus", + device=device, + use_seeds=True, +) + +# ## Initialize `experiment_config` + +# Lastly we initialize the `experiment_config` which contains the `inference_config` +# and `model_config` initialized above. + +experiment_config = ExperimentConfig( + inference_config=asdict(inference_config), + model_config=asdict(model_config), + normalization_factor=1.0, # since the test image is already normalized +) + +# Now we are ready to start the inference!!
+# (To see the output of the cell below, remove the first line `io.capture_output()`). + +# with io.capture_output() as captured: +infer(experiment_config) + +# ## Inspect predictions + +# Let's look at some of the predicted embeddings.
+# We will first load a glasbey-like color map to show individual cells +# with a unique color. + +urllib.request.urlretrieve( + "https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/cmap_60.npy", + "cmap_60.npy", +) +new_cmp = ListedColormap(np.load("cmap_60.npy")) + +# Change the value of `index` below to look at the raw image (left), +# x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the +# embedding (top-right). + +# + +index = 0 + +f = zarr.open(name + ".zarr") +ds = f["test/raw"] +ds2 = f["centered_embeddings"] + +slice = ds.shape[2] // 2 + +image = ds[index, 0, slice] +embedding = ds2[index, :, slice] + + +visualize_2d( + image, + top_right=embedding[-1], + bottom_left=embedding[0], + bottom_right=embedding[1], + top_right_label="UNCERTAINTY", + bottom_left_label="OFFSET_X", + bottom_right_label="OFFSET_Y", +) +# - + +# As you can see the magnitude of the uncertainty of the embedding (top-right) +# is low for most of the foreground cells.
This enables extraction +# of the foreground, which is eventually clustered into individual instances.
+# See bottom right figure for the final result. + +# + +f = zarr.open(name + ".zarr") +ds = f["test/raw"] +ds2 = f["detection"] +ds3 = f["segmentation"] + +visualize_2d( + image, + top_right=embedding[-1] < skimage.filters.threshold_otsu(embedding[-1]), + bottom_left=ds2[index, index, slice], + bottom_right=ds3[index, index, slice], + top_right_label="THRESHOLDED F.G.", + bottom_left_label="DETECTION", + bottom_right_label="SEGMENTATION", + top_right_cmap="gray", + bottom_left_cmap=new_cmp, + bottom_right_cmap=new_cmp, +) +# - From cb3bbff49a86cebaa7f846f0bebc2e45469bd0d5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 01:14:11 -0500 Subject: [PATCH 16/18] Rename f and ds variable names --- cellulus/detect.py | 22 ++++++++++------------ cellulus/segment.py | 22 ++++++++++------------ 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/cellulus/detect.py b/cellulus/detect.py index 6576c4c..6bcac19 100644 --- a/cellulus/detect.py +++ b/cellulus/detect.py @@ -19,10 +19,8 @@ def detect(inference_config: InferenceConfig) -> None: ds = f[inference_config.detection_dataset_config.secondary_dataset_name] # prepare the zarr dataset to write to - f_segmentation = zarr.open( - inference_config.detection_dataset_config.container_path - ) - ds_segmentation = f_segmentation.create_dataset( + f_detection = zarr.open(inference_config.detection_dataset_config.container_path) + ds_detection = f_detection.create_dataset( inference_config.detection_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, @@ -32,14 +30,14 @@ def detect(inference_config: InferenceConfig) -> None: dtype=np.uint16, ) - ds_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + ds_detection.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ -dataset_meta_data.num_spatial_dims : ] - ds_segmentation.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + ds_detection.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_detection.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims # prepare the binary segmentation zarr dataset to write to - ds_binary_segmentation = f_segmentation.create_dataset( + ds_binary_segmentation = f_detection.create_dataset( "binary-segmentation", shape=( dataset_meta_data.num_samples, @@ -58,7 +56,7 @@ def detect(inference_config: InferenceConfig) -> None: ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims # prepare the object centered embeddings zarr dataset to write to - ds_object_centered_embeddings = f_segmentation.create_dataset( + ds_object_centered_embeddings = f_detection.create_dataset( "centered-embeddings", shape=( dataset_meta_data.num_samples, @@ -160,7 +158,7 @@ def detect(inference_config: InferenceConfig) -> None: embeddings_mean = embeddings[ np.newaxis, : dataset_meta_data.num_spatial_dims, ... ].copy() - ds_segmentation[sample, bandwidth_factor, ...] = segmentation + ds_detection[sample, bandwidth_factor, ...] = segmentation elif inference_config.clustering == "greedy": if dataset_meta_data.num_spatial_dims == 3: cluster3d = Cluster3d( @@ -176,7 +174,7 @@ def detect(inference_config: InferenceConfig) -> None: bandwidth=inference_config.bandwidth / (2**bandwidth_factor), min_object_size=inference_config.min_size, ) - ds_segmentation[sample, bandwidth_factor, ...] = segmentation + ds_detection[sample, bandwidth_factor, ...] = segmentation elif dataset_meta_data.num_spatial_dims == 2: cluster2d = Cluster2d( width=embeddings.shape[-1], @@ -191,4 +189,4 @@ def detect(inference_config: InferenceConfig) -> None: min_object_size=inference_config.min_size, ) - ds_segmentation[sample, bandwidth_factor, ...] = segmentation + ds_detection[sample, bandwidth_factor, ...] = segmentation diff --git a/cellulus/segment.py b/cellulus/segment.py index 79dfb16..a4fcaf0 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -20,10 +20,8 @@ def segment(inference_config: InferenceConfig) -> None: ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] # prepare the zarr dataset to write to - f_postprocessed = zarr.open( - inference_config.segmentation_dataset_config.container_path - ) - ds_postprocessed = f_postprocessed.create_dataset( + f_segmented = zarr.open(inference_config.segmentation_dataset_config.container_path) + ds_segmented = f_segmented.create_dataset( inference_config.segmentation_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, @@ -33,11 +31,11 @@ def segment(inference_config: InferenceConfig) -> None: dtype=np.uint16, ) - ds_postprocessed.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + ds_segmented.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ -dataset_meta_data.num_spatial_dims : ] - ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + ds_segmented.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_segmented.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims # remove halo if inference_config.post_processing == "cell": @@ -50,7 +48,7 @@ def segment(inference_config: InferenceConfig) -> None: expanded_mask = distance_foreground < inference_config.grow_distance distance_background = dtedt(expanded_mask) segmentation[distance_background < inference_config.shrink_distance] = 0 - ds_postprocessed[sample, bandwidth_factor, ...] = segmentation + ds_segmented[sample, bandwidth_factor, ...] = segmentation elif inference_config.post_processing == "nucleus": ds_raw = f[inference_config.dataset_config.dataset_name] for sample in tqdm(range(dataset_meta_data.num_samples)): @@ -89,7 +87,7 @@ def segment(inference_config: InferenceConfig) -> None: ) mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, y, x] = id_ + ds_segmented[sample, bandwidth_factor, y, x] = id_ elif dataset_meta_data.num_spatial_dims == 3: mask_small = binary_fill_holes( mask[ @@ -100,11 +98,11 @@ def segment(inference_config: InferenceConfig) -> None: z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 ] = mask_small z, y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ + ds_segmented[sample, bandwidth_factor, z, y, x] = id_ # size filter - remove small objects for sample in tqdm(range(dataset_meta_data.num_samples)): for bandwidth_factor in range(inference_config.num_bandwidths): - ds_postprocessed[sample, bandwidth_factor, ...] = size_filter( - ds_postprocessed[sample, bandwidth_factor], inference_config.min_size + ds_segmented[sample, bandwidth_factor, ...] = size_filter( + ds_segmented[sample, bandwidth_factor], inference_config.min_size ) From 10a70d0de456c6f4d2d9fd006b687acf93429054 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 01:15:26 -0500 Subject: [PATCH 17/18] Refactor example scripts --- docs/examples/2d/03-infer.py | 4 +--- docs/examples/3d/01-data.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index bfc9b70..1128840 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -109,7 +109,7 @@ experiment_config = ExperimentConfig( inference_config=asdict(inference_config), model_config=asdict(model_config), - normalization_factor=1.0, # since the data was already normalized. + normalization_factor=1.0, # since the data was already normalized. ) # Now we are ready to start the inference!!
@@ -179,5 +179,3 @@ bottom_right_cmap=new_cmp, ) # - - - diff --git a/docs/examples/3d/01-data.py b/docs/examples/3d/01-data.py index 76ace06..2a3825b 100644 --- a/docs/examples/3d/01-data.py +++ b/docs/examples/3d/01-data.py @@ -3,7 +3,8 @@ # In this notebook, we will download data and convert it to a zarr dataset.
# For demonstration, we will use one image from the `Platynereis-Nuclei-CBG` dataset -# provided with [this](https://www.sciencedirect.com/science/article/pii/S1361841522001700) +# provided +# with [this](https://www.sciencedirect.com/science/article/pii/S1361841522001700) # publication. # Firstly, the `tif` raw images are downloaded to a directory indicated by `data_dir`. From b2f891f520cbbc95f5e9b2b3cbe8616626182a2c Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 3 Mar 2024 01:20:14 -0500 Subject: [PATCH 18/18] Update mkdocs.yml --- mkdocs.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 607af7f..5ed2b1e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,10 +61,15 @@ nav: - Home: index.md - Getting Started: - Installation: getting-started.md - - 2D Example: - - 01: examples/2d/01-data.py - - 02: examples/2d/02-train.py - - 03: examples/2d/03-infer.py + - Examples: + - 2D: + - 01: examples/2d/01-data.py + - 02: examples/2d/02-train.py + - 03: examples/2d/03-infer.py + - 3D: + - 01: examples/3d/01-data.py + - 02: examples/3d/02-train.py + - 03: examples/3d/03-infer.py - API Reference: - Configs: - DatasetConfig: api/dataset_config.md