The Goal of this Repository is to showcase the use of Pytorch-Lightning on Medical Image data segmentation problems solved with a U-Net.
The segmentation tasks that are solved are from the Medical Segmentation Decathlon. Specifically, the repository contains code for the Hippocampus and the Heart Dataset, but you can extend it for other datasets as well.
A 2D U-Net is trained which gets single slices of 3D medical images as input.
General knowledge about Pytorch Modules & DataLoaders is advantegous, as well as BatchGenerators.
- Clone this repository
- go into the folder
- pip install .
- change unet_defaults.yml or unet_defaults_heart.yml variables to your desired directories
- save_dir --> logs are saved here
- data_input_dir --> data is saved and load from here
- Run light_seg/main.py.
If you want to use unet_defaults_heart.yml file as config file, changetohparams, nested_dict = main_cli(config_file="./unet_defaults.yml")
at the bottom of main.pyhparams, nested_dict = main_cli(config_file="./unet_defaults_heart.yml")
- Experiment with settings and code
- Change test_unet_defaults.yml to your trained checkpoint.
- Run test.py
- More information about pytorch-lightning on
In the following, the functionality of the individual files will be briefly described.
In general, there are scripts that are important for training and scripts that are important for testing your trained network.
The scripts that are important for training are:
- Config files: These files contain the configuration that is used during training. Specifically, you can set the dataset location, training parametes like the fold or the number of epochs and others.
- unet_defaults.yml: This is a default configuration which is mainly for training on the Hippocampus dataset (adapt the dataset location to your local paths)
- unet_defaults_heart.yml: This is a default configuration which is mainly for training on the Heart dataset (adapt the dataset location to your local paths)
- main.py: This is the entry point for the training. Here, the parameters are set, the datamodule, trainer and model are defined and the training loop is started.
- Datamodules: In the datamodules, the dataset is downloaded (this is currently only possible for the Hippocampus dataset, the Heart dataset needs to be downloaded manually), preprocessed and during training time, batches are loaded and augmented.
- msd_datamodule.py
- hippocampus_datamodule.py
- heart_datamodule.py
- unet_lightning.py: This file contains all the necessary parts of the training loop. There, also the U-Net model itself is instantiated.
- unet_module.py: This contains the architecture definition of the U-Net model that is trained.
- loss_modules.py: Here, the soft dice loss is defined which is used as part of the loss function during training.
The scripts that are important for testing are:
- test_unet_defaults.yml: This contains the default configuration for testing. Most importantly, you need to specify the checkpoint from your training that you want to use for your predictions. You can also specify the location of the data and the saving of the test results in case you train and test on different machines. If you train and test on the same machine, this information can also be inferred from the checkpoint.
- test.py: This is the main test loop. The test data will be loaded and processed as 2D slices. In the end, the results are saved in the original 3D format.
- data_carrier.py: This class is responsible for the correct handling of the individual 2D slices that they are correctly saved as 3D images/segmentations in the end. The results will be saved in a specific folder structure (see section below)
The results of the test run are stored as 3D nifti in a specific folder structure, providing the original input to the network, the ground truth segmentation, the predicted segmentation, the softmax predictions and the calculated metrics for the test. Normally, you should have a folder structure like this (either in your specified save_dir from test_unet_defaults.yml if specified or in the save_dir of your training):
test_results
├── <version>
│ ├── gt_seg
│ ├── input
│ ├── pred_prob
│ ├── pred_seg
│ ├── metrics.json
├── ...
The relevant code for pytorch-lightning is situated in:
- unet_lightning.py / UNetExperiment (pl.LightningModule)
- hippocampus_datamodule.py / HippocampusDataModule (pl.LightningDataModule)
- main.py / main_cli (efficient parsing and hparameter handling)
- main.py / main (use of pl.Trainer)
For more complex examples where scores need to be computed over complete datasets it is advised to use the following methods in your pl.LightiningModule:
- train_epoch_end
- validation_epoch_end
- test_epoch_end
These methods overwrite the normal behaviour of pl.EvalResult & pl.TrainResult when they are used.
In the following, the test results on the two implemented datasets are shown. The networks were trained on 50, 100 and 200 epochs.
As the results did not differ much for the different number of training epochs, here the results for 200 epochs are shown.
Furthermore, they are compared with the results of a 2D nnU-Net which was trained and tested on the same data.
Fold | 0 | 1 | 2 | 3 | 4 | Mean |
Dice Pytorch-Lightning Example | 0.8822 | 0.8809 | 0.8798 | 0.878 | 0.8805 | 0.8803 |
Dice nnU-Net | 0.8748 | 0.8764 | 0.8768 | 0.875 | 0.876 | 0.8758 |
Fold | 0 | 1 | 2 | 3 | 4 | Mean |
Dice Pytorch-Lightning Example | 0.8735 | 0.8895 | 0.8905 | 0.8989 | 0.8942 | 0.8794 |
Dice nnU-Net | 0.9227 | 0.9205 | 0.9197 | 0.9052 | 0.9183 | 0.9173 |
Of course you can use this repository as a basis to run experiments with your own datasets.
If you specifically want to adapt it for another dataset from the Medical Segmentation decathlon, you can derive from the msd_datamodule.py like for the Hippocampus and the Heart dataset.
This library is developed and maintained by the Medical Image Computing Group of the DKFZ and the Interactive Machine Learning Group of Helmholtz Imaging and the DKFZ.