Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3d example #7

Merged
merged 18 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions cellulus/configs/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class InferenceConfig:

Configuration object produced by predict.py.

segmentation_dataset_config:
detection_dataset_config:

Configuration object produced by segment.py.
Configuration object produced by detect.py.

post_processed_dataset_config:
segmentation_dataset_config:

Configuration object produced by post_process.py.
Configuration object produced by segment.py.

evaluation_dataset_config:

Expand Down Expand Up @@ -72,6 +72,12 @@ class InferenceConfig:
How to cluster the embeddings?
Can be one of 'meanshift' or 'greedy'.

use_seeds (default = False):

If set to True, the local optima of the distance map from the
predicted object centers is used.
Else, seeds are determined by sklearn.cluster.MeanShift.

num_bandwidths (default = 1):

Number of bandwidths to obtain segmentations for.
Expand Down Expand Up @@ -118,11 +124,11 @@ class InferenceConfig:
default=None, converter=to_config(DatasetConfig)
)

segmentation_dataset_config: DatasetConfig = attrs.field(
detection_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)

post_processed_dataset_config: DatasetConfig = attrs.field(
segmentation_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)

Expand All @@ -139,6 +145,7 @@ class InferenceConfig:
clustering = attrs.field(
default="meanshift", validator=in_(["meanshift", "greedy"])
)
use_seeds = attrs.field(default=False, validator=instance_of(bool))
bandwidth = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(float))
)
Expand Down
23 changes: 15 additions & 8 deletions cellulus/datasets/zarr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,24 @@ def __yield_sample(self):

with gp.build(self.pipeline):
while True:
array_is_zero = True
# request one sample, all channels, plus crop dimensions
request = gp.BatchRequest()
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * self.num_dims, (1, self.num_channels, *self.crop_size)
while array_is_zero:
request = gp.BatchRequest()
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * self.num_dims,
(1, self.num_channels, *self.crop_size),
)
)
)

sample = self.pipeline.request_batch(request)
sample_data = sample[self.raw].data[0]
anchor_samples, reference_samples = self.sample_coordinates()
sample = self.pipeline.request_batch(request)
sample_data = sample[self.raw].data[0]
if np.max(sample_data) <= 0.0:
pass
else:
array_is_zero = False
anchor_samples, reference_samples = self.sample_coordinates()
yield sample_data, anchor_samples, reference_samples

def __read_meta_data(self):
Expand Down
192 changes: 192 additions & 0 deletions cellulus/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import numpy as np
import zarr
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
from skimage.filters import threshold_otsu
from tqdm import tqdm

from cellulus.configs.inference_config import InferenceConfig
from cellulus.datasets.meta_data import DatasetMetaData
from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d
from cellulus.utils.mean_shift import mean_shift_segmentation


def detect(inference_config: InferenceConfig) -> None:
dataset_config = inference_config.dataset_config
dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config)

f = zarr.open(inference_config.detection_dataset_config.container_path)
ds = f[inference_config.detection_dataset_config.secondary_dataset_name]

# prepare the zarr dataset to write to
f_detection = zarr.open(inference_config.detection_dataset_config.container_path)
ds_detection = f_detection.create_dataset(
inference_config.detection_dataset_config.dataset_name,
shape=(
dataset_meta_data.num_samples,
inference_config.num_bandwidths,
*dataset_meta_data.spatial_array,
),
dtype=np.uint16,
)

ds_detection.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
-dataset_meta_data.num_spatial_dims :
]
ds_detection.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
ds_detection.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims

# prepare the binary segmentation zarr dataset to write to
ds_binary_segmentation = f_detection.create_dataset(
"binary-segmentation",
shape=(
dataset_meta_data.num_samples,
1,
*dataset_meta_data.spatial_array,
),
dtype=np.uint16,
)

ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
-dataset_meta_data.num_spatial_dims :
]
ds_binary_segmentation.attrs["resolution"] = (
1,
) * dataset_meta_data.num_spatial_dims
ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims

# prepare the object centered embeddings zarr dataset to write to
ds_object_centered_embeddings = f_detection.create_dataset(
"centered-embeddings",
shape=(
dataset_meta_data.num_samples,
dataset_meta_data.num_spatial_dims + 1,
*dataset_meta_data.spatial_array,
),
dtype=float,
)

ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [
"t",
"z",
"y",
"x",
][-dataset_meta_data.num_spatial_dims :]
ds_object_centered_embeddings.attrs["resolution"] = (
1,
) * dataset_meta_data.num_spatial_dims
ds_object_centered_embeddings.attrs["offset"] = (
0,
) * dataset_meta_data.num_spatial_dims

