Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add predict command to CLI #281

Merged
merged 10 commits into from
Dec 2, 2024
41 changes: 26 additions & 15 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(

# path to configuration file or model
else:
# TODO: update this check so models can be downloaded directly from BMZ
source = check_path_exists(source)

# configuration file
Expand Down Expand Up @@ -734,7 +735,7 @@ def predict_to_disk(

Parameters
----------
source : PredictDataModule, pathlib.Path or str
source : PredictDataModule or pathlib.Path, str
Data to predict on.
batch_size : int, default=1
Batch size for prediction.
Expand Down Expand Up @@ -804,27 +805,36 @@ def predict_to_disk(
write_extension = SupportedData.get_extension(write_type)

# extract file names
source_path: Union[Path, str, NDArray]
source_data_type: Literal["array", "tiff", "custom"]
if isinstance(source, PredictDataModule):
# assert not isinstance(source.pred_data, )
source_file_paths = list_files(
source.pred_data, source.data_type, source.extension_filter
)
source_path = source.pred_data
source_data_type = source.data_type
extension_filter = source.extension_filter
elif isinstance(source, (str, Path)):
assert self.cfg.data_config.data_type != "array"
data_type = data_type or self.cfg.data_config.data_type
source_path = source
source_data_type = data_type or self.cfg.data_config.data_type
extension_filter = SupportedData.get_extension_pattern(
SupportedData(data_type)
SupportedData(source_data_type)
)
source_file_paths = list_files(source, data_type, extension_filter)
else:
raise ValueError(f"Unsupported source type: '{type(source)}'.")

if source_data_type == "array":
raise ValueError(
"Predicting to disk is not supported for input type 'array'."
)
assert isinstance(source_path, (Path, str)) # because data_type != "array"
source_path = Path(source_path)

file_paths = list_files(source_path, source_data_type, extension_filter)

# predict and write each file in turn
for source_path in source_file_paths:
for file_path in file_paths:
# source_path is relative to original source path...
# should mirror original directory structure
prediction = self.predict(
source=source_path,
source=file_path,
batch_size=batch_size,
tile_size=tile_size,
tile_overlap=tile_overlap,
Expand All @@ -840,11 +850,12 @@ def predict_to_disk(
write_data = np.concatenate(prediction)

# create directory structure and write path
file_write_dir = write_dir / source_path.parent.name
if not source_path.is_file():
file_write_dir = write_dir / file_path.parent.relative_to(source_path)
else:
file_write_dir = write_dir
file_write_dir.mkdir(parents=True, exist_ok=True)
write_path = (file_write_dir / source_path.name).with_suffix(
write_extension
)
write_path = (file_write_dir / file_path.name).with_suffix(write_extension)

# write data
write_func(file_path=write_path, img=write_data)
Expand Down
29 changes: 5 additions & 24 deletions src/careamics/cli/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional

import click
import typer
Expand All @@ -17,6 +17,7 @@
create_n2v_configuration,
save_configuration,
)
from .utils import handle_2D_3D_callback

WORK_DIR = Path.cwd()

Expand Down Expand Up @@ -92,26 +93,6 @@ def conf_options( # numpydoc ignore=PR01
ctx.obj = ConfOptions(dir, name, force, print)


def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
"""
Callback for --patch-size option.

Parameters
----------
value : (int, int, int)
Patch size value.

Returns
-------
(int, int, int) | (int, int)
If the last element in `value` is -1 the tuple is reduced to the first two
values.
"""
if value[2] == -1:
return value[:2]
return value


# TODO: Need to decide how to parse model kwargs
# - Could be json style string to be loaded as dict e.g. {"depth": 3}
# - Cons: Annoying to type, easily have syntax errors
Expand All @@ -132,7 +113,7 @@ def care( # numpydoc ignore=PR01
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
),
click_type=click.Tuple([int, int, int]),
callback=patch_size_callback,
callback=handle_2D_3D_callback,
),
],
batch_size: Annotated[int, typer.Option(help="Batch size.")],
Expand Down Expand Up @@ -219,7 +200,7 @@ def n2n( # numpydoc ignore=PR01
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
),
click_type=click.Tuple([int, int, int]),
callback=patch_size_callback,
callback=handle_2D_3D_callback,
),
],
batch_size: Annotated[int, typer.Option(help="Batch size.")],
Expand Down Expand Up @@ -303,7 +284,7 @@ def n2v( # numpydoc ignore=PR01
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
),
click_type=click.Tuple([int, int, int]),
callback=patch_size_callback,
callback=handle_2D_3D_callback,
),
],
batch_size: Annotated[int, typer.Option(help="Batch size.")],
Expand Down
121 changes: 111 additions & 10 deletions src/careamics/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@
from pathlib import Path
from typing import Optional

