diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py index 707333f..2719378 100644 --- a/cellulus/configs/inference_config.py +++ b/cellulus/configs/inference_config.py @@ -22,13 +22,13 @@ class InferenceConfig: Configuration object produced by predict.py. - segmentation_dataset_config: + detection_dataset_config: - Configuration object produced by segment.py. + Configuration object produced by detect.py. - post_processed_dataset_config: + segmentation_dataset_config: - Configuration object produced by post_process.py. + Configuration object produced by segment.py. evaluation_dataset_config: @@ -72,6 +72,12 @@ class InferenceConfig: How to cluster the embeddings? Can be one of 'meanshift' or 'greedy'. + use_seeds (default = False): + + If set to True, the local optima of the distance map from the + predicted object centers is used. + Else, seeds are determined by sklearn.cluster.MeanShift. + num_bandwidths (default = 1): Number of bandwidths to obtain segmentations for. @@ -118,11 +124,11 @@ class InferenceConfig: default=None, converter=to_config(DatasetConfig) ) - segmentation_dataset_config: DatasetConfig = attrs.field( + detection_dataset_config: DatasetConfig = attrs.field( default=None, converter=to_config(DatasetConfig) ) - post_processed_dataset_config: DatasetConfig = attrs.field( + segmentation_dataset_config: DatasetConfig = attrs.field( default=None, converter=to_config(DatasetConfig) ) @@ -139,6 +145,7 @@ class InferenceConfig: clustering = attrs.field( default="meanshift", validator=in_(["meanshift", "greedy"]) ) + use_seeds = attrs.field(default=False, validator=instance_of(bool)) bandwidth = attrs.field( default=None, validator=attrs.validators.optional(instance_of(float)) ) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index 1a6ad78..a864d89 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -137,17 +137,24 @@ def __yield_sample(self): with gp.build(self.pipeline): while True: + array_is_zero = True # request one sample, all channels, plus crop dimensions - request = gp.BatchRequest() - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + while array_is_zero: + request = gp.BatchRequest() + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, + (1, self.num_channels, *self.crop_size), + ) ) - ) - sample = self.pipeline.request_batch(request) - sample_data = sample[self.raw].data[0] - anchor_samples, reference_samples = self.sample_coordinates() + sample = self.pipeline.request_batch(request) + sample_data = sample[self.raw].data[0] + if np.max(sample_data) <= 0.0: + pass + else: + array_is_zero = False + anchor_samples, reference_samples = self.sample_coordinates() yield sample_data, anchor_samples, reference_samples def __read_meta_data(self): diff --git a/cellulus/detect.py b/cellulus/detect.py new file mode 100644 index 0000000..6bcac19 --- /dev/null +++ b/cellulus/detect.py @@ -0,0 +1,192 @@ +import numpy as np +import zarr +from scipy.ndimage import gaussian_filter +from skimage.feature import peak_local_max +from skimage.filters import threshold_otsu +from tqdm import tqdm + +from cellulus.configs.inference_config import InferenceConfig +from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d +from cellulus.utils.mean_shift import mean_shift_segmentation + + +def detect(inference_config: InferenceConfig) -> None: + dataset_config = inference_config.dataset_config + dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) + + f = zarr.open(inference_config.detection_dataset_config.container_path) + ds = f[inference_config.detection_dataset_config.secondary_dataset_name] + + # prepare the zarr dataset to write to + f_detection = zarr.open(inference_config.detection_dataset_config.container_path) + ds_detection = f_detection.create_dataset( + inference_config.detection_dataset_config.dataset_name, + shape=( + dataset_meta_data.num_samples, + inference_config.num_bandwidths, + *dataset_meta_data.spatial_array, + ), + dtype=np.uint16, + ) + + ds_detection.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + ds_detection.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_detection.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + + # prepare the binary segmentation zarr dataset to write to + ds_binary_segmentation = f_detection.create_dataset( + "binary-segmentation", + shape=( + dataset_meta_data.num_samples, + 1, + *dataset_meta_data.spatial_array, + ), + dtype=np.uint16, + ) + + ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + -dataset_meta_data.num_spatial_dims : + ] + ds_binary_segmentation.attrs["resolution"] = ( + 1, + ) * dataset_meta_data.num_spatial_dims + ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + + # prepare the object centered embeddings zarr dataset to write to + ds_object_centered_embeddings = f_detection.create_dataset( + "centered-embeddings", + shape=( + dataset_meta_data.num_samples, + dataset_meta_data.num_spatial_dims + 1, + *dataset_meta_data.spatial_array, + ), + dtype=float, + ) + + ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [ + "t", + "z", + "y", + "x", + ][-dataset_meta_data.num_spatial_dims :] + ds_object_centered_embeddings.attrs["resolution"] = ( + 1, + ) * dataset_meta_data.num_spatial_dims + ds_object_centered_embeddings.attrs["offset"] = ( + 0, + ) * dataset_meta_data.num_spatial_dims + + for sample in tqdm(range(dataset_meta_data.num_samples)): + embeddings = ds[sample] + embeddings_std = embeddings[-1, ...] + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + if inference_config.threshold is None: + threshold = threshold_otsu(embeddings_std) + else: + threshold = inference_config.threshold + + print(f"For sample {sample}, binary threshold {threshold} was used.") + binary_mask = embeddings_std < threshold + ds_binary_segmentation[sample, 0, ...] = binary_mask + + # find mean of embeddings + embeddings_centered = embeddings.copy() + embeddings_mean_masked = ( + binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean + ) + if embeddings_centered.shape[0] == 3: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + embeddings_centered[0] -= c_x + embeddings_centered[1] -= c_y + elif embeddings_centered.shape[0] == 4: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_z = embeddings_mean_masked[0, 2] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + c_z = c_z[c_z != 0].mean() + embeddings_centered[0] -= c_x + embeddings_centered[1] -= c_y + embeddings_centered[2] -= c_z + ds_object_centered_embeddings[sample] = embeddings_centered + + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims + ] + embeddings_centered_std = embeddings_centered[-1] + + if inference_config.clustering == "meanshift": + for bandwidth_factor in range(inference_config.num_bandwidths): + if inference_config.use_seeds: + offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0) + offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2) + coordinates = peak_local_max(-offset_magnitude_smooth) + seeds = np.flip(coordinates, 1) + segmentation = mean_shift_segmentation( + embeddings_centered_mean, + embeddings_centered_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=seeds, + ) + embeddings_centered_mean = embeddings_centered[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + else: + segmentation = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_size=inference_config.min_size, + reduction_probability=inference_config.reduction_probability, + threshold=threshold, + seeds=None, + ) + # Note that the line below is needed + # because the embeddings_mean is modified + # by mean_shift_segmentation + embeddings_mean = embeddings[ + np.newaxis, : dataset_meta_data.num_spatial_dims, ... + ].copy() + ds_detection[sample, bandwidth_factor, ...] = segmentation + elif inference_config.clustering == "greedy": + if dataset_meta_data.num_spatial_dims == 3: + cluster3d = Cluster3d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + depth=embeddings.shape[-3], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster3d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + ds_detection[sample, bandwidth_factor, ...] = segmentation + elif dataset_meta_data.num_spatial_dims == 2: + cluster2d = Cluster2d( + width=embeddings.shape[-1], + height=embeddings.shape[-2], + fg_mask=binary_mask, + device=inference_config.device, + ) + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = cluster2d.cluster( + prediction=embeddings, + bandwidth=inference_config.bandwidth / (2**bandwidth_factor), + min_object_size=inference_config.min_size, + ) + + ds_detection[sample, bandwidth_factor, ...] = segmentation diff --git a/cellulus/infer.py b/cellulus/infer.py index 0ddff24..e15df55 100644 --- a/cellulus/infer.py +++ b/cellulus/infer.py @@ -4,9 +4,9 @@ import torch from cellulus.datasets.meta_data import DatasetMetaData +from cellulus.detect import detect from cellulus.evaluate import evaluate from cellulus.models import get_model -from cellulus.post_process import post_process from cellulus.predict import predict from cellulus.segment import segment @@ -35,7 +35,7 @@ def infer(experiment_config): ) elif dataset_meta_data.num_spatial_dims == 3: inference_config.min_size = int( - 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) + 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) / 8 ) # set model model = get_model( @@ -69,12 +69,12 @@ def infer(experiment_config): # get predicted embeddings... if inference_config.prediction_dataset_config is not None: predict(model, inference_config, normalization_factor) - # ...turn them into a segmentation... + # ...turn them into a detection ... + if inference_config.detection_dataset_config is not None: + detect(inference_config) + # ...and post-process the detection to obtain an instance segmentation if inference_config.segmentation_dataset_config is not None: segment(inference_config) - # ...and post-process the segmentation - if inference_config.post_processed_dataset_config is not None: - post_process(inference_config) # ...and evaluate if ground-truth exists if inference_config.evaluation_dataset_config is not None: evaluate(inference_config) diff --git a/cellulus/post_process.py b/cellulus/post_process.py deleted file mode 100644 index 4c31994..0000000 --- a/cellulus/post_process.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np -import zarr -from scipy.ndimage import binary_fill_holes, label -from scipy.ndimage import distance_transform_edt as dtedt -from skimage.filters import threshold_otsu -from tqdm import tqdm - -from cellulus.configs.inference_config import InferenceConfig -from cellulus.datasets.meta_data import DatasetMetaData -from cellulus.utils.misc import size_filter - - -def post_process(inference_config: InferenceConfig) -> None: - # filter small objects, erosion, etc. - - dataset_config = inference_config.dataset_config - dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) - - f = zarr.open(inference_config.post_processed_dataset_config.container_path) - ds = f[inference_config.post_processed_dataset_config.secondary_dataset_name] - - # prepare the zarr dataset to write to - f_postprocessed = zarr.open( - inference_config.post_processed_dataset_config.container_path - ) - ds_postprocessed = f_postprocessed.create_dataset( - inference_config.post_processed_dataset_config.dataset_name, - shape=( - dataset_meta_data.num_samples, - inference_config.num_bandwidths, - *dataset_meta_data.spatial_array, - ), - dtype=np.uint16, - ) - - ds_postprocessed.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - - # remove halo - if inference_config.post_processing == "cell": - for sample in tqdm(range(dataset_meta_data.num_samples)): - # first instance label masks are expanded by `grow_distance` - # next, expanded instance label masks are shrunk by `shrink_distance` - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = ds[sample, bandwidth_factor] - distance_foreground = dtedt(segmentation == 0) - expanded_mask = distance_foreground < inference_config.grow_distance - distance_background = dtedt(expanded_mask) - segmentation[distance_background < inference_config.shrink_distance] = 0 - ds_postprocessed[sample, bandwidth_factor, ...] = segmentation - elif inference_config.post_processing == "nucleus": - ds_raw = f[inference_config.dataset_config.dataset_name] - for sample in tqdm(range(dataset_meta_data.num_samples)): - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = ds[sample, bandwidth_factor] - raw_image = ds_raw[sample, 0] - ids = np.unique(segmentation) - ids = ids[ids != 0] - for id_ in ids: - raw_image_masked = raw_image[segmentation == id_] - threshold = threshold_otsu(raw_image_masked) - mask = (segmentation == id_) & (raw_image > threshold) - mask = binary_fill_holes(mask) - if dataset_meta_data.num_spatial_dims == 2: - y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, y, x] = id_ - elif dataset_meta_data.num_spatial_dims == 3: - z, y, x = np.where(mask) - ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_ - - # remove non-connected components - for bandwidth_factor in range(inference_config.num_bandwidths): - ids = np.unique(ds_postprocessed[sample, bandwidth_factor]) - ids = ids[ids != 0] - counter = np.max(ids) + 1 - for id_ in ids: - ma_id = ds_postprocessed[sample, bandwidth_factor] == id_ - array, num_features = label(ma_id) - if num_features > 1: - ids_array = np.unique(array) - ids_array = ids_array[ids_array != 0] - for id_array in ids_array: - if dataset_meta_data.num_spatial_dims == 2: - y, x = np.where(array == id_array) - ds_postprocessed[ - sample, bandwidth_factor, y, x - ] = counter - elif dataset_meta_data.num_spatial_dims == 3: - z, y, x = np.where(array == id_array) - ds_postprocessed[ - sample, bandwidth_factor, z, y, x - ] = counter - counter += 1 - - # size filter - remove small objects - for sample in tqdm(range(dataset_meta_data.num_samples)): - for bandwidth_factor in range(inference_config.num_bandwidths): - ds_postprocessed[sample, bandwidth_factor, ...] = size_filter( - ds_postprocessed[sample, bandwidth_factor], inference_config.min_size - ) diff --git a/cellulus/segment.py b/cellulus/segment.py index f1ba70c..a4fcaf0 100644 --- a/cellulus/segment.py +++ b/cellulus/segment.py @@ -1,26 +1,27 @@ import numpy as np import zarr +from scipy.ndimage import binary_fill_holes +from scipy.ndimage import distance_transform_edt as dtedt from skimage.filters import threshold_otsu from tqdm import tqdm from cellulus.configs.inference_config import InferenceConfig from cellulus.datasets.meta_data import DatasetMetaData -from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d -from cellulus.utils.mean_shift import mean_shift_segmentation +from cellulus.utils.misc import size_filter def segment(inference_config: InferenceConfig) -> None: + # filter small objects, erosion, etc. + dataset_config = inference_config.dataset_config dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config) f = zarr.open(inference_config.segmentation_dataset_config.container_path) ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name] - # prepare the instance segmentation zarr dataset to write to - f_segmentation = zarr.open( - inference_config.segmentation_dataset_config.container_path - ) - ds_segmentation = f_segmentation.create_dataset( + # prepare the zarr dataset to write to + f_segmented = zarr.open(inference_config.segmentation_dataset_config.container_path) + ds_segmented = f_segmented.create_dataset( inference_config.segmentation_dataset_config.dataset_name, shape=( dataset_meta_data.num_samples, @@ -30,140 +31,78 @@ def segment(inference_config: InferenceConfig) -> None: dtype=np.uint16, ) - ds_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ - -dataset_meta_data.num_spatial_dims : - ] - ds_segmentation.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims - ds_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - - # prepare the binary segmentation zarr dataset to write to - ds_binary_segmentation = f_segmentation.create_dataset( - "binary_" + inference_config.segmentation_dataset_config.dataset_name, - shape=( - dataset_meta_data.num_samples, - 1, - *dataset_meta_data.spatial_array, - ), - dtype=np.uint16, - ) - - ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ + ds_segmented.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][ -dataset_meta_data.num_spatial_dims : ] - ds_binary_segmentation.attrs["resolution"] = ( - 1, - ) * dataset_meta_data.num_spatial_dims - ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims + ds_segmented.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims + ds_segmented.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims - # prepare the object centered embeddings zarr dataset to write to - ds_object_centered_embeddings = f_segmentation.create_dataset( - "centered_" - + inference_config.segmentation_dataset_config.secondary_dataset_name, - shape=( - dataset_meta_data.num_samples, - dataset_meta_data.num_spatial_dims + 1, - *dataset_meta_data.spatial_array, - ), - dtype=float, - ) + # remove halo + if inference_config.post_processing == "cell": + for sample in tqdm(range(dataset_meta_data.num_samples)): + # first instance label masks are expanded by `grow_distance` + # next, expanded instance label masks are shrunk by `shrink_distance` + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = ds[sample, bandwidth_factor] + distance_foreground = dtedt(segmentation == 0) + expanded_mask = distance_foreground < inference_config.grow_distance + distance_background = dtedt(expanded_mask) + segmentation[distance_background < inference_config.shrink_distance] = 0 + ds_segmented[sample, bandwidth_factor, ...] = segmentation + elif inference_config.post_processing == "nucleus": + ds_raw = f[inference_config.dataset_config.dataset_name] + for sample in tqdm(range(dataset_meta_data.num_samples)): + for bandwidth_factor in range(inference_config.num_bandwidths): + segmentation = ds[sample, bandwidth_factor] + raw_image = ds_raw[sample, 0] + ids = np.unique(segmentation) + ids = ids[ids != 0] + for id_ in ids: + segmentation_id_mask = segmentation == id_ + if dataset_meta_data.num_spatial_dims == 2: + y, x = np.where(segmentation_id_mask) + y_min, y_max, x_min, x_max = ( + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + elif dataset_meta_data.num_spatial_dims == 3: + z, y, x = np.where(segmentation_id_mask) + z_min, z_max, y_min, y_max, x_min, x_max = ( + np.min(z), + np.max(z), + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + raw_image_masked = raw_image[segmentation_id_mask] + threshold = threshold_otsu(raw_image_masked) + mask = segmentation_id_mask & (raw_image > threshold) - ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [ - "t", - "z", - "y", - "x", - ][-dataset_meta_data.num_spatial_dims :] - ds_object_centered_embeddings.attrs["resolution"] = ( - 1, - ) * dataset_meta_data.num_spatial_dims - ds_object_centered_embeddings.attrs["offset"] = ( - 0, - ) * dataset_meta_data.num_spatial_dims + if dataset_meta_data.num_spatial_dims == 2: + mask_small = binary_fill_holes( + mask[y_min : y_max + 1, x_min : x_max + 1] + ) + mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small + y, x = np.where(mask) + ds_segmented[sample, bandwidth_factor, y, x] = id_ + elif dataset_meta_data.num_spatial_dims == 3: + mask_small = binary_fill_holes( + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] + ) + mask[ + z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1 + ] = mask_small + z, y, x = np.where(mask) + ds_segmented[sample, bandwidth_factor, z, y, x] = id_ + # size filter - remove small objects for sample in tqdm(range(dataset_meta_data.num_samples)): - embeddings = ds[sample] - embeddings_std = embeddings[-1, ...] - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - if inference_config.threshold is None: - threshold = threshold_otsu(embeddings_std) - else: - threshold = inference_config.threshold - - print(f"For sample {sample}, binary threshold {threshold} was used.") - binary_mask = embeddings_std < threshold - ds_binary_segmentation[sample, 0, ...] = binary_mask - - # find mean of embeddings - embeddings_centered = embeddings.copy() - embeddings_mean_masked = ( - binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean - ) - if embeddings_centered.shape[0] == 3: - c_x = embeddings_mean_masked[0, 0] - c_y = embeddings_mean_masked[0, 1] - c_x = c_x[c_x != 0].mean() - c_y = c_y[c_y != 0].mean() - embeddings_centered[0] -= c_x - embeddings_centered[1] -= c_y - elif embeddings_centered.shape[0] == 4: - c_x = embeddings_mean_masked[0, 0] - c_y = embeddings_mean_masked[0, 1] - c_z = embeddings_mean_masked[0, 2] - c_x = c_x[c_x != 0].mean() - c_y = c_y[c_y != 0].mean() - c_z = c_z[c_z != 0].mean() - embeddings_centered[0] -= c_x - embeddings_centered[1] -= c_y - embeddings_centered[2] -= c_z - ds_object_centered_embeddings[sample] = embeddings_centered - - if inference_config.clustering == "meanshift": - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_size=inference_config.min_size, - reduction_probability=inference_config.reduction_probability, - threshold=threshold, - ) - # Note that the line below is needed - # because the embeddings_mean is modified - # by mean_shift_segmentation - embeddings_mean = embeddings[ - np.newaxis, : dataset_meta_data.num_spatial_dims, ... - ].copy() - ds_segmentation[sample, bandwidth_factor, ...] = segmentation - elif inference_config.clustering == "greedy": - if dataset_meta_data.num_spatial_dims == 3: - cluster3d = Cluster3d( - width=embeddings.shape[-1], - height=embeddings.shape[-2], - depth=embeddings.shape[-3], - fg_mask=binary_mask, - device=inference_config.device, - ) - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = cluster3d.cluster( - prediction=embeddings, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_object_size=inference_config.min_size, - ) - ds_segmentation[sample, bandwidth_factor, ...] = segmentation - elif dataset_meta_data.num_spatial_dims == 2: - cluster2d = Cluster2d( - width=embeddings.shape[-1], - height=embeddings.shape[-2], - fg_mask=binary_mask, - device=inference_config.device, - ) - for bandwidth_factor in range(inference_config.num_bandwidths): - segmentation = cluster2d.cluster( - prediction=embeddings, - bandwidth=inference_config.bandwidth / (2**bandwidth_factor), - min_object_size=inference_config.min_size, - ) - - ds_segmentation[sample, bandwidth_factor, ...] = segmentation + for bandwidth_factor in range(inference_config.num_bandwidths): + ds_segmented[sample, bandwidth_factor, ...] = size_filter( + ds_segmented[sample, bandwidth_factor], inference_config.min_size + ) diff --git a/cellulus/utils/mean_shift.py b/cellulus/utils/mean_shift.py index bdf6012..a51e488 100644 --- a/cellulus/utils/mean_shift.py +++ b/cellulus/utils/mean_shift.py @@ -4,7 +4,13 @@ def mean_shift_segmentation( - embedding_mean, embedding_std, bandwidth, min_size, reduction_probability, threshold + embedding_mean, + embedding_std, + bandwidth, + min_size, + reduction_probability, + threshold, + seeds, ): embedding_mean = torch.from_numpy(embedding_mean) if embedding_mean.ndim == 4: @@ -34,22 +40,28 @@ def mean_shift_segmentation( mask=mask, reduction_probability=reduction_probability, cluster_all=False, + seeds=seeds, )[0] return segmentation def segment_with_meanshift( - embedding, bandwidth, mask, reduction_probability, cluster_all + embedding, bandwidth, mask, reduction_probability, cluster_all, seeds ): anchor_mean_shift = AnchorMeanshift( - bandwidth, reduction_probability=reduction_probability, cluster_all=cluster_all + bandwidth, + reduction_probability=reduction_probability, + cluster_all=cluster_all, + seeds=seeds, ) return anchor_mean_shift(embedding, mask=mask) + 1 class AnchorMeanshift: - def __init__(self, bandwidth, reduction_probability, cluster_all): - self.mean_shift = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) + def __init__(self, bandwidth, reduction_probability, cluster_all, seeds): + self.mean_shift = MeanShift( + bandwidth=bandwidth, cluster_all=cluster_all, seeds=seeds + ) self.reduction_probability = reduction_probability def compute_mean_shift(self, X): diff --git a/cellulus/utils/misc.py b/cellulus/utils/misc.py index 2f77c72..f3edfbc 100644 --- a/cellulus/utils/misc.py +++ b/cellulus/utils/misc.py @@ -13,15 +13,16 @@ def size_filter(segmentation, min_size, filter_non_connected=True): return segmentation if filter_non_connected: - filter_labels = measure.label(segmentation, background=0) + filter_labels = measure.label(segmentation) else: filter_labels = segmentation + ids, sizes = np.unique(filter_labels, return_counts=True) filter_ids = ids[sizes < min_size] mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape) segmentation[mask] = 0 - return segmentation + return measure.label(segmentation) def extract_data(zip_url, data_dir, project_name): diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py index 1e9d20e..1128840 100644 --- a/docs/examples/2d/03-infer.py +++ b/docs/examples/2d/03-infer.py @@ -38,15 +38,15 @@ prediction_dataset_config = DatasetConfig( container_path=name + ".zarr", dataset_name="embeddings" ) -segmentation_dataset_config = DatasetConfig( +detection_dataset_config = DatasetConfig( container_path=name + ".zarr", - dataset_name="segmentation", + dataset_name="detection", secondary_dataset_name="embeddings", ) -post_processed_dataset_config = DatasetConfig( +segmentation_dataset_config = DatasetConfig( container_path=name + ".zarr", - dataset_name="post_processed_segmentation", - secondary_dataset_name="segmentation", + dataset_name="segmentation", + secondary_dataset_name="detection", ) # ## Specify config values for the model @@ -86,7 +86,7 @@ # We initialize the `inference_config` which contains our # `embeddings_dataset_config`, `segmentation_dataset_config` and -# `post_processed_dataset_config`. +# `post_processed_dataset_config`.
# We set post_processing to one of `cell` or `nucleus`, depending on if we # would like the cell membrane to be segmented or the nucleus. @@ -95,8 +95,8 @@ inference_config = InferenceConfig( dataset_config=asdict(dataset_config), prediction_dataset_config=asdict(prediction_dataset_config), + detection_dataset_config=asdict(detection_dataset_config), segmentation_dataset_config=asdict(segmentation_dataset_config), - post_processed_dataset_config=asdict(post_processed_dataset_config), post_processing=post_processing, device=device, ) @@ -109,7 +109,7 @@ experiment_config = ExperimentConfig( inference_config=asdict(inference_config), model_config=asdict(model_config), - normalization_factor=1.0, + normalization_factor=1.0, # since the data was already normalized. ) # Now we are ready to start the inference!!
@@ -139,7 +139,7 @@ f = zarr.open(name + ".zarr") ds = f["train/raw"] -ds2 = f["centered_embeddings"] +ds2 = f["centered-embeddings"] image = ds[index, 0] embedding = ds2[index] @@ -163,8 +163,8 @@ # + f = zarr.open(name + ".zarr") ds = f["train/raw"] -ds2 = f["segmentation"] -ds3 = f["post_processed_segmentation"] +ds2 = f["detection"] +ds3 = f["segmentation"] visualize_2d( image, @@ -172,9 +172,10 @@ bottom_left=ds2[index, 0], bottom_right=ds3[index, 0], top_right_label="THRESHOLDED F.G.", - bottom_left_label="SEGMENTATION", - bottom_right_label="POSTPROCESSED", + bottom_left_label="DETECTION", + bottom_right_label="SEGMENTATION", top_right_cmap="gray", bottom_left_cmap=new_cmp, bottom_right_cmap=new_cmp, ) +# - diff --git a/docs/examples/3d/01-data.py b/docs/examples/3d/01-data.py new file mode 100644 index 0000000..2a3825b --- /dev/null +++ b/docs/examples/3d/01-data.py @@ -0,0 +1,62 @@ +# # Download Data + +# In this notebook, we will download data and convert it to a zarr dataset.
+ +# For demonstration, we will use one image from the `Platynereis-Nuclei-CBG` dataset +# provided +# with [this](https://www.sciencedirect.com/science/article/pii/S1361841522001700) +# publication. + +# Firstly, the `tif` raw images are downloaded to a directory indicated by `data_dir`. + +from pathlib import Path + +# + +import numpy as np +import tifffile +import zarr +from cellulus.utils.misc import extract_data +from csbdeep.utils import normalize +from skimage.transform import rescale +from tqdm import tqdm + +# + +name = "3d-data-demo" +data_dir = "./data" + +extract_data( + zip_url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/3d-data-demo.zip", + data_dir=data_dir, + project_name=name, +) +# - + +# Currently, `cellulus` expects that the images are isotropic (i.e. the +# voxel size along z dimension (which is usually undersampled) is the same +# as the voxel size alng the x and y dimensions).
+# This dataset has a step size of $2.031 \mu m$ in z and $0.406 \mu m$ along +# x and y dimensions, thus, the upsampling factor (which we refer to as +# `anisotropy` equals $2.031/0.406$.
+# These raw images are upsampled, intensity-normalized and appended in a list. +# Here, we use the percentile normalization technique. + +anisotropy = 2.031 / 0.406 + +container_path = zarr.open(name + ".zarr") +subsets = ["train", "test"] +for subset in subsets: + dataset_name = subset + "/raw" + image_filenames = sorted((Path(data_dir) / name / subset).glob("*.tif")) + print(f"Number of raw images in {subset} directory is {len(image_filenames)}") + image_list = [] + + for i in tqdm(range(len(image_filenames))): + im = tifffile.imread(image_filenames[i]).astype(np.float32) + im_normalized = normalize(im, 1, 99.8, axis=(0, 1, 2)) + im_rescaled = rescale(im_normalized, (anisotropy, 1.0, 1.0)) + image_list.append(im_rescaled[np.newaxis, ...]) + + image_list = np.asarray(image_list) + container_path[dataset_name] = image_list + container_path[dataset_name].attrs["resolution"] = (1, 1, 1) + container_path[dataset_name].attrs["axis_names"] = ("s", "c", "z", "y", "x") diff --git a/docs/examples/3d/02-train.py b/docs/examples/3d/02-train.py new file mode 100644 index 0000000..017525c --- /dev/null +++ b/docs/examples/3d/02-train.py @@ -0,0 +1,80 @@ +# # Train Model + +# In this notebook, we will train a `cellulus` model. + +from attrs import asdict +from cellulus.configs.dataset_config import DatasetConfig +from cellulus.configs.experiment_config import ExperimentConfig +from cellulus.configs.model_config import ModelConfig +from cellulus.configs.train_config import TrainConfig + +# ## Specify config values for dataset + +# In the next cell, we specify the name of the zarr container and the dataset +# within it from which data would be read. + +name = "3d-data-demo" +dataset_name = "train/raw" + +train_data_config = DatasetConfig( + container_path=name + ".zarr", dataset_name=dataset_name +) + +# ## Specify config values for model + +# In the next cell, we specify the number of feature maps (`num_fmaps`) in the +# first layer in our model.
+# Additionally, we specify `fmap_inc_factor` and `downsampling_factors`, which +# indicates by how much the number of feature maps increase between adjacent +# layers, and how much the spatial extents of the image gets downsampled between +# adjacent layers respectively. + +num_fmaps = 24 +fmap_inc_factor = 3 +downsampling_factors = [ + [2, 2, 2], +] + +model_config = ModelConfig( + num_fmaps=num_fmaps, + fmap_inc_factor=fmap_inc_factor, + downsampling_factors=downsampling_factors, +) + +# ## Specify config values for the training process + +# Then, we specify training-specific parameters such as the `device`, which +# indicates the actual device to run the training on. +# We also specify the `crop_size`. Mini - batches of crops are shown to the model +# during training. +#
The device could be set equal to `cuda:n` (where `n` is the index of +# the GPU, for e.g. `cuda:0`) or `cpu`.
+# We set the `max_iterations` equal to `5000` for demonstration purposes. + +device = "cuda:0" +max_iterations = 5000 +crop_size = [80, 80, 80] + +train_config = TrainConfig( + train_data_config=asdict(train_data_config), + device=device, + max_iterations=max_iterations, + crop_size=crop_size, +) + +# Next, we initialize the experiment config which puts together the config +# objects (`train_config` and `model_config`) which we defined above. + +experiment_config = ExperimentConfig( + train_config=asdict(train_config), + model_config=asdict(model_config), + normalization_factor=1.0, # since we already normalized in previous notebook +) + +# Now we can begin the training!
+# Uncomment the next two lines to train the model. + +# + +# from cellulus.train import train +# train(experiment_config) +# - diff --git a/docs/examples/3d/03-infer.py b/docs/examples/3d/03-infer.py new file mode 100644 index 0000000..236934c --- /dev/null +++ b/docs/examples/3d/03-infer.py @@ -0,0 +1,186 @@ +# # Infer using Trained Model + +# In this notebook, we will use the `cellulus` model trained in the previous +# step to obtain instance segmentations. + +import urllib +import zipfile + +import numpy as np +import skimage +import torch +import zarr +from attrs import asdict +from cellulus.configs.dataset_config import DatasetConfig +from cellulus.configs.experiment_config import ExperimentConfig +from cellulus.configs.inference_config import InferenceConfig +from cellulus.configs.model_config import ModelConfig +from cellulus.infer import infer +from cellulus.utils.misc import visualize_2d +from matplotlib.colors import ListedColormap + +# ## Specify config values for datasets + +# We again specify `name` of the zarr container, and `dataset_name` which +# identifies the path to the raw image data, which needs to be segmented. + +name = "3d-data-demo" +dataset_name = "test/raw" + +# We initialize the `dataset_config` which relates to the raw image data, +# `prediction_dataset_config` which relates to the per-pixel embeddings and the +# uncertainty, the `segmentation_dataset_config` which relates to the segmentations +# post the mean-shift clustering and the `post_processed_config` which relates +# to the segmentations after some post-processing. + +dataset_config = DatasetConfig(container_path=name + ".zarr", dataset_name=dataset_name) +prediction_dataset_config = DatasetConfig( + container_path=name + ".zarr", dataset_name="embeddings" +) +detection_dataset_config = DatasetConfig( + container_path=name + ".zarr", + dataset_name="detection", + secondary_dataset_name="embeddings", +) +segmentation_dataset_config = DatasetConfig( + container_path=name + ".zarr", + dataset_name="segmentation", + secondary_dataset_name="detection", +) + +# ## Specify config values for the model + +# We must also specify the `num_fmaps`, `fmap_inc_factor` (use same values as +# in the training step) and set `checkpoint` equal to `models/best_loss.pth` +# (best in terms of the lowest loss obtained). + +# Here, we download a pretrained model trained by us for `5e3` iterations.
+# But please comment the next cell to use your own trained model, which +# should be available in the `models` directory. + +torch.hub.download_url_to_file( + url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/2d-demo-model.zip", + dst="pretrained_model", + progress=True, +) +with zipfile.ZipFile("pretrained_model", "r") as zip_ref: + zip_ref.extractall("") + +num_fmaps = 24 +fmap_inc_factor = 3 +downsampling_factors = [ + [2, 2, 2], +] +checkpoint = "models/best_loss.pth" + +model_config = ModelConfig( + num_fmaps=num_fmaps, + fmap_inc_factor=fmap_inc_factor, + downsampling_factors=downsampling_factors, + checkpoint=checkpoint, +) + +# ## Initialize `inference_config` + +# Then, we specify inference-specific parameters such as the `device`, which +# indicates the actual device to run the inference on. +#
The device could be set equal to `cuda:n` (where `n` is the index of the +# GPU, for e.g. `cuda:0`), `cpu` or `mps`. + +device = "cuda:0" + +# We initialize the `inference_config` which contains our `embeddings_dataset_config`, +# `segmentation_dataset_config` and `post_processed_dataset_config`. + +inference_config = InferenceConfig( + dataset_config=asdict(dataset_config), + # prediction_dataset_config=asdict(prediction_dataset_config), + # detection_dataset_config=asdict(detection_dataset_config), + segmentation_dataset_config=asdict(segmentation_dataset_config), + crop_size=[120, 120, 120], + post_processing="nucleus", + device=device, + use_seeds=True, +) + +# ## Initialize `experiment_config` + +# Lastly we initialize the `experiment_config` which contains the `inference_config` +# and `model_config` initialized above. + +experiment_config = ExperimentConfig( + inference_config=asdict(inference_config), + model_config=asdict(model_config), + normalization_factor=1.0, # since the test image is already normalized +) + +# Now we are ready to start the inference!!
+# (To see the output of the cell below, remove the first line `io.capture_output()`). + +# with io.capture_output() as captured: +infer(experiment_config) + +# ## Inspect predictions + +# Let's look at some of the predicted embeddings.
+# We will first load a glasbey-like color map to show individual cells +# with a unique color. + +urllib.request.urlretrieve( + "https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/cmap_60.npy", + "cmap_60.npy", +) +new_cmp = ListedColormap(np.load("cmap_60.npy")) + +# Change the value of `index` below to look at the raw image (left), +# x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the +# embedding (top-right). + +# + +index = 0 + +f = zarr.open(name + ".zarr") +ds = f["test/raw"] +ds2 = f["centered_embeddings"] + +slice = ds.shape[2] // 2 + +image = ds[index, 0, slice] +embedding = ds2[index, :, slice] + + +visualize_2d( + image, + top_right=embedding[-1], + bottom_left=embedding[0], + bottom_right=embedding[1], + top_right_label="UNCERTAINTY", + bottom_left_label="OFFSET_X", + bottom_right_label="OFFSET_Y", +) +# - + +# As you can see the magnitude of the uncertainty of the embedding (top-right) +# is low for most of the foreground cells.
This enables extraction +# of the foreground, which is eventually clustered into individual instances.
+# See bottom right figure for the final result. + +# + +f = zarr.open(name + ".zarr") +ds = f["test/raw"] +ds2 = f["detection"] +ds3 = f["segmentation"] + +visualize_2d( + image, + top_right=embedding[-1] < skimage.filters.threshold_otsu(embedding[-1]), + bottom_left=ds2[index, index, slice], + bottom_right=ds3[index, index, slice], + top_right_label="THRESHOLDED F.G.", + bottom_left_label="DETECTION", + bottom_right_label="SEGMENTATION", + top_right_cmap="gray", + bottom_left_cmap=new_cmp, + bottom_right_cmap=new_cmp, +) +# - diff --git a/mkdocs.yml b/mkdocs.yml index 607af7f..5ed2b1e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,10 +61,15 @@ nav: - Home: index.md - Getting Started: - Installation: getting-started.md - - 2D Example: - - 01: examples/2d/01-data.py - - 02: examples/2d/02-train.py - - 03: examples/2d/03-infer.py + - Examples: + - 2D: + - 01: examples/2d/01-data.py + - 02: examples/2d/02-train.py + - 03: examples/2d/03-infer.py + - 3D: + - 01: examples/3d/01-data.py + - 02: examples/3d/02-train.py + - 03: examples/3d/03-infer.py - API Reference: - Configs: - DatasetConfig: api/dataset_config.md