diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 256f0681..7f92e10b 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -52,7 +52,7 @@ class CAREamist: by default None. callbacks : list of Callback, optional List of callbacks to use during training and prediction, by default None. - Note: ModelCheckpoint configuration should be set through the Configuration + Note: ModelCheckpoint configuration should be set through the Configuration object (config.training_config.model_checkpoint) rather than through callbacks. Attributes @@ -221,9 +221,9 @@ def __init__( def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None: """Define the callbacks for the training loop. - ModelCheckpoint configuration should be provided through the Configuration + ModelCheckpoint configuration should be provided through the Configuration object (config.training_config.model_checkpoint) rather than through callbacks. - If no ModelCheckpoint is specified in the configuration, default checkpoint + If no ModelCheckpoint is specified in the configuration, default checkpoint settings will be used. Parameters @@ -254,12 +254,14 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None: self.callbacks.extend( [ HyperParametersCallback(self.cfg), - self.cfg.training_config.model_checkpoint - if self.cfg.training_config.model_checkpoint is not None - else ModelCheckpoint( - dirpath=self.work_dir / Path("checkpoints"), - filename=self.cfg.experiment_name, - **self.cfg.training_config.checkpoint_callback.model_dump(), + ( + self.cfg.training_config.model_checkpoint + if self.cfg.training_config.model_checkpoint is not None + else ModelCheckpoint( + dirpath=self.work_dir / Path("checkpoints"), + filename=self.cfg.experiment_name, + **self.cfg.training_config.checkpoint_callback.model_dump(), + ) ), ProgressBarCallback(), ] diff --git a/src/careamics/config/configuration_factories.py b/src/careamics/config/configuration_factories.py index 6f62032f..6e4cbe79 100644 --- a/src/careamics/config/configuration_factories.py +++ b/src/careamics/config/configuration_factories.py @@ -304,7 +304,7 @@ def _create_configuration( num_epochs=num_epochs, batch_size=batch_size, logger=None if logger == "none" else logger, - model_checkpoint=model_checkpoint, + model_checkpoint=model_checkpoint, ) # create configuration @@ -496,7 +496,7 @@ def create_care_configuration( dataloader_params : dict, optional Parameters for the dataloader, see PyTorch notes, by default None. modelcheckpoint : ModelCheckpoint, optional - PyTorch Lightning ModelCheckpoint configuration. If not provided, + PyTorch Lightning ModelCheckpoint configuration. If not provided, default checkpoint settings will be used. Returns diff --git a/src/careamics/config/training_model.py b/src/careamics/config/training_model.py index e4772613..e6fc8a15 100644 --- a/src/careamics/config/training_model.py +++ b/src/careamics/config/training_model.py @@ -6,9 +6,10 @@ from typing import Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator +from pytorch_lightning.callbacks import ModelCheckpoint from .callback_model import CheckpointModel, EarlyStoppingModel -from pytorch_lightning.callbacks import ModelCheckpoint + class TrainingConfig(BaseModel): """ @@ -28,7 +29,7 @@ class TrainingConfig(BaseModel): # Pydantic class configuration model_config = ConfigDict( validate_assignment=True, - arbitrary_types_allowed=True, #test - Diya 27.1.25 + arbitrary_types_allowed=True, # test - Diya 27.1.25 ) num_epochs: int = Field(default=20, ge=1)