Skip to content

Commit

Permalink
adapt without preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 24, 2024
1 parent cd3a4e7 commit 6e191ab
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 5 deletions.
13 changes: 13 additions & 0 deletions clinicadl/caps_dataset/extraction/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from logging import getLogger
from pathlib import Path
from time import time
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -31,6 +32,10 @@ class ExtractionConfig(BaseModel):

@field_validator("extract_json", mode="before")
def compute_extract_json(cls, v: str):
if isinstance(v, Path):
v = str(v)
elif isinstance(v, bool):
v = None
if v is None:
return f"extract_{int(time())}.json"
elif not v.endswith(".json"):
Expand Down Expand Up @@ -75,3 +80,11 @@ class ExtractionROIConfig(ExtractionConfig):
roi_custom_mask_pattern: str = ""
roi_background_value: int = 0
extract_method: ExtractionMethod = ExtractionMethod.ROI


ALL_EXTRACTION_TYPES = Union[
ExtractionImageConfig,
ExtractionROIConfig,
ExtractionSliceConfig,
ExtractionPatchConfig,
]
12 changes: 11 additions & 1 deletion clinicadl/caps_dataset/preprocessing/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from logging import getLogger
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -207,3 +207,13 @@ def caps_nii(self) -> tuple:

def get_filetype(self) -> FileType:
return self.linear_nii()


ALL_PREPROCESSING_TYPES = Union[
T1PreprocessingConfig,
T2PreprocessingConfig,
FlairPreprocessingConfig,
PETPreprocessingConfig,
CustomPreprocessingConfig,
DTIPreprocessingConfig,
]
3 changes: 2 additions & 1 deletion clinicadl/caps_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def get_preprocessing_and_mode_from_parameters(**kwargs):
if "preprocessing_dict" in kwargs:
kwargs = kwargs["preprocessing_dict"]

print(kwargs)
preprocessing = Preprocessing(kwargs["preprocessing"])
mode = ExtractionMethod(kwargs["mode"])
mode = ExtractionMethod(kwargs["extract_method"])
return get_preprocessing(preprocessing)(**kwargs), get_extraction(mode)(**kwargs)


Expand Down
1 change: 0 additions & 1 deletion clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def _check_args(self, parameters):
mandatory_arguments = [
"caps_directory",
"tsv_path",
"mode",
"network_task",
]
for arg in mandatory_arguments:
Expand Down
11 changes: 11 additions & 0 deletions clinicadl/trainer/config/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from clinicadl.callbacks.config import CallbacksConfig
from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.caps_dataset.extraction.config import ALL_EXTRACTION_TYPES
from clinicadl.caps_dataset.preprocessing.config import ALL_PREPROCESSING_TYPES
from clinicadl.config.config.lr_scheduler import LRschedulerConfig
from clinicadl.config.config.reproducibility import ReproducibilityConfig
from clinicadl.maps_manager.config import MapsManagerConfig
Expand Down Expand Up @@ -42,9 +44,11 @@ class TrainConfig(BaseModel, ABC):
data: DataConfig
dataloader: DataLoaderConfig
early_stopping: EarlyStoppingConfig
extraction: ALL_EXTRACTION_TYPES
lr_scheduler: LRschedulerConfig
maps_manager: MapsManagerConfig
model: NetworkConfig
preprocessing: ALL_PREPROCESSING_TYPES
optimization: OptimizationConfig
optimizer: OptimizerConfig
reproducibility: ReproducibilityConfig
Expand All @@ -55,6 +59,9 @@ class TrainConfig(BaseModel, ABC):
# pydantic config
model_config = ConfigDict(validate_assignment=True)

# @field_validator("preprocessing", mode="before")
# def check_preprocessing(cls, v: str):

@computed_field
@property
@abstractmethod
Expand All @@ -68,9 +75,11 @@ def __init__(self, **kwargs):
data=kwargs,
dataloader=kwargs,
early_stopping=kwargs,
extraction=kwargs,
lr_scheduler=kwargs,
maps_manager=kwargs,
model=kwargs,
preprocessing=kwargs,
optimization=kwargs,
optimizer=kwargs,
reproducibility=kwargs,
Expand All @@ -87,9 +96,11 @@ def _update(self, config_dict: Dict[str, Any]) -> None:
self.data.__dict__.update(config_dict)
self.dataloader.__dict__.update(config_dict)
self.early_stopping.__dict__.update(config_dict)
self.extraction.__dict__.update(config_dict)
self.lr_scheduler.__dict__.update(config_dict)
self.maps_manager.__dict__.update(config_dict)
self.model.__dict__.update(config_dict)
self.preprocessing.__dict__.update(config_dict)
self.optimization.__dict__.update(config_dict)
self.optimizer.__dict__.update(config_dict)
self.reproducibility.__dict__.update(config_dict)
Expand Down
5 changes: 4 additions & 1 deletion clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ def from_json(
config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch
config_dict["maps_dir"] = maps_path
config_dict["split"] = split if split else ()

from clinicadl.utils.iotools.trainer_utils import read_multi_level_dict

config_object = create_training_config(config_dict["network_task"])(
**config_dict
**read_multi_level_dict(config_dict)
)
return cls(config_object)

Expand Down
11 changes: 10 additions & 1 deletion clinicadl/utils/iotools/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from pathlib import Path


def read_multi_level_dict(dict_):
parameters = {}
for key in dict_:
if isinstance(dict_[key], dict):
parameters.update(dict_[key])
else:
parameters[key] = dict_[key]
return parameters


def create_parameters_dict(config):
parameters = {}
config_dict = config.model_dump()
Expand All @@ -20,7 +30,6 @@ def create_parameters_dict(config):
if parameters["data_augmentation"] == ():
parameters["data_augmentation"] = False

del parameters["preprocessing_json"]
# if "tsv_path" in parameters:
# parameters["tsv_path"] = parameters["tsv_path"]
# del parameters["tsv_path"]
Expand Down

0 comments on commit 6e191ab

Please sign in to comment.