Skip to content

Commit

Permalink
style(pre-commit.ci): auto fixes [...]
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Jan 29, 2025
1 parent ece522d commit bc39633
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
20 changes: 11 additions & 9 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CAREamist:
by default None.
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
Note: ModelCheckpoint configuration should be set through the Configuration
Note: ModelCheckpoint configuration should be set through the Configuration
object (config.training_config.model_checkpoint) rather than through callbacks.
Attributes
Expand Down Expand Up @@ -221,9 +221,9 @@ def __init__(
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
"""Define the callbacks for the training loop.
ModelCheckpoint configuration should be provided through the Configuration
ModelCheckpoint configuration should be provided through the Configuration
object (config.training_config.model_checkpoint) rather than through callbacks.
If no ModelCheckpoint is specified in the configuration, default checkpoint
If no ModelCheckpoint is specified in the configuration, default checkpoint
settings will be used.
Parameters
Expand Down Expand Up @@ -254,12 +254,14 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
self.callbacks.extend(
[
HyperParametersCallback(self.cfg),
self.cfg.training_config.model_checkpoint
if self.cfg.training_config.model_checkpoint is not None
else ModelCheckpoint(
dirpath=self.work_dir / Path("checkpoints"),
filename=self.cfg.experiment_name,
**self.cfg.training_config.checkpoint_callback.model_dump(),
(
self.cfg.training_config.model_checkpoint
if self.cfg.training_config.model_checkpoint is not None
else ModelCheckpoint(
dirpath=self.work_dir / Path("checkpoints"),
filename=self.cfg.experiment_name,
**self.cfg.training_config.checkpoint_callback.model_dump(),
)
),
ProgressBarCallback(),
]
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/config/configuration_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _create_configuration(
num_epochs=num_epochs,
batch_size=batch_size,
logger=None if logger == "none" else logger,
model_checkpoint=model_checkpoint,
model_checkpoint=model_checkpoint,
)

# create configuration
Expand Down Expand Up @@ -496,7 +496,7 @@ def create_care_configuration(
dataloader_params : dict, optional
Parameters for the dataloader, see PyTorch notes, by default None.
modelcheckpoint : ModelCheckpoint, optional
PyTorch Lightning ModelCheckpoint configuration. If not provided,
PyTorch Lightning ModelCheckpoint configuration. If not provided,
default checkpoint settings will be used.
Returns
Expand Down
5 changes: 3 additions & 2 deletions src/careamics/config/training_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pytorch_lightning.callbacks import ModelCheckpoint

from .callback_model import CheckpointModel, EarlyStoppingModel
from pytorch_lightning.callbacks import ModelCheckpoint


class TrainingConfig(BaseModel):
"""
Expand All @@ -28,7 +29,7 @@ class TrainingConfig(BaseModel):
# Pydantic class configuration
model_config = ConfigDict(
validate_assignment=True,
arbitrary_types_allowed=True, #test - Diya 27.1.25
arbitrary_types_allowed=True, # test - Diya 27.1.25
)

num_epochs: int = Field(default=20, ge=1)
Expand Down

0 comments on commit bc39633

Please sign in to comment.