import click
import typer
from typing_extensions import Annotated

from ..careamist import CAREamist
from . import conf
from .utils import handle_2D_3D_callback

app = typer.Typer(
help="Run CAREamics algorithms from the command line, including Noise2Void "
"and its many variants and cousins"
)
app.add_typer(
conf.app,
name="conf",
# callback=conf.conf_options
"and its many variants and cousins",
pretty_exceptions_show_locals=False,
)
app.add_typer(conf.app, name="conf")


@app.command()
Expand Down Expand Up @@ -102,7 +101,7 @@ def train( # numpydoc ignore=PR01
typer.Option(
"--work-dir",
"-wd",
help=("Path to working directory in which to save checkpoints and " "logs"),
help=("Path to working directory in which to save checkpoints and logs"),
exists=True,
file_okay=False,
dir_okay=True,
Expand All @@ -123,10 +122,112 @@ def train( # numpydoc ignore=PR01


@app.command()
def predict(): # numpydoc ignore=PR01
def predict( # numpydoc ignore=PR01
model: Annotated[
Path,
typer.Argument(
help="Path to a configuration file or a trained model.",
exists=True,
file_okay=True,
dir_okay=False,
),
],
source: Annotated[
Path,
typer.Argument(
help="Path to the training data. Can be a directory or single file.",
exists=True,
file_okay=True,
dir_okay=True,
),
],
batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
tile_size: Annotated[
Optional[click.Tuple],
typer.Option(
help=(
"Size of the tiles to use for prediction, (if the data "
"is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
),
click_type=click.Tuple([int, int, int]),
callback=handle_2D_3D_callback,
),
] = None,
tile_overlap: Annotated[
click.Tuple,
typer.Option(
help=(
"Overlap between tiles, (if the data is not 3D pass the last value as "
"-1 e.g. --tile_overlap 64 64 -1)."
),
click_type=click.Tuple([int, int, int]),
callback=handle_2D_3D_callback,
),
] = (48, 48, -1),
axes: Annotated[
Optional[str],
typer.Option(
help="Axes of the input data. If unused the data is assumed to have the "
"same axes as the original training data."
),
] = None,
data_type: Annotated[
click.Choice,
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
] = "tiff",
tta_transforms: Annotated[
bool,
typer.Option(
"--tta-transforms/--no-tta-transforms",
"-t/-T",
help="Whether to apply test-time augmentation.",
),
] = False,
write_type: Annotated[
click.Choice,
typer.Option(
click_type=click.Choice(["tiff"]), help="Type of the output data."
),
] = "tiff",
# TODO: could make dataloader_params as json, necessary?
work_dir: Annotated[
Optional[Path],
typer.Option(
"--work-dir",
"-wd",
help=("Path to working directory."),
exists=True,
file_okay=False,
dir_okay=True,
),
] = None,
prediction_dir: Annotated[
Path,
typer.Option(
"--prediction-dir",
"-pd",
help=(
"Directory to save predictions to. If not an abosulte path it will be "
"relative to the set working directory."
),
file_okay=False,
dir_okay=True,
),
] = Path("predictions"),
):
"""Create and save predictions from CAREamics models."""
# TODO: Need a save predict to workdir function
raise NotImplementedError
engine = CAREamist(source=model, work_dir=work_dir)
engine.predict_to_disk(
source=source,
batch_size=batch_size,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
tta_transforms=tta_transforms,
write_type=write_type,
prediction_dir=prediction_dir,
)


def run():
Expand Down
29 changes: 29 additions & 0 deletions src/careamics/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Utility functions for the CAREamics CLI."""

from typing import Optional, Tuple


def handle_2D_3D_callback(
value: Optional[Tuple[int, int, int]]
) -> Optional[Tuple[int, ...]]:
"""
Callback for options that require 2D or 3D inputs.

In the case of 2D, the 3rd element should be set to -1.

Parameters
----------
value : (int, int, int)
Tile size value.

Returns
-------
(int, int, int) | (int, int)
If the last element in `value` is -1 the tuple is reduced to the first two
values.
"""
if value is None:
return value
if value[2] == -1:
return value[:2]
return value
Loading