for sample in tqdm(range(dataset_meta_data.num_samples)):
embeddings = ds[sample]
embeddings_std = embeddings[-1, ...]
embeddings_mean = embeddings[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
if inference_config.threshold is None:
threshold = threshold_otsu(embeddings_std)
else:
threshold = inference_config.threshold

print(f"For sample {sample}, binary threshold {threshold} was used.")
binary_mask = embeddings_std < threshold
ds_binary_segmentation[sample, 0, ...] = binary_mask

# find mean of embeddings
embeddings_centered = embeddings.copy()
embeddings_mean_masked = (
binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean
)
if embeddings_centered.shape[0] == 3:
c_x = embeddings_mean_masked[0, 0]
c_y = embeddings_mean_masked[0, 1]
c_x = c_x[c_x != 0].mean()
c_y = c_y[c_y != 0].mean()
embeddings_centered[0] -= c_x
embeddings_centered[1] -= c_y
elif embeddings_centered.shape[0] == 4:
c_x = embeddings_mean_masked[0, 0]
c_y = embeddings_mean_masked[0, 1]
c_z = embeddings_mean_masked[0, 2]
c_x = c_x[c_x != 0].mean()
c_y = c_y[c_y != 0].mean()
c_z = c_z[c_z != 0].mean()
embeddings_centered[0] -= c_x
embeddings_centered[1] -= c_y
embeddings_centered[2] -= c_z
ds_object_centered_embeddings[sample] = embeddings_centered

embeddings_centered_mean = embeddings_centered[
np.newaxis, : dataset_meta_data.num_spatial_dims
]
embeddings_centered_std = embeddings_centered[-1]

if inference_config.clustering == "meanshift":
for bandwidth_factor in range(inference_config.num_bandwidths):
if inference_config.use_seeds:
offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0)
offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2)
coordinates = peak_local_max(-offset_magnitude_smooth)
seeds = np.flip(coordinates, 1)
segmentation = mean_shift_segmentation(
embeddings_centered_mean,
embeddings_centered_std,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_size=inference_config.min_size,
reduction_probability=inference_config.reduction_probability,
threshold=threshold,
seeds=seeds,
)
embeddings_centered_mean = embeddings_centered[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
else:
segmentation = mean_shift_segmentation(
embeddings_mean,
embeddings_std,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_size=inference_config.min_size,
reduction_probability=inference_config.reduction_probability,
threshold=threshold,
seeds=None,
)
# Note that the line below is needed
# because the embeddings_mean is modified
# by mean_shift_segmentation
embeddings_mean = embeddings[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
ds_detection[sample, bandwidth_factor, ...] = segmentation
elif inference_config.clustering == "greedy":
if dataset_meta_data.num_spatial_dims == 3:
cluster3d = Cluster3d(
width=embeddings.shape[-1],
height=embeddings.shape[-2],
depth=embeddings.shape[-3],
fg_mask=binary_mask,
device=inference_config.device,
)
for bandwidth_factor in range(inference_config.num_bandwidths):
segmentation = cluster3d.cluster(
prediction=embeddings,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_object_size=inference_config.min_size,
)
ds_detection[sample, bandwidth_factor, ...] = segmentation
elif dataset_meta_data.num_spatial_dims == 2:
cluster2d = Cluster2d(
width=embeddings.shape[-1],
height=embeddings.shape[-2],
fg_mask=binary_mask,
device=inference_config.device,
)
for bandwidth_factor in range(inference_config.num_bandwidths):
segmentation = cluster2d.cluster(
prediction=embeddings,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_object_size=inference_config.min_size,
)

ds_detection[sample, bandwidth_factor, ...] = segmentation
12 changes: 6 additions & 6 deletions cellulus/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch

from cellulus.datasets.meta_data import DatasetMetaData
from cellulus.detect import detect
from cellulus.evaluate import evaluate
from cellulus.models import get_model
from cellulus.post_process import post_process
from cellulus.predict import predict
from cellulus.segment import segment

Expand Down Expand Up @@ -35,7 +35,7 @@ def infer(experiment_config):
)
elif dataset_meta_data.num_spatial_dims == 3:
inference_config.min_size = int(
0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3)
0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) / 8
)
# set model
model = get_model(
Expand Down Expand Up @@ -69,12 +69,12 @@ def infer(experiment_config):
# get predicted embeddings...
if inference_config.prediction_dataset_config is not None:
predict(model, inference_config, normalization_factor)
# ...turn them into a segmentation...
# ...turn them into a detection ...
if inference_config.detection_dataset_config is not None:
detect(inference_config)
# ...and post-process the detection to obtain an instance segmentation
if inference_config.segmentation_dataset_config is not None:
segment(inference_config)
# ...and post-process the segmentation
if inference_config.post_processed_dataset_config is not None:
post_process(inference_config)
# ...and evaluate if ground-truth exists
if inference_config.evaluation_dataset_config is not None:
evaluate(inference_config)
Loading
Loading