diff --git a/cellulus/configs/inference_config.py b/cellulus/configs/inference_config.py
index 707333f..2719378 100644
--- a/cellulus/configs/inference_config.py
+++ b/cellulus/configs/inference_config.py
@@ -22,13 +22,13 @@ class InferenceConfig:
Configuration object produced by predict.py.
- segmentation_dataset_config:
+ detection_dataset_config:
- Configuration object produced by segment.py.
+ Configuration object produced by detect.py.
- post_processed_dataset_config:
+ segmentation_dataset_config:
- Configuration object produced by post_process.py.
+ Configuration object produced by segment.py.
evaluation_dataset_config:
@@ -72,6 +72,12 @@ class InferenceConfig:
How to cluster the embeddings?
Can be one of 'meanshift' or 'greedy'.
+ use_seeds (default = False):
+
+ If set to True, the local optima of the distance map from the
+ predicted object centers is used.
+ Else, seeds are determined by sklearn.cluster.MeanShift.
+
num_bandwidths (default = 1):
Number of bandwidths to obtain segmentations for.
@@ -118,11 +124,11 @@ class InferenceConfig:
default=None, converter=to_config(DatasetConfig)
)
- segmentation_dataset_config: DatasetConfig = attrs.field(
+ detection_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)
- post_processed_dataset_config: DatasetConfig = attrs.field(
+ segmentation_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)
@@ -139,6 +145,7 @@ class InferenceConfig:
clustering = attrs.field(
default="meanshift", validator=in_(["meanshift", "greedy"])
)
+ use_seeds = attrs.field(default=False, validator=instance_of(bool))
bandwidth = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(float))
)
diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py
index 1a6ad78..a864d89 100644
--- a/cellulus/datasets/zarr_dataset.py
+++ b/cellulus/datasets/zarr_dataset.py
@@ -137,17 +137,24 @@ def __yield_sample(self):
with gp.build(self.pipeline):
while True:
+ array_is_zero = True
# request one sample, all channels, plus crop dimensions
- request = gp.BatchRequest()
- request[self.raw] = gp.ArraySpec(
- roi=gp.Roi(
- (0,) * self.num_dims, (1, self.num_channels, *self.crop_size)
+ while array_is_zero:
+ request = gp.BatchRequest()
+ request[self.raw] = gp.ArraySpec(
+ roi=gp.Roi(
+ (0,) * self.num_dims,
+ (1, self.num_channels, *self.crop_size),
+ )
)
- )
- sample = self.pipeline.request_batch(request)
- sample_data = sample[self.raw].data[0]
- anchor_samples, reference_samples = self.sample_coordinates()
+ sample = self.pipeline.request_batch(request)
+ sample_data = sample[self.raw].data[0]
+ if np.max(sample_data) <= 0.0:
+ pass
+ else:
+ array_is_zero = False
+ anchor_samples, reference_samples = self.sample_coordinates()
yield sample_data, anchor_samples, reference_samples
def __read_meta_data(self):
diff --git a/cellulus/detect.py b/cellulus/detect.py
new file mode 100644
index 0000000..6bcac19
--- /dev/null
+++ b/cellulus/detect.py
@@ -0,0 +1,192 @@
+import numpy as np
+import zarr
+from scipy.ndimage import gaussian_filter
+from skimage.feature import peak_local_max
+from skimage.filters import threshold_otsu
+from tqdm import tqdm
+
+from cellulus.configs.inference_config import InferenceConfig
+from cellulus.datasets.meta_data import DatasetMetaData
+from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d
+from cellulus.utils.mean_shift import mean_shift_segmentation
+
+
+def detect(inference_config: InferenceConfig) -> None:
+ dataset_config = inference_config.dataset_config
+ dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config)
+
+ f = zarr.open(inference_config.detection_dataset_config.container_path)
+ ds = f[inference_config.detection_dataset_config.secondary_dataset_name]
+
+ # prepare the zarr dataset to write to
+ f_detection = zarr.open(inference_config.detection_dataset_config.container_path)
+ ds_detection = f_detection.create_dataset(
+ inference_config.detection_dataset_config.dataset_name,
+ shape=(
+ dataset_meta_data.num_samples,
+ inference_config.num_bandwidths,
+ *dataset_meta_data.spatial_array,
+ ),
+ dtype=np.uint16,
+ )
+
+ ds_detection.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
+ -dataset_meta_data.num_spatial_dims :
+ ]
+ ds_detection.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
+ ds_detection.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
+
+ # prepare the binary segmentation zarr dataset to write to
+ ds_binary_segmentation = f_detection.create_dataset(
+ "binary-segmentation",
+ shape=(
+ dataset_meta_data.num_samples,
+ 1,
+ *dataset_meta_data.spatial_array,
+ ),
+ dtype=np.uint16,
+ )
+
+ ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
+ -dataset_meta_data.num_spatial_dims :
+ ]
+ ds_binary_segmentation.attrs["resolution"] = (
+ 1,
+ ) * dataset_meta_data.num_spatial_dims
+ ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
+
+ # prepare the object centered embeddings zarr dataset to write to
+ ds_object_centered_embeddings = f_detection.create_dataset(
+ "centered-embeddings",
+ shape=(
+ dataset_meta_data.num_samples,
+ dataset_meta_data.num_spatial_dims + 1,
+ *dataset_meta_data.spatial_array,
+ ),
+ dtype=float,
+ )
+
+ ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [
+ "t",
+ "z",
+ "y",
+ "x",
+ ][-dataset_meta_data.num_spatial_dims :]
+ ds_object_centered_embeddings.attrs["resolution"] = (
+ 1,
+ ) * dataset_meta_data.num_spatial_dims
+ ds_object_centered_embeddings.attrs["offset"] = (
+ 0,
+ ) * dataset_meta_data.num_spatial_dims
+
+ for sample in tqdm(range(dataset_meta_data.num_samples)):
+ embeddings = ds[sample]
+ embeddings_std = embeddings[-1, ...]
+ embeddings_mean = embeddings[
+ np.newaxis, : dataset_meta_data.num_spatial_dims, ...
+ ].copy()
+ if inference_config.threshold is None:
+ threshold = threshold_otsu(embeddings_std)
+ else:
+ threshold = inference_config.threshold
+
+ print(f"For sample {sample}, binary threshold {threshold} was used.")
+ binary_mask = embeddings_std < threshold
+ ds_binary_segmentation[sample, 0, ...] = binary_mask
+
+ # find mean of embeddings
+ embeddings_centered = embeddings.copy()
+ embeddings_mean_masked = (
+ binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean
+ )
+ if embeddings_centered.shape[0] == 3:
+ c_x = embeddings_mean_masked[0, 0]
+ c_y = embeddings_mean_masked[0, 1]
+ c_x = c_x[c_x != 0].mean()
+ c_y = c_y[c_y != 0].mean()
+ embeddings_centered[0] -= c_x
+ embeddings_centered[1] -= c_y
+ elif embeddings_centered.shape[0] == 4:
+ c_x = embeddings_mean_masked[0, 0]
+ c_y = embeddings_mean_masked[0, 1]
+ c_z = embeddings_mean_masked[0, 2]
+ c_x = c_x[c_x != 0].mean()
+ c_y = c_y[c_y != 0].mean()
+ c_z = c_z[c_z != 0].mean()
+ embeddings_centered[0] -= c_x
+ embeddings_centered[1] -= c_y
+ embeddings_centered[2] -= c_z
+ ds_object_centered_embeddings[sample] = embeddings_centered
+
+ embeddings_centered_mean = embeddings_centered[
+ np.newaxis, : dataset_meta_data.num_spatial_dims
+ ]
+ embeddings_centered_std = embeddings_centered[-1]
+
+ if inference_config.clustering == "meanshift":
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ if inference_config.use_seeds:
+ offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0)
+ offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2)
+ coordinates = peak_local_max(-offset_magnitude_smooth)
+ seeds = np.flip(coordinates, 1)
+ segmentation = mean_shift_segmentation(
+ embeddings_centered_mean,
+ embeddings_centered_std,
+ bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
+ min_size=inference_config.min_size,
+ reduction_probability=inference_config.reduction_probability,
+ threshold=threshold,
+ seeds=seeds,
+ )
+ embeddings_centered_mean = embeddings_centered[
+ np.newaxis, : dataset_meta_data.num_spatial_dims, ...
+ ].copy()
+ else:
+ segmentation = mean_shift_segmentation(
+ embeddings_mean,
+ embeddings_std,
+ bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
+ min_size=inference_config.min_size,
+ reduction_probability=inference_config.reduction_probability,
+ threshold=threshold,
+ seeds=None,
+ )
+ # Note that the line below is needed
+ # because the embeddings_mean is modified
+ # by mean_shift_segmentation
+ embeddings_mean = embeddings[
+ np.newaxis, : dataset_meta_data.num_spatial_dims, ...
+ ].copy()
+ ds_detection[sample, bandwidth_factor, ...] = segmentation
+ elif inference_config.clustering == "greedy":
+ if dataset_meta_data.num_spatial_dims == 3:
+ cluster3d = Cluster3d(
+ width=embeddings.shape[-1],
+ height=embeddings.shape[-2],
+ depth=embeddings.shape[-3],
+ fg_mask=binary_mask,
+ device=inference_config.device,
+ )
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ segmentation = cluster3d.cluster(
+ prediction=embeddings,
+ bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
+ min_object_size=inference_config.min_size,
+ )
+ ds_detection[sample, bandwidth_factor, ...] = segmentation
+ elif dataset_meta_data.num_spatial_dims == 2:
+ cluster2d = Cluster2d(
+ width=embeddings.shape[-1],
+ height=embeddings.shape[-2],
+ fg_mask=binary_mask,
+ device=inference_config.device,
+ )
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ segmentation = cluster2d.cluster(
+ prediction=embeddings,
+ bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
+ min_object_size=inference_config.min_size,
+ )
+
+ ds_detection[sample, bandwidth_factor, ...] = segmentation
diff --git a/cellulus/infer.py b/cellulus/infer.py
index 0ddff24..e15df55 100644
--- a/cellulus/infer.py
+++ b/cellulus/infer.py
@@ -4,9 +4,9 @@
import torch
from cellulus.datasets.meta_data import DatasetMetaData
+from cellulus.detect import detect
from cellulus.evaluate import evaluate
from cellulus.models import get_model
-from cellulus.post_process import post_process
from cellulus.predict import predict
from cellulus.segment import segment
@@ -35,7 +35,7 @@ def infer(experiment_config):
)
elif dataset_meta_data.num_spatial_dims == 3:
inference_config.min_size = int(
- 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3)
+ 0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) / 8
)
# set model
model = get_model(
@@ -69,12 +69,12 @@ def infer(experiment_config):
# get predicted embeddings...
if inference_config.prediction_dataset_config is not None:
predict(model, inference_config, normalization_factor)
- # ...turn them into a segmentation...
+ # ...turn them into a detection ...
+ if inference_config.detection_dataset_config is not None:
+ detect(inference_config)
+ # ...and post-process the detection to obtain an instance segmentation
if inference_config.segmentation_dataset_config is not None:
segment(inference_config)
- # ...and post-process the segmentation
- if inference_config.post_processed_dataset_config is not None:
- post_process(inference_config)
# ...and evaluate if ground-truth exists
if inference_config.evaluation_dataset_config is not None:
evaluate(inference_config)
diff --git a/cellulus/post_process.py b/cellulus/post_process.py
deleted file mode 100644
index 4c31994..0000000
--- a/cellulus/post_process.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import numpy as np
-import zarr
-from scipy.ndimage import binary_fill_holes, label
-from scipy.ndimage import distance_transform_edt as dtedt
-from skimage.filters import threshold_otsu
-from tqdm import tqdm
-
-from cellulus.configs.inference_config import InferenceConfig
-from cellulus.datasets.meta_data import DatasetMetaData
-from cellulus.utils.misc import size_filter
-
-
-def post_process(inference_config: InferenceConfig) -> None:
- # filter small objects, erosion, etc.
-
- dataset_config = inference_config.dataset_config
- dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config)
-
- f = zarr.open(inference_config.post_processed_dataset_config.container_path)
- ds = f[inference_config.post_processed_dataset_config.secondary_dataset_name]
-
- # prepare the zarr dataset to write to
- f_postprocessed = zarr.open(
- inference_config.post_processed_dataset_config.container_path
- )
- ds_postprocessed = f_postprocessed.create_dataset(
- inference_config.post_processed_dataset_config.dataset_name,
- shape=(
- dataset_meta_data.num_samples,
- inference_config.num_bandwidths,
- *dataset_meta_data.spatial_array,
- ),
- dtype=np.uint16,
- )
-
- ds_postprocessed.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
- -dataset_meta_data.num_spatial_dims :
- ]
- ds_postprocessed.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
- ds_postprocessed.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
-
- # remove halo
- if inference_config.post_processing == "cell":
- for sample in tqdm(range(dataset_meta_data.num_samples)):
- # first instance label masks are expanded by `grow_distance`
- # next, expanded instance label masks are shrunk by `shrink_distance`
- for bandwidth_factor in range(inference_config.num_bandwidths):
- segmentation = ds[sample, bandwidth_factor]
- distance_foreground = dtedt(segmentation == 0)
- expanded_mask = distance_foreground < inference_config.grow_distance
- distance_background = dtedt(expanded_mask)
- segmentation[distance_background < inference_config.shrink_distance] = 0
- ds_postprocessed[sample, bandwidth_factor, ...] = segmentation
- elif inference_config.post_processing == "nucleus":
- ds_raw = f[inference_config.dataset_config.dataset_name]
- for sample in tqdm(range(dataset_meta_data.num_samples)):
- for bandwidth_factor in range(inference_config.num_bandwidths):
- segmentation = ds[sample, bandwidth_factor]
- raw_image = ds_raw[sample, 0]
- ids = np.unique(segmentation)
- ids = ids[ids != 0]
- for id_ in ids:
- raw_image_masked = raw_image[segmentation == id_]
- threshold = threshold_otsu(raw_image_masked)
- mask = (segmentation == id_) & (raw_image > threshold)
- mask = binary_fill_holes(mask)
- if dataset_meta_data.num_spatial_dims == 2:
- y, x = np.where(mask)
- ds_postprocessed[sample, bandwidth_factor, y, x] = id_
- elif dataset_meta_data.num_spatial_dims == 3:
- z, y, x = np.where(mask)
- ds_postprocessed[sample, bandwidth_factor, z, y, x] = id_
-
- # remove non-connected components
- for bandwidth_factor in range(inference_config.num_bandwidths):
- ids = np.unique(ds_postprocessed[sample, bandwidth_factor])
- ids = ids[ids != 0]
- counter = np.max(ids) + 1
- for id_ in ids:
- ma_id = ds_postprocessed[sample, bandwidth_factor] == id_
- array, num_features = label(ma_id)
- if num_features > 1:
- ids_array = np.unique(array)
- ids_array = ids_array[ids_array != 0]
- for id_array in ids_array:
- if dataset_meta_data.num_spatial_dims == 2:
- y, x = np.where(array == id_array)
- ds_postprocessed[
- sample, bandwidth_factor, y, x
- ] = counter
- elif dataset_meta_data.num_spatial_dims == 3:
- z, y, x = np.where(array == id_array)
- ds_postprocessed[
- sample, bandwidth_factor, z, y, x
- ] = counter
- counter += 1
-
- # size filter - remove small objects
- for sample in tqdm(range(dataset_meta_data.num_samples)):
- for bandwidth_factor in range(inference_config.num_bandwidths):
- ds_postprocessed[sample, bandwidth_factor, ...] = size_filter(
- ds_postprocessed[sample, bandwidth_factor], inference_config.min_size
- )
diff --git a/cellulus/segment.py b/cellulus/segment.py
index f1ba70c..a4fcaf0 100644
--- a/cellulus/segment.py
+++ b/cellulus/segment.py
@@ -1,26 +1,27 @@
import numpy as np
import zarr
+from scipy.ndimage import binary_fill_holes
+from scipy.ndimage import distance_transform_edt as dtedt
from skimage.filters import threshold_otsu
from tqdm import tqdm
from cellulus.configs.inference_config import InferenceConfig
from cellulus.datasets.meta_data import DatasetMetaData
-from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d
-from cellulus.utils.mean_shift import mean_shift_segmentation
+from cellulus.utils.misc import size_filter
def segment(inference_config: InferenceConfig) -> None:
+ # filter small objects, erosion, etc.
+
dataset_config = inference_config.dataset_config
dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config)
f = zarr.open(inference_config.segmentation_dataset_config.container_path)
ds = f[inference_config.segmentation_dataset_config.secondary_dataset_name]
- # prepare the instance segmentation zarr dataset to write to
- f_segmentation = zarr.open(
- inference_config.segmentation_dataset_config.container_path
- )
- ds_segmentation = f_segmentation.create_dataset(
+ # prepare the zarr dataset to write to
+ f_segmented = zarr.open(inference_config.segmentation_dataset_config.container_path)
+ ds_segmented = f_segmented.create_dataset(
inference_config.segmentation_dataset_config.dataset_name,
shape=(
dataset_meta_data.num_samples,
@@ -30,140 +31,78 @@ def segment(inference_config: InferenceConfig) -> None:
dtype=np.uint16,
)
- ds_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
- -dataset_meta_data.num_spatial_dims :
- ]
- ds_segmentation.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
- ds_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
-
- # prepare the binary segmentation zarr dataset to write to
- ds_binary_segmentation = f_segmentation.create_dataset(
- "binary_" + inference_config.segmentation_dataset_config.dataset_name,
- shape=(
- dataset_meta_data.num_samples,
- 1,
- *dataset_meta_data.spatial_array,
- ),
- dtype=np.uint16,
- )
-
- ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
+ ds_segmented.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
-dataset_meta_data.num_spatial_dims :
]
- ds_binary_segmentation.attrs["resolution"] = (
- 1,
- ) * dataset_meta_data.num_spatial_dims
- ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
+ ds_segmented.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
+ ds_segmented.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims
- # prepare the object centered embeddings zarr dataset to write to
- ds_object_centered_embeddings = f_segmentation.create_dataset(
- "centered_"
- + inference_config.segmentation_dataset_config.secondary_dataset_name,
- shape=(
- dataset_meta_data.num_samples,
- dataset_meta_data.num_spatial_dims + 1,
- *dataset_meta_data.spatial_array,
- ),
- dtype=float,
- )
+ # remove halo
+ if inference_config.post_processing == "cell":
+ for sample in tqdm(range(dataset_meta_data.num_samples)):
+ # first instance label masks are expanded by `grow_distance`
+ # next, expanded instance label masks are shrunk by `shrink_distance`
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ segmentation = ds[sample, bandwidth_factor]
+ distance_foreground = dtedt(segmentation == 0)
+ expanded_mask = distance_foreground < inference_config.grow_distance
+ distance_background = dtedt(expanded_mask)
+ segmentation[distance_background < inference_config.shrink_distance] = 0
+ ds_segmented[sample, bandwidth_factor, ...] = segmentation
+ elif inference_config.post_processing == "nucleus":
+ ds_raw = f[inference_config.dataset_config.dataset_name]
+ for sample in tqdm(range(dataset_meta_data.num_samples)):
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ segmentation = ds[sample, bandwidth_factor]
+ raw_image = ds_raw[sample, 0]
+ ids = np.unique(segmentation)
+ ids = ids[ids != 0]
+ for id_ in ids:
+ segmentation_id_mask = segmentation == id_
+ if dataset_meta_data.num_spatial_dims == 2:
+ y, x = np.where(segmentation_id_mask)
+ y_min, y_max, x_min, x_max = (
+ np.min(y),
+ np.max(y),
+ np.min(x),
+ np.max(x),
+ )
+ elif dataset_meta_data.num_spatial_dims == 3:
+ z, y, x = np.where(segmentation_id_mask)
+ z_min, z_max, y_min, y_max, x_min, x_max = (
+ np.min(z),
+ np.max(z),
+ np.min(y),
+ np.max(y),
+ np.min(x),
+ np.max(x),
+ )
+ raw_image_masked = raw_image[segmentation_id_mask]
+ threshold = threshold_otsu(raw_image_masked)
+ mask = segmentation_id_mask & (raw_image > threshold)
- ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [
- "t",
- "z",
- "y",
- "x",
- ][-dataset_meta_data.num_spatial_dims :]
- ds_object_centered_embeddings.attrs["resolution"] = (
- 1,
- ) * dataset_meta_data.num_spatial_dims
- ds_object_centered_embeddings.attrs["offset"] = (
- 0,
- ) * dataset_meta_data.num_spatial_dims
+ if dataset_meta_data.num_spatial_dims == 2:
+ mask_small = binary_fill_holes(
+ mask[y_min : y_max + 1, x_min : x_max + 1]
+ )
+ mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small
+ y, x = np.where(mask)
+ ds_segmented[sample, bandwidth_factor, y, x] = id_
+ elif dataset_meta_data.num_spatial_dims == 3:
+ mask_small = binary_fill_holes(
+ mask[
+ z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1
+ ]
+ )
+ mask[
+ z_min : z_max + 1, y_min : y_max + 1, x_min : x_max + 1
+ ] = mask_small
+ z, y, x = np.where(mask)
+ ds_segmented[sample, bandwidth_factor, z, y, x] = id_
+ # size filter - remove small objects
for sample in tqdm(range(dataset_meta_data.num_samples)):
- embeddings = ds[sample]
- embeddings_std = embeddings[-1, ...]
- embeddings_mean = embeddings[
- np.newaxis, : dataset_meta_data.num_spatial_dims, ...
- ].copy()
- if inference_config.threshold is None:
- threshold = threshold_otsu(embeddings_std)
- else:
- threshold = inference_config.threshold
-
- print(f"For sample {sample}, binary threshold {threshold} was used.")
- binary_mask = embeddings_std < threshold
- ds_binary_segmentation[sample, 0, ...] = binary_mask
-
- # find mean of embeddings
- embeddings_centered = embeddings.copy()
- embeddings_mean_masked = (
- binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean
- )
- if embeddings_centered.shape[0] == 3:
- c_x = embeddings_mean_masked[0, 0]
- c_y = embeddings_mean_masked[0, 1]
- c_x = c_x[c_x != 0].mean()
- c_y = c_y[c_y != 0].mean()
- embeddings_centered[0] -= c_x
- embeddings_centered[1] -= c_y
- elif embeddings_centered.shape[0] == 4:
- c_x = embeddings_mean_masked[0, 0]
- c_y = embeddings_mean_masked[0, 1]
- c_z = embeddings_mean_masked[0, 2]
- c_x = c_x[c_x != 0].mean()
- c_y = c_y[c_y != 0].mean()
- c_z = c_z[c_z != 0].mean()
- embeddings_centered[0] -= c_x
- embeddings_centered[1] -= c_y
- embeddings_centered[2] -= c_z
- ds_object_centered_embeddings[sample] = embeddings_centered
-
- if inference_config.clustering == "meanshift":
- for bandwidth_factor in range(inference_config.num_bandwidths):
- segmentation = mean_shift_segmentation(
- embeddings_mean,
- embeddings_std,
- bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
- min_size=inference_config.min_size,
- reduction_probability=inference_config.reduction_probability,
- threshold=threshold,
- )
- # Note that the line below is needed
- # because the embeddings_mean is modified
- # by mean_shift_segmentation
- embeddings_mean = embeddings[
- np.newaxis, : dataset_meta_data.num_spatial_dims, ...
- ].copy()
- ds_segmentation[sample, bandwidth_factor, ...] = segmentation
- elif inference_config.clustering == "greedy":
- if dataset_meta_data.num_spatial_dims == 3:
- cluster3d = Cluster3d(
- width=embeddings.shape[-1],
- height=embeddings.shape[-2],
- depth=embeddings.shape[-3],
- fg_mask=binary_mask,
- device=inference_config.device,
- )
- for bandwidth_factor in range(inference_config.num_bandwidths):
- segmentation = cluster3d.cluster(
- prediction=embeddings,
- bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
- min_object_size=inference_config.min_size,
- )
- ds_segmentation[sample, bandwidth_factor, ...] = segmentation
- elif dataset_meta_data.num_spatial_dims == 2:
- cluster2d = Cluster2d(
- width=embeddings.shape[-1],
- height=embeddings.shape[-2],
- fg_mask=binary_mask,
- device=inference_config.device,
- )
- for bandwidth_factor in range(inference_config.num_bandwidths):
- segmentation = cluster2d.cluster(
- prediction=embeddings,
- bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
- min_object_size=inference_config.min_size,
- )
-
- ds_segmentation[sample, bandwidth_factor, ...] = segmentation
+ for bandwidth_factor in range(inference_config.num_bandwidths):
+ ds_segmented[sample, bandwidth_factor, ...] = size_filter(
+ ds_segmented[sample, bandwidth_factor], inference_config.min_size
+ )
diff --git a/cellulus/utils/mean_shift.py b/cellulus/utils/mean_shift.py
index bdf6012..a51e488 100644
--- a/cellulus/utils/mean_shift.py
+++ b/cellulus/utils/mean_shift.py
@@ -4,7 +4,13 @@
def mean_shift_segmentation(
- embedding_mean, embedding_std, bandwidth, min_size, reduction_probability, threshold
+ embedding_mean,
+ embedding_std,
+ bandwidth,
+ min_size,
+ reduction_probability,
+ threshold,
+ seeds,
):
embedding_mean = torch.from_numpy(embedding_mean)
if embedding_mean.ndim == 4:
@@ -34,22 +40,28 @@ def mean_shift_segmentation(
mask=mask,
reduction_probability=reduction_probability,
cluster_all=False,
+ seeds=seeds,
)[0]
return segmentation
def segment_with_meanshift(
- embedding, bandwidth, mask, reduction_probability, cluster_all
+ embedding, bandwidth, mask, reduction_probability, cluster_all, seeds
):
anchor_mean_shift = AnchorMeanshift(
- bandwidth, reduction_probability=reduction_probability, cluster_all=cluster_all
+ bandwidth,
+ reduction_probability=reduction_probability,
+ cluster_all=cluster_all,
+ seeds=seeds,
)
return anchor_mean_shift(embedding, mask=mask) + 1
class AnchorMeanshift:
- def __init__(self, bandwidth, reduction_probability, cluster_all):
- self.mean_shift = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
+ def __init__(self, bandwidth, reduction_probability, cluster_all, seeds):
+ self.mean_shift = MeanShift(
+ bandwidth=bandwidth, cluster_all=cluster_all, seeds=seeds
+ )
self.reduction_probability = reduction_probability
def compute_mean_shift(self, X):
diff --git a/cellulus/utils/misc.py b/cellulus/utils/misc.py
index 2f77c72..f3edfbc 100644
--- a/cellulus/utils/misc.py
+++ b/cellulus/utils/misc.py
@@ -13,15 +13,16 @@ def size_filter(segmentation, min_size, filter_non_connected=True):
return segmentation
if filter_non_connected:
- filter_labels = measure.label(segmentation, background=0)
+ filter_labels = measure.label(segmentation)
else:
filter_labels = segmentation
+
ids, sizes = np.unique(filter_labels, return_counts=True)
filter_ids = ids[sizes < min_size]
mask = np.in1d(filter_labels, filter_ids).reshape(filter_labels.shape)
segmentation[mask] = 0
- return segmentation
+ return measure.label(segmentation)
def extract_data(zip_url, data_dir, project_name):
diff --git a/docs/examples/2d/03-infer.py b/docs/examples/2d/03-infer.py
index 1e9d20e..1128840 100644
--- a/docs/examples/2d/03-infer.py
+++ b/docs/examples/2d/03-infer.py
@@ -38,15 +38,15 @@
prediction_dataset_config = DatasetConfig(
container_path=name + ".zarr", dataset_name="embeddings"
)
-segmentation_dataset_config = DatasetConfig(
+detection_dataset_config = DatasetConfig(
container_path=name + ".zarr",
- dataset_name="segmentation",
+ dataset_name="detection",
secondary_dataset_name="embeddings",
)
-post_processed_dataset_config = DatasetConfig(
+segmentation_dataset_config = DatasetConfig(
container_path=name + ".zarr",
- dataset_name="post_processed_segmentation",
- secondary_dataset_name="segmentation",
+ dataset_name="segmentation",
+ secondary_dataset_name="detection",
)
# ## Specify config values for the model
@@ -86,7 +86,7 @@
# We initialize the `inference_config` which contains our
# `embeddings_dataset_config`, `segmentation_dataset_config` and
-# `post_processed_dataset_config`.
+# `post_processed_dataset_config`.
# We set post_processing to one of `cell` or `nucleus`, depending on if we
# would like the cell membrane to be segmented or the nucleus.
@@ -95,8 +95,8 @@
inference_config = InferenceConfig(
dataset_config=asdict(dataset_config),
prediction_dataset_config=asdict(prediction_dataset_config),
+ detection_dataset_config=asdict(detection_dataset_config),
segmentation_dataset_config=asdict(segmentation_dataset_config),
- post_processed_dataset_config=asdict(post_processed_dataset_config),
post_processing=post_processing,
device=device,
)
@@ -109,7 +109,7 @@
experiment_config = ExperimentConfig(
inference_config=asdict(inference_config),
model_config=asdict(model_config),
- normalization_factor=1.0,
+ normalization_factor=1.0, # since the data was already normalized.
)
# Now we are ready to start the inference!!
@@ -139,7 +139,7 @@
f = zarr.open(name + ".zarr")
ds = f["train/raw"]
-ds2 = f["centered_embeddings"]
+ds2 = f["centered-embeddings"]
image = ds[index, 0]
embedding = ds2[index]
@@ -163,8 +163,8 @@
# +
f = zarr.open(name + ".zarr")
ds = f["train/raw"]
-ds2 = f["segmentation"]
-ds3 = f["post_processed_segmentation"]
+ds2 = f["detection"]
+ds3 = f["segmentation"]
visualize_2d(
image,
@@ -172,9 +172,10 @@
bottom_left=ds2[index, 0],
bottom_right=ds3[index, 0],
top_right_label="THRESHOLDED F.G.",
- bottom_left_label="SEGMENTATION",
- bottom_right_label="POSTPROCESSED",
+ bottom_left_label="DETECTION",
+ bottom_right_label="SEGMENTATION",
top_right_cmap="gray",
bottom_left_cmap=new_cmp,
bottom_right_cmap=new_cmp,
)
+# -
diff --git a/docs/examples/3d/01-data.py b/docs/examples/3d/01-data.py
new file mode 100644
index 0000000..2a3825b
--- /dev/null
+++ b/docs/examples/3d/01-data.py
@@ -0,0 +1,62 @@
+# # Download Data
+
+# In this notebook, we will download data and convert it to a zarr dataset.
+
+# For demonstration, we will use one image from the `Platynereis-Nuclei-CBG` dataset
+# provided
+# with [this](https://www.sciencedirect.com/science/article/pii/S1361841522001700)
+# publication.
+
+# Firstly, the `tif` raw images are downloaded to a directory indicated by `data_dir`.
+
+from pathlib import Path
+
+# +
+import numpy as np
+import tifffile
+import zarr
+from cellulus.utils.misc import extract_data
+from csbdeep.utils import normalize
+from skimage.transform import rescale
+from tqdm import tqdm
+
+# +
+name = "3d-data-demo"
+data_dir = "./data"
+
+extract_data(
+ zip_url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/3d-data-demo.zip",
+ data_dir=data_dir,
+ project_name=name,
+)
+# -
+
+# Currently, `cellulus` expects that the images are isotropic (i.e. the
+# voxel size along z dimension (which is usually undersampled) is the same
+# as the voxel size alng the x and y dimensions).
+# This dataset has a step size of $2.031 \mu m$ in z and $0.406 \mu m$ along
+# x and y dimensions, thus, the upsampling factor (which we refer to as
+# `anisotropy` equals $2.031/0.406$.
+# These raw images are upsampled, intensity-normalized and appended in a list.
+# Here, we use the percentile normalization technique.
+
+anisotropy = 2.031 / 0.406
+
+container_path = zarr.open(name + ".zarr")
+subsets = ["train", "test"]
+for subset in subsets:
+ dataset_name = subset + "/raw"
+ image_filenames = sorted((Path(data_dir) / name / subset).glob("*.tif"))
+ print(f"Number of raw images in {subset} directory is {len(image_filenames)}")
+ image_list = []
+
+ for i in tqdm(range(len(image_filenames))):
+ im = tifffile.imread(image_filenames[i]).astype(np.float32)
+ im_normalized = normalize(im, 1, 99.8, axis=(0, 1, 2))
+ im_rescaled = rescale(im_normalized, (anisotropy, 1.0, 1.0))
+ image_list.append(im_rescaled[np.newaxis, ...])
+
+ image_list = np.asarray(image_list)
+ container_path[dataset_name] = image_list
+ container_path[dataset_name].attrs["resolution"] = (1, 1, 1)
+ container_path[dataset_name].attrs["axis_names"] = ("s", "c", "z", "y", "x")
diff --git a/docs/examples/3d/02-train.py b/docs/examples/3d/02-train.py
new file mode 100644
index 0000000..017525c
--- /dev/null
+++ b/docs/examples/3d/02-train.py
@@ -0,0 +1,80 @@
+# # Train Model
+
+# In this notebook, we will train a `cellulus` model.
+
+from attrs import asdict
+from cellulus.configs.dataset_config import DatasetConfig
+from cellulus.configs.experiment_config import ExperimentConfig
+from cellulus.configs.model_config import ModelConfig
+from cellulus.configs.train_config import TrainConfig
+
+# ## Specify config values for dataset
+
+# In the next cell, we specify the name of the zarr container and the dataset
+# within it from which data would be read.
+
+name = "3d-data-demo"
+dataset_name = "train/raw"
+
+train_data_config = DatasetConfig(
+ container_path=name + ".zarr", dataset_name=dataset_name
+)
+
+# ## Specify config values for model
+
+# In the next cell, we specify the number of feature maps (`num_fmaps`) in the
+# first layer in our model.
+# Additionally, we specify `fmap_inc_factor` and `downsampling_factors`, which
+# indicates by how much the number of feature maps increase between adjacent
+# layers, and how much the spatial extents of the image gets downsampled between
+# adjacent layers respectively.
+
+num_fmaps = 24
+fmap_inc_factor = 3
+downsampling_factors = [
+ [2, 2, 2],
+]
+
+model_config = ModelConfig(
+ num_fmaps=num_fmaps,
+ fmap_inc_factor=fmap_inc_factor,
+ downsampling_factors=downsampling_factors,
+)
+
+# ## Specify config values for the training process
+
+# Then, we specify training-specific parameters such as the `device`, which
+# indicates the actual device to run the training on.
+# We also specify the `crop_size`. Mini - batches of crops are shown to the model
+# during training.
+#
The device could be set equal to `cuda:n` (where `n` is the index of
+# the GPU, for e.g. `cuda:0`) or `cpu`.
+# We set the `max_iterations` equal to `5000` for demonstration purposes.
+
+device = "cuda:0"
+max_iterations = 5000
+crop_size = [80, 80, 80]
+
+train_config = TrainConfig(
+ train_data_config=asdict(train_data_config),
+ device=device,
+ max_iterations=max_iterations,
+ crop_size=crop_size,
+)
+
+# Next, we initialize the experiment config which puts together the config
+# objects (`train_config` and `model_config`) which we defined above.
+
+experiment_config = ExperimentConfig(
+ train_config=asdict(train_config),
+ model_config=asdict(model_config),
+ normalization_factor=1.0, # since we already normalized in previous notebook
+)
+
+# Now we can begin the training!
+# Uncomment the next two lines to train the model.
+
+# +
+# from cellulus.train import train
+# train(experiment_config)
+# -
diff --git a/docs/examples/3d/03-infer.py b/docs/examples/3d/03-infer.py
new file mode 100644
index 0000000..236934c
--- /dev/null
+++ b/docs/examples/3d/03-infer.py
@@ -0,0 +1,186 @@
+# # Infer using Trained Model
+
+# In this notebook, we will use the `cellulus` model trained in the previous
+# step to obtain instance segmentations.
+
+import urllib
+import zipfile
+
+import numpy as np
+import skimage
+import torch
+import zarr
+from attrs import asdict
+from cellulus.configs.dataset_config import DatasetConfig
+from cellulus.configs.experiment_config import ExperimentConfig
+from cellulus.configs.inference_config import InferenceConfig
+from cellulus.configs.model_config import ModelConfig
+from cellulus.infer import infer
+from cellulus.utils.misc import visualize_2d
+from matplotlib.colors import ListedColormap
+
+# ## Specify config values for datasets
+
+# We again specify `name` of the zarr container, and `dataset_name` which
+# identifies the path to the raw image data, which needs to be segmented.
+
+name = "3d-data-demo"
+dataset_name = "test/raw"
+
+# We initialize the `dataset_config` which relates to the raw image data,
+# `prediction_dataset_config` which relates to the per-pixel embeddings and the
+# uncertainty, the `segmentation_dataset_config` which relates to the segmentations
+# post the mean-shift clustering and the `post_processed_config` which relates
+# to the segmentations after some post-processing.
+
+dataset_config = DatasetConfig(container_path=name + ".zarr", dataset_name=dataset_name)
+prediction_dataset_config = DatasetConfig(
+ container_path=name + ".zarr", dataset_name="embeddings"
+)
+detection_dataset_config = DatasetConfig(
+ container_path=name + ".zarr",
+ dataset_name="detection",
+ secondary_dataset_name="embeddings",
+)
+segmentation_dataset_config = DatasetConfig(
+ container_path=name + ".zarr",
+ dataset_name="segmentation",
+ secondary_dataset_name="detection",
+)
+
+# ## Specify config values for the model
+
+# We must also specify the `num_fmaps`, `fmap_inc_factor` (use same values as
+# in the training step) and set `checkpoint` equal to `models/best_loss.pth`
+# (best in terms of the lowest loss obtained).
+
+# Here, we download a pretrained model trained by us for `5e3` iterations.
+# But please comment the next cell to use your own trained model, which
+# should be available in the `models` directory.
+
+torch.hub.download_url_to_file(
+ url="https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/2d-demo-model.zip",
+ dst="pretrained_model",
+ progress=True,
+)
+with zipfile.ZipFile("pretrained_model", "r") as zip_ref:
+ zip_ref.extractall("")
+
+num_fmaps = 24
+fmap_inc_factor = 3
+downsampling_factors = [
+ [2, 2, 2],
+]
+checkpoint = "models/best_loss.pth"
+
+model_config = ModelConfig(
+ num_fmaps=num_fmaps,
+ fmap_inc_factor=fmap_inc_factor,
+ downsampling_factors=downsampling_factors,
+ checkpoint=checkpoint,
+)
+
+# ## Initialize `inference_config`
+
+# Then, we specify inference-specific parameters such as the `device`, which
+# indicates the actual device to run the inference on.
+#
The device could be set equal to `cuda:n` (where `n` is the index of the
+# GPU, for e.g. `cuda:0`), `cpu` or `mps`.
+
+device = "cuda:0"
+
+# We initialize the `inference_config` which contains our `embeddings_dataset_config`,
+# `segmentation_dataset_config` and `post_processed_dataset_config`.
+
+inference_config = InferenceConfig(
+ dataset_config=asdict(dataset_config),
+ # prediction_dataset_config=asdict(prediction_dataset_config),
+ # detection_dataset_config=asdict(detection_dataset_config),
+ segmentation_dataset_config=asdict(segmentation_dataset_config),
+ crop_size=[120, 120, 120],
+ post_processing="nucleus",
+ device=device,
+ use_seeds=True,
+)
+
+# ## Initialize `experiment_config`
+
+# Lastly we initialize the `experiment_config` which contains the `inference_config`
+# and `model_config` initialized above.
+
+experiment_config = ExperimentConfig(
+ inference_config=asdict(inference_config),
+ model_config=asdict(model_config),
+ normalization_factor=1.0, # since the test image is already normalized
+)
+
+# Now we are ready to start the inference!!
+# (To see the output of the cell below, remove the first line `io.capture_output()`).
+
+# with io.capture_output() as captured:
+infer(experiment_config)
+
+# ## Inspect predictions
+
+# Let's look at some of the predicted embeddings.
+# We will first load a glasbey-like color map to show individual cells
+# with a unique color.
+
+urllib.request.urlretrieve(
+ "https://github.com/funkelab/cellulus/releases/download/v0.0.1-tag/cmap_60.npy",
+ "cmap_60.npy",
+)
+new_cmp = ListedColormap(np.load("cmap_60.npy"))
+
+# Change the value of `index` below to look at the raw image (left),
+# x-offset (bottom-left), y-offset (bottom-right) and uncertainty of the
+# embedding (top-right).
+
+# +
+index = 0
+
+f = zarr.open(name + ".zarr")
+ds = f["test/raw"]
+ds2 = f["centered_embeddings"]
+
+slice = ds.shape[2] // 2
+
+image = ds[index, 0, slice]
+embedding = ds2[index, :, slice]
+
+
+visualize_2d(
+ image,
+ top_right=embedding[-1],
+ bottom_left=embedding[0],
+ bottom_right=embedding[1],
+ top_right_label="UNCERTAINTY",
+ bottom_left_label="OFFSET_X",
+ bottom_right_label="OFFSET_Y",
+)
+# -
+
+# As you can see the magnitude of the uncertainty of the embedding (top-right)
+# is low for most of the foreground cells.
This enables extraction
+# of the foreground, which is eventually clustered into individual instances.
+# See bottom right figure for the final result.
+
+# +
+f = zarr.open(name + ".zarr")
+ds = f["test/raw"]
+ds2 = f["detection"]
+ds3 = f["segmentation"]
+
+visualize_2d(
+ image,
+ top_right=embedding[-1] < skimage.filters.threshold_otsu(embedding[-1]),
+ bottom_left=ds2[index, index, slice],
+ bottom_right=ds3[index, index, slice],
+ top_right_label="THRESHOLDED F.G.",
+ bottom_left_label="DETECTION",
+ bottom_right_label="SEGMENTATION",
+ top_right_cmap="gray",
+ bottom_left_cmap=new_cmp,
+ bottom_right_cmap=new_cmp,
+)
+# -
diff --git a/mkdocs.yml b/mkdocs.yml
index 607af7f..5ed2b1e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -61,10 +61,15 @@ nav:
- Home: index.md
- Getting Started:
- Installation: getting-started.md
- - 2D Example:
- - 01: examples/2d/01-data.py
- - 02: examples/2d/02-train.py
- - 03: examples/2d/03-infer.py
+ - Examples:
+ - 2D:
+ - 01: examples/2d/01-data.py
+ - 02: examples/2d/02-train.py
+ - 03: examples/2d/03-infer.py
+ - 3D:
+ - 01: examples/3d/01-data.py
+ - 02: examples/3d/02-train.py
+ - 03: examples/3d/03-infer.py
- API Reference:
- Configs:
- DatasetConfig: api/dataset_config.md