Skip to content

Commit

Permalink
tests ok
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 9, 2024
1 parent c0b424c commit 44926cd
Showing 1 changed file with 65 additions and 2 deletions.
67 changes: 65 additions & 2 deletions clinicadl/interpret/config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from logging import getLogger
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel, field_validator

from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp
from clinicadl.maps_manager.config import MapsManagerConfig
from clinicadl.maps_manager.config import MapsManagerConfigBase
from clinicadl.maps_manager.maps_manager import MapsManager
from clinicadl.predictor.validation import ValidationConfig
from clinicadl.splitter.config import SplitConfig
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.computational.computational import ComputationalConfig
from clinicadl.utils.enum import InterpretationMethod
from clinicadl.utils.exceptions import ClinicaDLArgumentError

logger = getLogger("clinicadl.interpret_config")


class MapsManagerConfig(MapsManagerConfigBase):
def check_output_saving_tensor(self, network_task: str) -> None:
# Check if task is reconstruction for "save_tensor" and "save_nifti"
if self.save_tensor and network_task != "reconstruction":
raise ClinicaDLArgumentError(
"Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option."
)


class DataConfig(DataBaseConfig):
caps_directory: Optional[Path] = None

Expand Down Expand Up @@ -54,3 +66,54 @@ class InterpretConfig(BaseModel):
dataloader: DataLoaderConfig
split: SplitConfig
interpret: InterpretBaseConfig

def __init__(self, **kwargs):
super().__init__(
maps_manager=kwargs,
computational=kwargs,
dataloader=kwargs,
data=kwargs,
split=kwargs,
validation=kwargs,
transforms=kwargs,
)

def _update(self, config_dict: Dict[str, Any]) -> None:
"""Updates the configs with a dict given by the user."""
self.data.__dict__.update(config_dict)
self.split.__dict__.update(config_dict)
self.validation.__dict__.update(config_dict)
self.maps_manager.__dict__.update(config_dict)
self.split.__dict__.update(config_dict)
self.computational.__dict__.update(config_dict)
self.dataloader.__dict__.update(config_dict)
self.transforms.__dict__.update(config_dict)

def adapt_with_maps_manager_info(self, maps_manager: MapsManager):
self.maps_manager.check_output_saving_nifti(maps_manager.network_task)
self.data.diagnoses = (
maps_manager.diagnoses
if self.data.diagnoses is None or len(self.data.diagnoses) == 0
else self.data.diagnoses
)

self.dataloader.batch_size = (
maps_manager.batch_size
if not self.dataloader.batch_size
else self.dataloader.batch_size
)
self.dataloader.n_proc = (
maps_manager.n_proc
if not self.dataloader.n_proc
else self.dataloader.n_proc
)

self.split.adapt_cross_val_with_maps_manager_info(maps_manager)
self.maps_manager.check_output_saving_tensor(maps_manager.network_task)

self.transforms = TransformsConfig(
normalize=maps_manager.normalize,
data_augmentation=maps_manager.data_augmentation,
size_reduction=maps_manager.size_reduction,
size_reduction_factor=maps_manager.size_reduction_factor,
)

0 comments on commit 44926cd

Please sign in to comment.