diff --git a/fractal_tasks_core/__FRACTAL_MANIFEST__.json b/fractal_tasks_core/__FRACTAL_MANIFEST__.json index d1843bc4b..5fd5812a6 100644 --- a/fractal_tasks_core/__FRACTAL_MANIFEST__.json +++ b/fractal_tasks_core/__FRACTAL_MANIFEST__.json @@ -321,6 +321,23 @@ { "args_schema": { "additionalProperties": false, + "definitions": { + "CellposeSegmentationChannel": { + "description": "TBD", + "properties": { + "label": { + "title": "Label", + "type": "string" + }, + "wavelength_id": { + "title": "Wavelength Id", + "type": "string" + } + }, + "title": "CellposeSegmentationChannel", + "type": "object" + } + }, "properties": { "anisotropy": { "description": "Ratio of the pixel sizes along Z and XY axis (ignored if the image is not three-dimensional). If `None`, it is inferred from the OME-NGFF metadata.", @@ -339,15 +356,13 @@ "title": "Cellprob Threshold", "type": "number" }, - "channel_label": { - "description": "Identifier of a channel based on its label (e.g. ``DAPI``). If not ``None``, then ``wavelength_id`` must be ``None``.", - "title": "Channel Label", - "type": "string" + "channel": { + "$ref": "#/definitions/CellposeSegmentationChannel", + "description": "TBD wavelength_id: Identifier of a channel based on the wavelength (e.g. ``A01_C01``). If not ``None``, then ``label` must be ``None``. Identifier of a channel based on its label (e.g. ``DAPI``). If not ``None``, then ``wavelength_id`` must be ``None``." }, - "channel_label_c2": { - "description": "Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker.", - "title": "Channel Label C2", - "type": "string" + "channel2": { + "$ref": "#/definitions/CellposeSegmentationChannel", + "description": "TBD Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker. Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker." }, "component": { "description": "Path to the OME-Zarr image in the OME-Zarr plate that is processed. Example: \"some_plate.zarr/B/03/0\" (standard argument for Fractal tasks, managed by Fractal server)", @@ -445,16 +460,6 @@ "description": "If ``True``, try to use masked loading and fall back to ``use_masks=False`` if the ROI table is not suitable. Masked loading is relevant when only a subset of the bounding box should actually be processed (e.g. running within organoid_ROI_table).", "title": "Use Masks", "type": "boolean" - }, - "wavelength_id": { - "description": "Identifier of a channel based on the wavelength (e.g. ``A01_C01``). If not ``None``, then ``channel_label` must be ``None``.", - "title": "Wavelength Id", - "type": "string" - }, - "wavelength_id_c2": { - "description": "Identifier of a second channel in the same format as the first wavelength_id. If specified, cellpose runs in dual channel mode. For dual channel segmentation of cells, the first channel should contain the membrane marker, the second channel should contain the nuclear marker.", - "title": "Wavelength Id C2", - "type": "string" } }, "required": [ @@ -462,7 +467,8 @@ "output_path", "component", "metadata", - "level" + "level", + "channel" ], "title": "CellposeSegmentation", "type": "object" diff --git a/fractal_tasks_core/tasks/_input_models.py b/fractal_tasks_core/tasks/_input_models.py new file mode 100644 index 000000000..004ac6fe1 --- /dev/null +++ b/fractal_tasks_core/tasks/_input_models.py @@ -0,0 +1,65 @@ +""" +Copyright 2022 (C) + Friedrich Miescher Institute for Biomedical Research and + University of Zurich + + Original authors: + Tommaso Comparin + + This file is part of Fractal and was originally developed by eXact lab + S.r.l. under contract with Liberali Lab from the Friedrich + Miescher Institute for Biomedical Research and Pelkmans Lab from the + University of Zurich. + +Pydantic models for some task parameters +""" +from typing import Optional + +from pydantic import BaseModel +from pydantic import validator + + +class BaseChannel(BaseModel): + """ + TBD + """ + + wavelength_id: Optional[str] = None + label: Optional[str] = None + + @validator("label", always=True) + def mutually_exclusive_channel_attributes(cls, v, values): + """ + If `label` is set, then `wavelength_id` must be `None` + """ + wavelength_id = values.get("wavelength_id") + label = v + + if wavelength_id is not None and v is not None: + raise ValueError( + "`wavelength_id` and `label` cannot be both set " + f"(given {wavelength_id=} and {label=})." + ) + + if wavelength_id is None and v is None: + raise ValueError( + "`wavelength_id` and `label` cannot be both `None`" + ) + + return v + + +class CellposeSegmentationChannel(BaseChannel): + """ + TBD + """ + + pass + + +class NapariWorkflowsChannel(BaseModel): + """ + TBD + """ + + pass diff --git a/fractal_tasks_core/tasks/cellpose_segmentation.py b/fractal_tasks_core/tasks/cellpose_segmentation.py index 59f9533cb..aa3bb54d2 100644 --- a/fractal_tasks_core/tasks/cellpose_segmentation.py +++ b/fractal_tasks_core/tasks/cellpose_segmentation.py @@ -36,6 +36,7 @@ from pydantic.decorator import validate_arguments import fractal_tasks_core +from fractal_tasks_core.lib_channels import Channel from fractal_tasks_core.lib_channels import ChannelNotFoundError from fractal_tasks_core.lib_channels import get_channel_from_image_zarr from fractal_tasks_core.lib_masked_loading import masked_loading_wrapper @@ -51,6 +52,7 @@ from fractal_tasks_core.lib_ROI_overlaps import get_overlapping_pairs_3D from fractal_tasks_core.lib_zattrs_utils import extract_zyx_pixel_sizes from fractal_tasks_core.lib_zattrs_utils import rescale_datasets +from fractal_tasks_core.tasks._input_models import CellposeSegmentationChannel logger = logging.getLogger(__name__) @@ -151,10 +153,8 @@ def cellpose_segmentation( metadata: Dict[str, Any], # Task-specific arguments level: int, - wavelength_id: Optional[str] = None, - channel_label: Optional[str] = None, - wavelength_id_c2: Optional[str] = None, - channel_label_c2: Optional[str] = None, + channel: CellposeSegmentationChannel, + channel2: Optional[CellposeSegmentationChannel] = None, input_ROI_table: str = "FOV_ROI_table", output_ROI_table: Optional[str] = None, output_label_name: Optional[str] = None, @@ -203,24 +203,23 @@ def cellpose_segmentation( managed by Fractal server) :param level: Pyramid level of the image to be segmented. Choose 0 to process at full resolution. - :param wavelength_id: Identifier of a channel based on the - wavelength (e.g. ``A01_C01``). If not ``None``, then - ``channel_label` must be ``None``. - :param channel_label: Identifier of a channel based on its label (e.g. - ``DAPI``). If not ``None``, then ``wavelength_id`` - must be ``None``. - :param wavelength_id_c2: Identifier of a second channel in the same format - as the first wavelength_id. If specified, cellpose - runs in dual channel mode. For dual channel - segmentation of cells, the first channel should - contain the membrane marker, the second channel - should contain the nuclear marker. - :param channel_label_c2: Identifier of a second channel in the same - format as the first wavelength_id. If specified, - cellpose runs in dual channel mode. For dual - channel segmentation of cells, the first channel - should contain the membrane marker, the second - channel should contain the nuclear marker. + :param channel: TBD + wavelength_id: Identifier of a channel based on the + wavelength (e.g. ``A01_C01``). If not ``None``, then + ``label` must be ``None``. Identifier of a channel + based on its label (e.g. ``DAPI``). If not ``None``, then + ``wavelength_id`` must be ``None``. + :param channel2: TBD + Identifier of a second channel in the same format as the + first wavelength_id. If specified, cellpose runs in dual + channel mode. For dual channel segmentation of cells, the + first channel should contain the membrane marker, the + second channel should contain the nuclear marker. + Identifier of a second channel in the same format as the + first wavelength_id. If specified, cellpose runs in dual + channel mode. For dual channel segmentation of cells, the + first channel should contain the membrane marker, the + second channel should contain the nuclear marker. :param input_ROI_table: Name of the ROI table over which the task loops to apply Cellpose segmentation. Example: "FOV_ROI_table" => loop over the field of @@ -294,15 +293,6 @@ def cellpose_segmentation( zarrurl = (in_path.resolve() / component).as_posix() logger.info(f"{zarrurl=}") - # Preliminary check - if (channel_label is None and wavelength_id is None) or ( - channel_label and wavelength_id - ): - raise ValueError( - f"One and only one of {channel_label=} and " - f"{wavelength_id=} arguments must be provided" - ) - # Preliminary checks on Cellpose model if pretrained_model is None: if model_type not in models.MODEL_NAMES: @@ -319,10 +309,10 @@ def cellpose_segmentation( # Find channel index try: - channel = get_channel_from_image_zarr( + tmp_channel: Channel = get_channel_from_image_zarr( image_zarr_path=zarrurl, - wavelength_id=wavelength_id, - label=channel_label, + wavelength_id=channel.wavelength_id, + label=channel.label, ) except ChannelNotFoundError as e: logger.warning( @@ -330,30 +320,30 @@ def cellpose_segmentation( f"Original error: {str(e)}" ) return {} - ind_channel = channel.index + ind_channel = tmp_channel.index # Find channel index for second channel, if one is provided - if wavelength_id_c2 or channel_label_c2: + if channel2: try: - channel_c2 = get_channel_from_image_zarr( + tmp_channel_c2: Channel = get_channel_from_image_zarr( image_zarr_path=zarrurl, - wavelength_id=wavelength_id_c2, - label=channel_label_c2, + wavelength_id=channel2.wavelength_id, + label=channel2.label, ) except ChannelNotFoundError as e: logger.warning( - f"Second channel with wavelength_id_c2:{wavelength_id_c2} and " - f"channel_label_c2: {channel_label_c2} not found, exit " + f"Second channel with wavelength_id: {channel2.wavelength_id} " + f"and channel_label: {channel2.channel_label} not found, exit " "from the task.\n" f"Original error: {str(e)}" ) return {} - ind_channel_c2 = channel_c2.index + ind_channel_c2 = tmp_channel_c2.index # Set channel label if output_label_name is None: try: - channel_label = channel.label + channel_label = tmp_channel.label output_label_name = f"label_{channel_label}" except (KeyError, IndexError): output_label_name = f"label_{ind_channel}" @@ -361,7 +351,7 @@ def cellpose_segmentation( # Load ZYX data data_zyx = da.from_zarr(f"{zarrurl}/{level}")[ind_channel] logger.info(f"{data_zyx.shape=}") - if wavelength_id_c2 or channel_label_c2: + if channel2: data_zyx_c2 = da.from_zarr(f"{zarrurl}/{level}")[ind_channel_c2] logger.info(f"Second channel: {data_zyx_c2.shape=}") @@ -528,7 +518,7 @@ def cellpose_segmentation( logger.info("Total well shape/chunks:") logger.info(f"{data_zyx.shape}") logger.info(f"{data_zyx.chunks}") - if wavelength_id_c2 or channel_label_c2: + if channel2: logger.info("Dual channel input for cellpose model") logger.info(f"{data_zyx_c2.shape}") logger.info(f"{data_zyx_c2.chunks}") @@ -555,7 +545,7 @@ def cellpose_segmentation( logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}") # Prepare single-channel or dual-channel input for cellpose - if wavelength_id_c2 or channel_label_c2: + if channel2: # Dual channel mode, first channel is the membrane channel img_np = np.zeros((2, *data_zyx[region].shape)) img_np[0, :, :, :] = data_zyx[region].compute() diff --git a/tests/test_workflows_cellpose_segmentation.py b/tests/test_workflows_cellpose_segmentation.py index e0dd9cb16..b4ab9a7fb 100644 --- a/tests/test_workflows_cellpose_segmentation.py +++ b/tests/test_workflows_cellpose_segmentation.py @@ -203,14 +203,14 @@ def test_failures( # Attempt 1 cellpose_segmentation( **kwargs, - wavelength_id="invalid_wavelength_id", + channel=dict(wavelength_id="invalid_wavelength_id"), ) assert "ChannelNotFoundError" in caplog.records[0].msg # Attempt 2 cellpose_segmentation( **kwargs, - channel_label="invalid_channel_name", + channel=dict(label="invalid_channel_name"), ) assert "ChannelNotFoundError" in caplog.records[0].msg assert "ChannelNotFoundError" in caplog.records[1].msg @@ -219,8 +219,10 @@ def test_failures( with pytest.raises(ValueError): cellpose_segmentation( **kwargs, - wavelength_id="A01_C01", - channel_label="invalid_channel_name", + channel=dict( + wavelength_id="A01_C01", + label="invalid_channel_name", + ), ) @@ -262,7 +264,7 @@ def test_workflow_with_per_FOV_labeling( output_path=str(zarr_path), metadata=metadata, component=component, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=3, relabeling=True, diameter_level0=80.0, @@ -322,8 +324,8 @@ def test_workflow_with_multi_channel_input( output_path=str(zarr_path), metadata=metadata, component=component, - wavelength_id="A01_C01", - wavelength_id_c2="A01_C01", + channel=dict(wavelength_id="A01_C01"), + channel2=dict(wavelength_id="A01_C01"), level=3, relabeling=True, diameter_level0=80.0, @@ -379,7 +381,7 @@ def test_workflow_with_per_FOV_labeling_2D( output_path=str(zarr_path_mip), metadata=metadata, component=component, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=2, relabeling=True, diameter_level0=80.0, @@ -471,7 +473,7 @@ def test_workflow_with_per_well_labeling_2D( output_path=str(zarr_path_mip), metadata=metadata, component=component, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=2, input_ROI_table="well_ROI_table", relabeling=True, @@ -529,7 +531,7 @@ def test_workflow_bounding_box( output_path=str(zarr_path), metadata=metadata, component=component, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=3, relabeling=True, diameter_level0=80.0, @@ -585,7 +587,7 @@ def test_workflow_bounding_box_with_overlap( output_path=str(zarr_path), metadata=metadata, component=component, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=3, relabeling=True, diameter_level0=80.0, @@ -627,7 +629,7 @@ def test_workflow_with_per_FOV_labeling_via_script( output_path=str(zarr_path), metadata=metadata, component=metadata["image"][0], - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=4, relabeling=True, diameter_level0=80.0, @@ -701,7 +703,7 @@ def test_workflow_with_per_FOV_labeling_with_empty_FOV_table( metadata=metadata, component=component, input_ROI_table=TABLE_NAME, - wavelength_id="A01_C01", + channel=dict(wavelength_id="A01_C01"), level=3, relabeling=True, diameter_level0=80.0,