Skip to content

Commit

Permalink
feat: check that user callbacks are not those used by CAREamics
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 19, 2024
1 parent d200884 commit 2531b68
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
17 changes: 17 additions & 0 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,23 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
"""
self.callbacks = [] if callbacks is None else callbacks

# check that user callbacks are not any of the CAREamics callbacks
for c in self.callbacks:
if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
raise ValueError(
"ModelCheckpoint and EarlyStopping callbacks are already defined "
"in CAREamics and should only be modified through the "
"training configuration (see TrainingConfig)."
)

if isinstance(c, HyperParametersCallback) or isinstance(
c, ProgressBarCallback
):
raise ValueError(
"HyperParameter and ProgressBar callbacks are defined internally "
"and should not be passed as callbacks."
)

# checkpoint callback saves checkpoints during training
self.callbacks.extend(
[
Expand Down
49 changes: 44 additions & 5 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import numpy as np
import pytest
import tifffile
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint

from careamics import CAREamist, Configuration, save_configuration
from careamics.callbacks import HyperParametersCallback, ProgressBarCallback
from careamics.config.support import SupportedAlgorithm, SupportedData


Expand Down Expand Up @@ -740,7 +742,7 @@ def test_add_custom_callback(tmp_path, minimum_configuration):
"""Test that custom callback can be added to the CAREamist."""

# define a custom callback
class MyPrintingCallback(Callback):
class MyCallback(Callback):
def __init__(self):
super().__init__()

Expand All @@ -753,13 +755,12 @@ def on_train_start(self, trainer, pl_module):
def on_train_end(self, trainer, pl_module):
self.has_ended = True

my_callback = MyPrintingCallback()
my_callback = MyCallback()
assert not my_callback.has_started
assert not my_callback.has_ended

# training data
train_array = random_array((32, 32))
val_array = random_array((32, 32))

# create configuration
config = Configuration(**minimum_configuration)
Expand All @@ -775,8 +776,46 @@ def on_train_end(self, trainer, pl_module):
assert not my_callback.has_ended

# train CAREamist
careamist.train(train_source=train_array, val_source=val_array)
careamist.train(train_source=train_array)

# check the state of the callback
assert my_callback.has_started
assert my_callback.has_ended


def test_error_passing_careamics_callback(tmp_path, minimum_configuration):
"""Test that an error is thrown if we pass known callbacks to CAREamist."""
# create configuration
config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 1
config.data_config.axes = "YX"
config.data_config.batch_size = 2
config.data_config.data_type = SupportedData.ARRAY.value
config.data_config.patch_size = (8, 8)

# Lightning callbacks
model_ckp = ModelCheckpoint()

with pytest.raises(ValueError):
CAREamist(source=config, work_dir=tmp_path, callbacks=[model_ckp])

early_stp = EarlyStopping(
Trainer(
max_epochs=1,
default_root_dir=tmp_path,
)
)

with pytest.raises(ValueError):
CAREamist(source=config, work_dir=tmp_path, callbacks=[early_stp])

# CAREamics callbacks
progress_bar = ProgressBarCallback()

with pytest.raises(ValueError):
CAREamist(source=config, work_dir=tmp_path, callbacks=[progress_bar])

hyper_params = HyperParametersCallback(config=config)

with pytest.raises(ValueError):
CAREamist(source=config, work_dir=tmp_path, callbacks=[hyper_params])

0 comments on commit 2531b68

Please sign in to comment.