Skip to content

Commit

Permalink
(feat): API parameter for custom callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 18, 2024
1 parent 07fb84e commit 705ede7
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 38 deletions.
82 changes: 44 additions & 38 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A class to train, predict and export models in CAREamics."""

from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
from typing import Any, Callable, Literal, Optional, Union, overload

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -35,20 +35,20 @@
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]


# TODO napari callbacks
# TODO: how to do AMP? How to continue training?
class CAREamist:
"""Main CAREamics class, allowing training and prediction using various algorithms.
Parameters
----------
source : Union[Path, str, Configuration]
source : pathlib.Path or str or CAREamics Configuration
Path to a configuration file or a trained model.
work_dir : Optional[str], optional
work_dir : str, optional
Path to working directory in which to save checkpoints and logs,
by default None.
experiment_name : str, optional
Experiment name used for checkpoints, by default "CAREamics".
experiment_name : str, by default "CAREamics"
Experiment name used for checkpoints.
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
Attributes
----------
Expand All @@ -58,8 +58,7 @@ class CAREamist:
CAREamics configuration.
trainer : Trainer
PyTorch Lightning trainer.
experiment_logger : pytorch_lightning.loggersTensorBoardLogger or
pytorch_lightning.loggersWandbLogger
experiment_logger : TensorBoardLogger or WandbLogger
Experiment logger, "wandb" or "tensorboard".
work_dir : pathlib.Path
Working directory.
Expand All @@ -75,6 +74,7 @@ def __init__( # numpydoc ignore=GL08
source: Union[Path, str],
work_dir: Optional[str] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None: ...

@overload
Expand All @@ -83,13 +83,15 @@ def __init__( # numpydoc ignore=GL08
source: Configuration,
work_dir: Optional[str] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None: ...

def __init__(
self,
source: Union[Path, str, Configuration],
work_dir: Optional[Union[Path, str]] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Initialize CAREamist with a configuration object or a path.
Expand All @@ -113,6 +115,8 @@ def __init__(
by default None.
experiment_name : str, optional
Experiment name used for checkpoints, by default "CAREamics".
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
Raises
------
Expand Down Expand Up @@ -165,7 +169,7 @@ def __init__(
self.model, self.cfg = load_pretrained(source)

# define the checkpoint saving callback
self.callbacks = self._define_callbacks()
self._define_callbacks(callbacks)

# instantiate logger
if self.cfg.training_config.has_logger():
Expand Down Expand Up @@ -196,34 +200,36 @@ def __init__(
self.train_datamodule: Optional[CAREamicsTrainData] = None
self.pred_datamodule: Optional[CAREamicsPredictData] = None

def _define_callbacks(self) -> List[Callback]:
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
"""
Define the callbacks for the training loop.
Returns
-------
list of Callback
List of callbacks to be used during training.
Parameters
----------
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
"""
self.callbacks = [] if callbacks is None else callbacks

# checkpoint callback saves checkpoints during training
self.callbacks = [
HyperParametersCallback(self.cfg),
ModelCheckpoint(
dirpath=self.work_dir / Path("checkpoints"),
filename=self.cfg.experiment_name,
**self.cfg.training_config.checkpoint_callback.model_dump(),
),
ProgressBarCallback(),
]
self.callbacks.extend(
[
HyperParametersCallback(self.cfg),
ModelCheckpoint(
dirpath=self.work_dir / Path("checkpoints"),
filename=self.cfg.experiment_name,
**self.cfg.training_config.checkpoint_callback.model_dump(),
),
ProgressBarCallback(),
]
)

# early stopping callback
if self.cfg.training_config.early_stopping_callback is not None:
self.callbacks.append(
EarlyStopping(self.cfg.training_config.early_stopping_callback)
)

return self.callbacks

def train(
self,
*,
Expand Down Expand Up @@ -486,12 +492,12 @@ def predict( # numpydoc ignore=GL08
source: Union[Path, str],
*,
batch_size: int = 1,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Tuple[int, ...] = (48, 48),
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["tiff", "custom"]] = None,
tta_transforms: bool = True,
dataloader_params: Optional[Dict] = None,
dataloader_params: Optional[dict] = None,
read_source_func: Optional[Callable] = None,
extension_filter: str = "",
checkpoint: Optional[Literal["best", "last"]] = None,
Expand All @@ -503,12 +509,12 @@ def predict( # numpydoc ignore=GL08
source: NDArray,
*,
batch_size: int = 1,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Tuple[int, ...] = (48, 48),
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["array"]] = None,
tta_transforms: bool = True,
dataloader_params: Optional[Dict] = None,
dataloader_params: Optional[dict] = None,
checkpoint: Optional[Literal["best", "last"]] = None,
) -> Union[list[NDArray], NDArray]: ...

Expand All @@ -517,17 +523,17 @@ def predict(
source: Union[CAREamicsPredictData, Path, str, NDArray],
*,
batch_size: Optional[int] = None,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Tuple[int, ...] = (48, 48),
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
tta_transforms: bool = True,
dataloader_params: Optional[Dict] = None,
dataloader_params: Optional[dict] = None,
read_source_func: Optional[Callable] = None,
extension_filter: str = "",
checkpoint: Optional[Literal["best", "last"]] = None,
**kwargs: Any,
) -> Union[List[NDArray], NDArray]:
) -> Union[list[NDArray], NDArray]:
"""
Make predictions on the provided data.
Expand Down Expand Up @@ -669,9 +675,9 @@ def export_to_bmz(
path: Union[Path, str],
name: str,
input_array: NDArray,
authors: List[dict],
authors: list[dict],
general_description: str = "",
channel_names: Optional[List[str]] = None,
channel_names: Optional[list[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.
Expand Down
47 changes: 47 additions & 0 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
import tifffile
from pytorch_lightning.callbacks import Callback

from careamics import CAREamist, Configuration, save_configuration
from careamics.config.support import SupportedAlgorithm, SupportedData
Expand Down Expand Up @@ -733,3 +734,49 @@ def test_export_bmz_pretrained_with_array(tmp_path: Path, pre_trained: Path):
general_description="A model that just walked in.",
)
assert (tmp_path / "model2.zip").exists()


def test_add_custom_callback(tmp_path, minimum_configuration):
"""Test that custom callback can be added to the CAREamist."""

# define a custom callback
class MyPrintingCallback(Callback):
def __init__(self):
super().__init__()

self.has_started = False
self.has_ended = False

def on_train_start(self, trainer, pl_module):
self.has_started = True

def on_train_end(self, trainer, pl_module):
self.has_ended = True

my_callback = MyPrintingCallback()
assert not my_callback.has_started
assert not my_callback.has_ended

# training data
train_array = random_array((32, 32))
val_array = random_array((32, 32))

# create configuration
config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 1
config.data_config.axes = "YX"
config.data_config.batch_size = 2
config.data_config.data_type = SupportedData.ARRAY.value
config.data_config.patch_size = (8, 8)

# instantiate CAREamist
careamist = CAREamist(source=config, work_dir=tmp_path, callbacks=[my_callback])
assert not my_callback.has_started
assert not my_callback.has_ended

# train CAREamist
careamist.train(train_source=train_array, val_source=val_array)

# check the state of the callback
assert my_callback.has_started
assert my_callback.has_ended

0 comments on commit 705ede7

Please sign in to comment.