Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict and interpret adaptation to data class #586

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions clinicadl/config/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .data import DataConfig
from .dataloader import DataLoaderConfig
from .early_stopping import EarlyStoppingConfig
from .interpret import InterpretConfig
from .lr_scheduler import LRschedulerConfig
from .maps_manager import MapsManagerConfig
from .modality import (
Expand All @@ -17,7 +16,6 @@
from .model import ModelConfig
from .optimization import OptimizationConfig
from .optimizer import OptimizerConfig
from .predict import PredictConfig
from .preprocessing import (
PreprocessingConfig,
PreprocessingImageConfig,
Expand Down
16 changes: 15 additions & 1 deletion clinicadl/config/config/computational.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from logging import getLogger

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

from clinicadl.utils.cmdline_utils import check_gpu
from clinicadl.utils.exceptions import ClinicaDLArgumentError

logger = getLogger("clinicadl.computational_config")

Expand All @@ -13,3 +17,13 @@ class ComputationalConfig(BaseModel):
gpu: bool = True
# pydantic config
model_config = ConfigDict(validate_assignment=True)

@model_validator(mode="after")
def validator_gpu(self) -> Self:
if self.gpu:
check_gpu()
elif self.amp:
raise ClinicaDLArgumentError(
"AMP is designed to work with modern GPUs. Please add the --gpu flag."
)
return self
10 changes: 9 additions & 1 deletion clinicadl/config/config/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import NonNegativeInt

from clinicadl.utils.maps_manager.maps_manager import MapsManager

logger = getLogger("clinicadl.cross_validation_config")


Expand All @@ -19,7 +21,7 @@ class CrossValidationConfig(

n_splits: NonNegativeInt = 0
split: Optional[Tuple[NonNegativeInt, ...]] = None
tsv_directory: Path
tsv_directory: Optional[Path] = None # not needed in predict ?
# pydantic config
model_config = ConfigDict(validate_assignment=True)

Expand All @@ -28,3 +30,9 @@ def validator_split(cls, v):
if isinstance(v, list):
return tuple(v)
return v # TODO : check that split exists (and check coherence with n_splits)

def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.split:
self.split = maps_manager._find_splits()
logger.debug(f"List of splits {self.split}")
8 changes: 7 additions & 1 deletion clinicadl/config/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from clinicadl.utils.caps_dataset.data import load_data_test
from clinicadl.utils.enum import Mode
from clinicadl.utils.maps_manager.maps_manager import MapsManager
from clinicadl.utils.preprocessing import read_preprocessing

logger = getLogger("clinicadl.data_config")
Expand All @@ -24,12 +25,17 @@ class DataConfig(BaseModel): # TODO : put in data module
label: Optional[str] = None
label_code: Dict[str, int] = {}
multi_cohort: bool = False
preprocessing_json: Path
preprocessing_json: Optional[Path] = None
data_tsv: Optional[Path] = None
n_subjects: int = 300
# pydantic config
model_config = ConfigDict(validate_assignment=True)

def adapt_data_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if self.diagnoses is None or len(self.diagnoses) == 0:
self.diagnoses = maps_manager.diagnoses

def create_groupe_df(self):
group_df = None
if self.data_tsv is not None and self.data_tsv.is_file():
Expand Down
9 changes: 9 additions & 0 deletions clinicadl/config/config/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic.types import PositiveInt

from clinicadl.utils.enum import Sampler
from clinicadl.utils.maps_manager.maps_manager import MapsManager

logger = getLogger("clinicadl.dataloader_config")

Expand All @@ -16,3 +17,11 @@ class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module
sampler: Sampler = Sampler.RANDOM
# pydantic config
model_config = ConfigDict(validate_assignment=True)

def adapt_dataloader_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.batch_size:
self.batch_size = maps_manager.batch_size

if not self.n_proc:
self.n_proc = maps_manager.n_proc
14 changes: 0 additions & 14 deletions clinicadl/config/config/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,3 @@ def check_output_saving_nifti(self, network_task: str) -> None:
raise ClinicaDLArgumentError(
"Cannot save nifti if the network task is not reconstruction. Please remove --save_nifti option."
)

def adapt_config_with_maps_manager_info(self, maps_manager: MapsManager):
if not self.split_list:
self.split_list = maps_manager._find_splits()
logger.debug(f"List of splits {self.split_list}")

if self.diagnoses is None or len(self.diagnoses) == 0:
self.diagnoses = maps_manager.diagnoses

if not self.batch_size:
self.batch_size = maps_manager.batch_size

if not self.n_proc:
self.n_proc = maps_manager.n_proc
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@

from pydantic import BaseModel, field_validator

from clinicadl.config.config import (
ComputationalConfig,
CrossValidationConfig,
DataLoaderConfig,
MapsManagerConfig,
ValidationConfig,
)
from clinicadl.config.config import DataConfig as DataBaseConfig
from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp
from clinicadl.utils.caps_dataset.data import (
load_data_test,
)
from clinicadl.utils.enum import InterpretationMethod
from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore
from clinicadl.utils.maps_manager.maps_manager import MapsManager # type: ignore

logger = getLogger("clinicadl.predict_config")
logger = getLogger("clinicadl.interpret_config")


class DataConfig(DataBaseConfig):
caps_directory: Optional[Path] = None


class InterpretConfig(BaseModel):
class InterpretBaseConfig(BaseModel):
name: str
method: InterpretationMethod = InterpretationMethod.GRADIENTS
target_node: int = 0
Expand All @@ -38,3 +48,15 @@ def get_method(self) -> Gradients:
return GradCam
else:
raise ValueError(f"The method {self.method.value} is not implemented")


class InterpretConfig(
MapsManagerConfig,
InterpretBaseConfig,
DataConfig,
ValidationConfig,
CrossValidationConfig,
ComputationalConfig,
DataLoaderConfig,
):
"""Config class to perform Transfer Learning."""
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,42 @@

from pydantic import BaseModel

from clinicadl.config.config.data import DataConfig as DataBaseConfig
from clinicadl.config.config.maps_manager import (
MapsManagerConfig as MapsManagerBaseConfig,
)
from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore

from ..computational import ComputationalConfig
from ..cross_validation import CrossValidationConfig
from ..dataloader import DataLoaderConfig
from ..validation import ValidationConfig

logger = getLogger("clinicadl.predict_config")


class PredictConfig(BaseModel):
class MapsManagerConfig(MapsManagerBaseConfig):
save_tensor: bool = False
save_latent_tensor: bool = False
use_labels: bool = True

def check_output_saving_tensor(self, network_task: str) -> None:
# Check if task is reconstruction for "save_tensor" and "save_nifti"
if self.save_tensor and network_task != "reconstruction":
raise ClinicaDLArgumentError(
"Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option."
)


class DataConfig(DataBaseConfig):
use_labels: bool = True


class PredictConfig(
MapsManagerConfig,
DataConfig,
ValidationConfig,
CrossValidationConfig,
ComputationalConfig,
DataLoaderConfig,
):
"""Config class to perform Transfer Learning."""
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from clinicadl.config.config import DataConfig as BaseDataConfig
from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task

logger = getLogger("clinicadl.classification_config")
Expand Down Expand Up @@ -57,7 +57,7 @@ def list_to_tuples(cls, v):
return v


class ClassificationConfig(TrainingConfig):
class ClassificationConfig(TrainConfig):
"""
Config class for the training of a classification model.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import (
Normalization,
ReconstructionLoss,
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_to_tuples(cls, v):
return v


class ReconstructionConfig(TrainingConfig):
class ReconstructionConfig(TrainConfig):
"""
Config class for the training of a reconstruction model.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from clinicadl.config.config import DataConfig as BaseDataConfig
from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task

logger = getLogger("clinicadl.reconstruction_config")
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_to_tuples(cls, v):
return v


class RegressionConfig(TrainingConfig):
class RegressionConfig(TrainConfig):
"""
Config class for the training of a regression model.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = getLogger("clinicadl.training_config")


class TrainingConfig(BaseModel, ABC):
class TrainConfig(BaseModel, ABC):
"""

Abstract config class for the training pipeline.
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/config/options/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .task import classification, reconstruction, regression
# from .task import classification, reconstruction, regression
9 changes: 4 additions & 5 deletions clinicadl/config/options/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import click

import clinicadl.train.trainer.training_config as config
from clinicadl.config import config
from clinicadl.config.config.callbacks import CallbacksConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

emissions_calculator = click.option(
"--calculate_emissions/--dont_calculate_emissions",
default=get_default("emissions_calculator", config.CallbacksConfig),
default=get_default("emissions_calculator", CallbacksConfig),
help="Flag to allow calculate the carbon emissions during training.",
show_default=True,
)
track_exp = click.option(
"--track_exp",
"-te",
type=get_type("track_exp", config.CallbacksConfig),
default=get_default("track_exp", config.CallbacksConfig),
type=get_type("track_exp", CallbacksConfig),
default=get_default("track_exp", CallbacksConfig),
help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.",
show_default=True,
)
6 changes: 3 additions & 3 deletions clinicadl/config/options/computational.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import click

from clinicadl.config import config
from clinicadl.config.config.computational import ComputationalConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

# Computational
amp = click.option(
"--amp/--no-amp",
default=get_default("amp", config.ComputationalConfig),
default=get_default("amp", ComputationalConfig),
help="Enables automatic mixed precision during training and inference.",
show_default=True,
)
Expand All @@ -21,7 +21,7 @@
)
gpu = click.option(
"--gpu/--no-gpu",
default=get_default("gpu", config.ComputationalConfig),
default=get_default("gpu", ComputationalConfig),
help="Use GPU by default. Please specify `--no-gpu` to force using CPU.",
show_default=True,
)
9 changes: 4 additions & 5 deletions clinicadl/config/options/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import click

import clinicadl.train.trainer.training_config as config
from clinicadl.config import config
from clinicadl.config.config.cross_validation import CrossValidationConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

# Cross Validation
n_splits = click.option(
"--n_splits",
type=get_type("n_splits", config.CrossValidationConfig),
default=get_default("n_splits", config.CrossValidationConfig),
type=get_type("n_splits", CrossValidationConfig),
default=get_default("n_splits", CrossValidationConfig),
help="If a value is given for k will load data of a k-fold CV. "
"Default value (0) will load a single split.",
show_default=True,
Expand All @@ -18,7 +17,7 @@
"--split",
"-s",
type=int, # get_type("split", config.CrossValidationConfig),
default=get_default("split", config.CrossValidationConfig),
default=get_default("split", CrossValidationConfig),
multiple=True,
help="Train the list of given splits. By default, all the splits are trained.",
show_default=True,
Expand Down
Loading