In this repository we show the implementation of a machine learning medical image registration method. The method was submitted to the Learn2Reg 2020 Challenge. The method is based on 3D downsampled CNN pyramid wherein displacement fields are estimated and refined at each level.
We are using Tensorflow 2.0.2. To install all necessary libaries run:
conda env create --file environment.yml
conda activate tf2
Our methods is inspired by [PWC-Net], a 2D optical flow method popular in computer vision. Below is an overview of the architecture and a detail graph of the operations at each feature level.
Description of the components:
- Pyramid: Downsamples the moving and fixed image into several feature map levels using CNN layers. The same pyramid is used for the moving and the fixed images.
- Warp (W): Warps features from moving images with the estimated displacement field.
- Affine (A): A dense neural network that estimates the 12 parameters in an affine transformation.
- Cost volume (CV): Correlation between the warped feature maps from the moving image and feature maps from the fixed image. For computational reasons the cost volume is restricted to the voxel neighbourhood.
- Deform (D): A CNN that estimates the displacement field based on the affine displacement field, the cost volume and the feature maps from the fixed image.
- Upsample (U): Upsamples the estimated displacement field from one level to the next.
TBD
see (https://learn2reg.grand-challenge.org/Datasets/) for instructions.
To run the traning and testing script we assuming the datasets are organized like this:
+-- task_02
| +-- pairs_val.csv
| +-- NIFTI
+-- task_02
| +-- pairs_val.csv
| +-- training
+-- task_03
| +-- pairs_val.csv
| +-- Training
+-- task_04
| +-- pairs_val.csv
| +-- Training
+-- Test
| +-- task_01
| | +-- pairs_val.csv
| | +-- NIFTI
| +-- task_02
| | +-- Training
| +-- task_03
| | +-- Training
| +-- task_04
| | +-- Training
Train the model using images (and segmentations) for Task 2, 3 and 4 run
python train_model.py -ds {path to dataset root} -gpus {gpu numbers}
ex:
python train_model.py -ds /data/Learn2Reg/ -gpus 0,1,2
To fine tune the model on a specific task run:
python train_tf_task{TASK #}.py -ds {path to dataset root} -gpus {gpu numbers}
ex:
python train_tf_task2.py -ds /data/Learn2Reg/ -gpus 0,1,2
or feel free to modify, create your own training procedure
Create submission
TBD