Skip to content

Commit

Permalink
splitter in init trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 14, 2024
1 parent 264e884 commit 8b740d5
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,5 @@ def cli(**kwargs):
options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs)
config = ClassificationConfig(**options)
trainer = Trainer(config)

trainer.train(split_list=config.split.split, overwrite=True)
5 changes: 4 additions & 1 deletion clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:

self.maps_manager = MapsManager(_config.maps_manager.maps_dir)
self._config.adapt_with_maps_manager_info(self.maps_manager)

print(self._config.data.model_dump())
tmp = self._config.data.model_dump(
exclude=set(["preprocessing_dict", "mode", "caps_dict"])
)
print(tmp)
tmp.update(self._config.split.model_dump())
print(tmp)
tmp.update(self._config.validation.model_dump())
print(tmp)
self.splitter = Splitter(SplitterConfig(**tmp))

def predict(
Expand Down
21 changes: 21 additions & 0 deletions clinicadl/splitter/splitter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import abc
import shutil
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd

from clinicadl.splitter.config import SplitterConfig
from clinicadl.utils import cluster
from clinicadl.utils.exceptions import MAPSError

logger = getLogger("clinicadl.split_manager")

Expand Down Expand Up @@ -214,3 +217,21 @@ def _check_item(self, item):
raise IndexError(
f"Split index {item} out of allowed splits {self.allowed_splits_list}."
)

def check_split_list(self, maps_path, overwrite):
existing_splits = []
for split in self.split_iterator():
split_path = maps_path / f"split-{split}"
if split_path.is_dir():
if overwrite:
if cluster.master:
shutil.rmtree(split_path)
else:
existing_splits.append(split)

if len(existing_splits) > 0:
raise MAPSError(
f"Splits {existing_splits} already exist. Please "
f"specify a list of splits not intersecting the previous list, "
f"or use overwrite to erase previously trained splits."
)
81 changes: 43 additions & 38 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations # noqa: I001

import shutil

from contextlib import nullcontext
from datetime import datetime
from logging import getLogger
Expand Down Expand Up @@ -71,6 +71,12 @@ def __init__(
predict_config = PredictConfig(**config.get_dict())
self.validator = Predictor(predict_config)
self._check_args()
### test
splitter_config = SplitterConfig(**self.config.get_dict())
self.splitter = Splitter(splitter_config)
self.splitter.check_split_list(
self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite
)

def _init_maps_manager(self, config) -> MapsManager:
# temporary: to match CLI data. TODO : change CLI data
Expand Down Expand Up @@ -161,12 +167,12 @@ def resume(self, splits: List[int]) -> None:
"""
stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir))
finished_splits = set(find_finished_splits(self.maps_manager.maps_path))
# TODO : check these two lines. Why do we need a split_manager?
# TODO : check these two lines. Why do we need a self.splitter?

splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config)
self.splitter = Splitter(splitter_config)

split_iterator = split_manager.split_iterator()
split_iterator = self.splitter.split_iterator()
###
absent_splits = set(split_iterator) - stopped_splits - finished_splits

Expand Down Expand Up @@ -214,24 +220,23 @@ def train(
If splits specified in input already exist and overwrite is False.
"""

self.check_split_list(split_list=split_list, overwrite=overwrite)

if self.config.ssda.ssda_network:
self._train_ssda(split_list, resume=False)

else:
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config)
# splitter_config = SplitterConfig(**self.config.get_dict())
# self.splitter = Splitter(splitter_config)
# self.splitter.check_split_list(self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite)

for split in split_manager.split_iterator():
for split in self.splitter.split_iterator():
logger.info(f"Training split {split}")
seed_everything(
self.config.reproducibility.seed,
self.config.reproducibility.deterministic,
self.config.reproducibility.compensation,
)

split_df_dict = split_manager[split]
split_df_dict = self.splitter[split]

if self.config.model.multi_network:
resume, first_network = self.init_first_network(False, split)
Expand All @@ -242,25 +247,25 @@ def train(
else:
self._train_single(split, split_df_dict, resume=False)

def check_split_list(self, split_list, overwrite):
existing_splits = []
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config)
for split in split_manager.split_iterator():
split_path = self.maps_manager.maps_path / f"split-{split}"
if split_path.is_dir():
if overwrite:
if cluster.master:
shutil.rmtree(split_path)
else:
existing_splits.append(split)

if len(existing_splits) > 0:
raise MAPSError(
f"Splits {existing_splits} already exist. Please "
f"specify a list of splits not intersecting the previous list, "
f"or use overwrite to erase previously trained splits."
)
# def check_split_list(self, split_list, overwrite):
# existing_splits = []
# splitter_config = SplitterConfig(**self.config.get_dict())
# self.splitter = Splitter(splitter_config)
# for split in self.splitter.split_iterator():
# split_path = self.maps_manager.maps_path / f"split-{split}"
# if split_path.is_dir():
# if overwrite:
# if cluster.master:
# shutil.rmtree(split_path)
# else:
# existing_splits.append(split)

# if len(existing_splits) > 0:
# raise MAPSError(
# f"Splits {existing_splits} already exist. Please "
# f"specify a list of splits not intersecting the previous list, "
# f"or use overwrite to erase previously trained splits."
# )

def _resume(
self,
Expand All @@ -282,8 +287,8 @@ def _resume(
"""
missing_splits = []
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config)
for split in split_manager.split_iterator():
self.splitter = Splitter(splitter_config)
for split in self.splitter.split_iterator():
if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir():
missing_splits.append(split)

Expand All @@ -296,15 +301,15 @@ def _resume(
if self.config.ssda.ssda_network:
self._train_ssda(split_list, resume=True)
else:
for split in split_manager.split_iterator():
for split in self.splitter.split_iterator():
logger.info(f"Training split {split}")
seed_everything(
self.config.reproducibility.seed,
self.config.reproducibility.deterministic,
self.config.reproducibility.compensation,
)

split_df_dict = split_manager[split]
split_df_dict = self.splitter[split]
if self.config.model.multi_network:
resume, first_network = self.init_first_network(True, split)
for network in range(first_network, self.maps_manager.num_networks):
Expand Down Expand Up @@ -474,19 +479,19 @@ def _train_ssda(

splitter_config = SplitterConfig(**self.config.get_dict())

split_manager = Splitter(splitter_config)
split_manager_target_lab = Splitter(splitter_config)
self.splitter = Splitter(splitter_config)
self.splitter_target_lab = Splitter(splitter_config)

for split in split_manager.split_iterator():
for split in self.splitter.split_iterator():
logger.info(f"Training split {split}")
seed_everything(
self.config.reproducibility.seed,
self.config.reproducibility.deterministic,
self.config.reproducibility.compensation,
)

split_df_dict = split_manager[split]
split_df_dict_target_lab = split_manager_target_lab[split]
split_df_dict = self.splitter[split]
split_df_dict_target_lab = self.splitter_target_lab[split]

logger.debug("Loading source training data...")
data_train_source = return_dataset(
Expand Down
11 changes: 4 additions & 7 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ def test_interpret(cmdopt, tmp_path, test_name):
labels_dir_str = str(input_dir / "labels_list" / "2_fold")
maps_tmp_out_dir = str(tmp_out_dir / "maps")
if test_name == "classification":
caps_input = str(input_dir / "caps_image")
cnn_input = [
"train",
"classification",
caps_input,
str(input_dir / "caps_image"),
"t1-linear_mode-image.json",
labels_dir_str,
maps_tmp_out_dir,
Expand All @@ -43,11 +42,10 @@ def test_interpret(cmdopt, tmp_path, test_name):
]

elif test_name == "regression":
caps_input = str(input_dir / "caps_patch")
cnn_input = [
"train",
"regression",
caps_input,
str(input_dir / "caps_patch"),
"t1-linear_mode-patch.json",
labels_dir_str,
maps_tmp_out_dir,
Expand All @@ -65,10 +63,10 @@ def test_interpret(cmdopt, tmp_path, test_name):
if cmdopt["no-gpu"]:
cnn_input.append("--no-gpu")

run_interpret(cnn_input, tmp_out_dir, ref_dir, caps_input)
run_interpret(cnn_input, tmp_out_dir, ref_dir)


def run_interpret(cnn_input, tmp_out_dir, ref_dir, caps_input):
def run_interpret(cnn_input, tmp_out_dir, ref_dir):
from clinicadl.utils.enum import InterpretationMethod

maps_path = tmp_out_dir / "maps"
Expand All @@ -84,7 +82,6 @@ def run_interpret(cnn_input, tmp_out_dir, ref_dir, caps_input):
data_group="train",
name=f"test-{method}",
method_cls=method,
caps_directory=caps_input,
)
interpret_manager = Predictor(interpret_config)
interpret_manager.interpret()
Expand Down

0 comments on commit 8b740d5

Please sign in to comment.