Skip to content

Commit

Permalink
Introduce CellposeSegmentationChannel for channel inputs in `cellpo…
Browse files Browse the repository at this point in the history
…se_segmentation` (ref #412)
  • Loading branch information
tcompa committed Jun 12, 2023
1 parent 1fc0f69 commit cd1ad49
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 78 deletions.
44 changes: 25 additions & 19 deletions fractal_tasks_core/__FRACTAL_MANIFEST__.json
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,23 @@
{
"args_schema": {
"additionalProperties": false,
"definitions": {
"CellposeSegmentationChannel": {
"description": "TBD",
"properties": {
"label": {
"title": "Label",
"type": "string"
},
"wavelength_id": {
"title": "Wavelength Id",
"type": "string"
}
},
"title": "CellposeSegmentationChannel",
"type": "object"
}
},
"properties": {
"anisotropy": {
"description": "Ratio of the pixel sizes along Z and XY axis (ignored if the image is not three-dimensional). If `None`, it is inferred from the OME-NGFF metadata.",
Expand All @@ -339,15 +356,13 @@
"title": "Cellprob Threshold",
"type": "number"
},
"channel_label": {
"description": "Identifier of a channel based on its label (e.g. ``DAPI``). If not ``None``, then ``wavelength_id`` must be ``None``.",
"title": "Channel Label",
"type": "string"
"channel": {
"$ref": "#/definitions/CellposeSegmentationChannel",
"description": "TBD wavelength_id: Identifier of a channel based on the wavelength (e.g. ``A01_C01``). If not ``None``, then ``label` must be ``None``. Identifier of a channel based on its label (e.g. ``DAPI``). If not ``None``, then ``wavelength_id`` must be ``None``."
},
"channel_label_c2": {
"description": "Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker.",
"title": "Channel Label C2",
"type": "string"
"channel2": {
"$ref": "#/definitions/CellposeSegmentationChannel",
"description": "TBD Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker. Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker."
},
"component": {
"description": "Path to the OME-Zarr image in the OME-Zarr plate that is processed. Example: \"some_plate.zarr/B/03/0\" (standard argument for Fractal tasks, managed by Fractal server)",
Expand Down Expand Up @@ -445,24 +460,15 @@
"description": "If ``True``, try to use masked loading and fall back to ``use_masks=False`` if the ROI table is not suitable. Masked loading is relevant when only a subset of the bounding box should actually be processed (e.g. running within organoid_ROI_table).",
"title": "Use Masks",
"type": "boolean"
},
"wavelength_id": {
"description": "Identifier of a channel based on the wavelength (e.g. ``A01_C01``). If not ``None``, then ``channel_label` must be ``None``.",
"title": "Wavelength Id",
"type": "string"
},
"wavelength_id_c2": {
"description": "Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker.",
"title": "Wavelength Id C2",
"type": "string"
}
},
"required": [
"input_paths",
"output_path",
"component",
"metadata",
"level"
"level",
"channel"
],
"title": "CellposeSegmentation",
"type": "object"
Expand Down
65 changes: 65 additions & 0 deletions fractal_tasks_core/tasks/_input_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Copyright 2022 (C)
Friedrich Miescher Institute for Biomedical Research and
University of Zurich
Original authors:
Tommaso Comparin <[email protected]>
This file is part of Fractal and was originally developed by eXact lab
S.r.l. <exact-lab.it> under contract with Liberali Lab from the Friedrich
Miescher Institute for Biomedical Research and Pelkmans Lab from the
University of Zurich.
Pydantic models for some task parameters
"""
from typing import Optional

from pydantic import BaseModel
from pydantic import validator


class BaseChannel(BaseModel):
"""
TBD
"""

wavelength_id: Optional[str] = None
label: Optional[str] = None

@validator("label", always=True)
def mutually_exclusive_channel_attributes(cls, v, values):
"""
If `label` is set, then `wavelength_id` must be `None`
"""
wavelength_id = values.get("wavelength_id")
label = v

if wavelength_id is not None and v is not None:
raise ValueError(
"`wavelength_id` and `label` cannot be both set "
f"(given {wavelength_id=} and {label=})."
)

if wavelength_id is None and v is None:
raise ValueError(
"`wavelength_id` and `label` cannot be both `None`"
)

return v


class CellposeSegmentationChannel(BaseChannel):
"""
TBD
"""

pass


class NapariWorkflowsChannel(BaseModel):
"""
TBD
"""

pass
82 changes: 36 additions & 46 deletions fractal_tasks_core/tasks/cellpose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pydantic.decorator import validate_arguments

import fractal_tasks_core
from fractal_tasks_core.lib_channels import Channel
from fractal_tasks_core.lib_channels import ChannelNotFoundError
from fractal_tasks_core.lib_channels import get_channel_from_image_zarr
from fractal_tasks_core.lib_masked_loading import masked_loading_wrapper
Expand All @@ -51,6 +52,7 @@
from fractal_tasks_core.lib_ROI_overlaps import get_overlapping_pairs_3D
from fractal_tasks_core.lib_zattrs_utils import extract_zyx_pixel_sizes
from fractal_tasks_core.lib_zattrs_utils import rescale_datasets
from fractal_tasks_core.tasks._input_models import CellposeSegmentationChannel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,10 +153,8 @@ def cellpose_segmentation(
metadata: Dict[str, Any],
# Task-specific arguments
level: int,
wavelength_id: Optional[str] = None,
channel_label: Optional[str] = None,
wavelength_id_c2: Optional[str] = None,
channel_label_c2: Optional[str] = None,
channel: CellposeSegmentationChannel,
channel2: Optional[CellposeSegmentationChannel] = None,
input_ROI_table: str = "FOV_ROI_table",
output_ROI_table: Optional[str] = None,
output_label_name: Optional[str] = None,
Expand Down Expand Up @@ -203,24 +203,23 @@ def cellpose_segmentation(
managed by Fractal server)
:param level: Pyramid level of the image to be segmented. Choose 0 to
process at full resolution.
:param wavelength_id: Identifier of a channel based on the
wavelength (e.g. ``A01_C01``). If not ``None``, then
``channel_label` must be ``None``.
:param channel_label: Identifier of a channel based on its label (e.g.
``DAPI``). If not ``None``, then ``wavelength_id``
must be ``None``.
:param wavelength_id_c2: Identifier of a second channel in the same format
as the first wavelength_id. If specified, cellpose
runs in dual channel mode. For dual channel
segmentation of cells, the first channel should
contain the membrane marker, the second channel
should contain the nuclear marker.
:param channel_label_c2: Identifier of a second channel in the same
format as the first wavelength_id. If specified,
cellpose runs in dual channel mode. For dual
channel segmentation of cells, the first channel
should contain the membrane marker, the second
channel should contain the nuclear marker.
:param channel: TBD
wavelength_id: Identifier of a channel based on the
wavelength (e.g. ``A01_C01``). If not ``None``, then
``label` must be ``None``. Identifier of a channel
based on its label (e.g. ``DAPI``). If not ``None``, then
``wavelength_id`` must be ``None``.
:param channel2: TBD
Identifier of a second channel in the same format as the
first wavelength_id. If specified, cellpose runs in dual
channel mode. For dual channel segmentation of cells, the
first channel should contain the membrane marker, the
second channel should contain the nuclear marker.
Identifier of a second channel in the same format as the
first wavelength_id. If specified, cellpose runs in dual
channel mode. For dual channel segmentation of cells, the
first channel should contain the membrane marker, the
second channel should contain the nuclear marker.
:param input_ROI_table: Name of the ROI table over which the task loops
to apply Cellpose segmentation.
Example: "FOV_ROI_table" => loop over the field of
Expand Down Expand Up @@ -294,15 +293,6 @@ def cellpose_segmentation(
zarrurl = (in_path.resolve() / component).as_posix()
logger.info(f"{zarrurl=}")

# Preliminary check
if (channel_label is None and wavelength_id is None) or (
channel_label and wavelength_id
):
raise ValueError(
f"One and only one of {channel_label=} and "
f"{wavelength_id=} arguments must be provided"
)

# Preliminary checks on Cellpose model
if pretrained_model is None:
if model_type not in models.MODEL_NAMES:
Expand All @@ -319,49 +309,49 @@ def cellpose_segmentation(

# Find channel index
try:
channel = get_channel_from_image_zarr(
tmp_channel: Channel = get_channel_from_image_zarr(
image_zarr_path=zarrurl,
wavelength_id=wavelength_id,
label=channel_label,
wavelength_id=channel.wavelength_id,
label=channel.label,
)
except ChannelNotFoundError as e:
logger.warning(
"Channel not found, exit from the task.\n"
f"Original error: {str(e)}"
)
return {}
ind_channel = channel.index
ind_channel = tmp_channel.index

# Find channel index for second channel, if one is provided
if wavelength_id_c2 or channel_label_c2:
if channel2:
try:
channel_c2 = get_channel_from_image_zarr(
tmp_channel_c2: Channel = get_channel_from_image_zarr(
image_zarr_path=zarrurl,
wavelength_id=wavelength_id_c2,
label=channel_label_c2,
wavelength_id=channel2.wavelength_id,
label=channel2.label,
)
except ChannelNotFoundError as e:
logger.warning(
f"Second channel with wavelength_id_c2:{wavelength_id_c2} and "
f"channel_label_c2: {channel_label_c2} not found, exit "
f"Second channel with wavelength_id: {channel2.wavelength_id} "
f"and channel_label: {channel2.channel_label} not found, exit "
"from the task.\n"
f"Original error: {str(e)}"
)
return {}
ind_channel_c2 = channel_c2.index
ind_channel_c2 = tmp_channel_c2.index

# Set channel label
if output_label_name is None:
try:
channel_label = channel.label
channel_label = tmp_channel.label
output_label_name = f"label_{channel_label}"
except (KeyError, IndexError):
output_label_name = f"label_{ind_channel}"

# Load ZYX data
data_zyx = da.from_zarr(f"{zarrurl}/{level}")[ind_channel]
logger.info(f"{data_zyx.shape=}")
if wavelength_id_c2 or channel_label_c2:
if channel2:
data_zyx_c2 = da.from_zarr(f"{zarrurl}/{level}")[ind_channel_c2]
logger.info(f"Second channel: {data_zyx_c2.shape=}")

Expand Down Expand Up @@ -528,7 +518,7 @@ def cellpose_segmentation(
logger.info("Total well shape/chunks:")
logger.info(f"{data_zyx.shape}")
logger.info(f"{data_zyx.chunks}")
if wavelength_id_c2 or channel_label_c2:
if channel2:
logger.info("Dual channel input for cellpose model")
logger.info(f"{data_zyx_c2.shape}")
logger.info(f"{data_zyx_c2.chunks}")
Expand All @@ -555,7 +545,7 @@ def cellpose_segmentation(
logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}")

# Prepare single-channel or dual-channel input for cellpose
if wavelength_id_c2 or channel_label_c2:
if channel2:
# Dual channel mode, first channel is the membrane channel
img_np = np.zeros((2, *data_zyx[region].shape))
img_np[0, :, :, :] = data_zyx[region].compute()
Expand Down
Loading

0 comments on commit cd1ad49

Please sign in to comment.