Skip to content

Commit

Permalink
Registration paper readme (#551)
Browse files Browse the repository at this point in the history
* readme and latent space comparison for registration models
  • Loading branch information
lucidtronix authored Jan 17, 2024
1 parent 0064b83 commit e24768f
Show file tree
Hide file tree
Showing 10 changed files with 672 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ model_zoo/liver_fat_from_mri_ukb/liver_fat_from_echo_teacher_model.png filter=lf
model_zoo/liver_fat_from_mri_ukb/liver_fat_from_ideal_student_model.png filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG_PheWAS/ukb_phewas.png filter=lfs diff=lfs merge=lfs -text
model_zoo/dropfuse/overview.png filter=lfs diff=lfs merge=lfs -text
model_zoo/registration_reveals_genetics/registration.png filter=lfs diff=lfs merge=lfs -text
model_zoo/registration_reveals_genetics/table1.png filter=lfs diff=lfs merge=lfs -text
model_zoo/registration_reveals_genetics/table2.png filter=lfs diff=lfs merge=lfs -text
20 changes: 17 additions & 3 deletions ml4h/data_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,29 @@ def dataframe_data_description_from_tensor_map(
dataframe: pd.DataFrame,
is_input: bool = False,
) -> DataDescription:
if tensor_map.is_survival_curve():
if tensor_map.name == 'survival_curve_af':
event_age = 'af_age'
event_column = 'survival_curve_af'
else:
event_age = f'{tensor_map.name.replace("_event", "_age")}'
event_column = tensor_map.name
return SurvivalWideFile(
wide_df=dataframe,
name=tensor_map.output_name(),
intervals=tensor_map.shape[0]//2,
event_age=event_age,
event_column=event_column,
)
if tensor_map.is_categorical():
process_col = one_hot_sex
else:
process_col = make_zscore(dataframe[tensor_map.name].mean(), dataframe[tensor_map.name].std())
return DataFrameDataDescription(
dataframe,
col = tensor_map.name,
process_col = process_col,
name = tensor_map.input_name() if is_input else tensor_map.output_name(),
col=tensor_map.name,
process_col=process_col,
name=tensor_map.input_name() if is_input else tensor_map.output_name(),
)


Expand Down
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error

from neurite.tf.losses import Dice
#from neurite.tf.losses import Dice

STRING_METRICS = [
'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae',
Expand Down
158 changes: 152 additions & 6 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,18 @@
from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb
from ml4h.models.model_factory import make_multimodal_multitask_model
from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels
from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions
from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model
from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import plot_roc, plot_precision_recall_per_class, plot_scatter
from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv
from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored
from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice
from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp

from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations
from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival
Expand Down Expand Up @@ -71,6 +70,8 @@ def run(args):
train_legacy(args)
elif 'train_xdl' == args.mode:
train_xdl(args)
elif 'train_xdl_af' == args.mode:
train_xdl_af(args)
elif 'test' == args.mode:
test_multimodal_multitask(args)
elif 'compare' == args.mode:
Expand Down Expand Up @@ -141,12 +142,13 @@ def run(args):
logging.exception(e)

if args.gcs_cloud_bucket is not None:
save_to_google_cloud(args)
save_to_google_cloud(args)

end_time = timer()
elapsed_time = end_time - start_time
logging.info("Executed the '{}' operation in {:.2f} seconds".format(args.mode, elapsed_time))


def save_to_google_cloud(args):

"""
Expand Down Expand Up @@ -378,6 +380,150 @@ def option_picker(sample_id, data_descriptions):
merger.save(f'{args.output_folder}{args.id}/merger.h5')


def train_xdl_af(args):
mrn_df = pd.read_csv(args.app_csv)
mrn_df = mrn_df.dropna(subset=['last_encounter'])
mrn_df.MRN = mrn_df.MRN.astype(int)
mrn_df['survival_curve_af'] = mrn_df.af_event
#mrn_df['start_date'] = mrn_df.start_fu_datetime

if 'start_fu_age' in mrn_df:
mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days
elif 'start_fu' in mrn_df:
mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days
for ot in args.tensor_maps_out:
mrn_df = mrn_df[mrn_df[ot.name].notna()]

mrn_df = mrn_df.set_index('MRN')

output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out]

# ecg_dd = ECGDataDescription(
# args.tensors,
# name=f'input_ecg_strip_I_continuous',
# ecg_len=5000, # all ECGs will be linearly interpolated to be this length
# transforms=[standardize_by_sample_ecg], # these will be applied in order
# leads={'I': 0},
# )
ecg_dd = ECGDataDescription(
args.tensors,
name=args.tensor_maps_in[0].input_name(),
ecg_len=5000, # all ECGs will be linearly interpolated to be this length
transforms=[standardize_by_sample_ecg], # these will be applied in order
# data will be automatically localized from s3
)

def option_picker(sample_id, data_descriptions):
ecg_dts = ecg_dd.get_loading_options(sample_id)
start_dt = output_dds[0].get_loading_options(sample_id)[0]['start_date']
min_ecg_dt = start_dt - pd.to_timedelta("1095d")
dates = []
for dt in ecg_dts:
if min_ecg_dt <= dt[DATE_OPTION_KEY] <= start_dt:
dates.append(dt)
if len(dates) == 0:
raise ValueError('No matching dates')
chosen_dt = np.random.choice(dates)
chosen_dt['start_date'] = start_dt
chosen_dt['day_delta'] = (start_dt - chosen_dt[DATE_OPTION_KEY]).days
return {dd: chosen_dt for dd in data_descriptions}

logging.info(f'output_dds[0].name {output_dds[0].name}')
logging.info(f'output_dds[0].get_loading_options(sample_id)[0] {output_dds[0].get_loading_options(3773)[0]}')
logging.info(f'option_picker {option_picker(3773, [ecg_dd])}')

sg = DataDescriptionSampleGetter(
input_data_descriptions=[ecg_dd], # what we want a model to use as input data
output_data_descriptions=output_dds, # what we want a model to predict from the input data
option_picker=option_picker,
)

b = sg(3773)
logging.info(f'batch output: {b[1]}')

model, encoders, decoders, merger = make_multimodal_multitask_model(**args.__dict__)

train_ids = list(mrn_df[mrn_df.split == 'train'].index)
valid_ids = list(mrn_df[mrn_df.split == 'valid'].index)
test_ids = list(mrn_df[mrn_df.split == 'test'].index)

train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)

num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)

generate_train = TensorMapDataLoader2(
batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out,
dataset=train_dataset,
num_workers=num_train_workers,
)
generate_valid = TensorMapDataLoader2(
batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out,
dataset=valid_dataset,
num_workers=num_valid_workers,
)

model = train_model_from_generators(
model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, args.epochs,
args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, args.tensor_maps_out,
save_last_model=args.save_last_model,
)
for tm in encoders:
encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5')
for tm in decoders:
decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5')
if merger:
merger.save(f'{args.output_folder}{args.id}/merger.h5')

test_sg = DataDescriptionSampleGetter(
input_data_descriptions=[ecg_dd], # what we want a model to use as input data
output_data_descriptions=output_dds, # what we want a model to predict from the input data
option_picker=option_picker,
)
test_dataset = SampleGetterIterableDataset(sample_ids=list(test_ids), sample_getter=test_sg,
get_epoch=shuffle_get_epoch)

generate_test = TensorMapDataLoader2(
batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out,
dataset=test_dataset,
num_workers=num_train_workers,
)

y_trues = defaultdict(list)
y_preds = defaultdict(list)
logging.info(f'Now start testing... output keys: {list(next(generate_test)[1].keys())}')
for X, y in generate_test:
preds = model.predict(X, verbose=0)
if len(model.output_names) == 1:
preds = [preds]
predictions_dict = {name: pred for name, pred in zip(model.output_names, preds)}
for otm in args.tensor_maps_out:
y_preds[otm.name].extend(predictions_dict[otm.output_name()])
y_trues[otm.name].extend(y[otm.output_name()])

if len(y_trues[otm.name]) % 400 == 0:
print(f'Predicted on {len(y_trues[otm.name])} test examples.')
if len(y_trues[otm.name]) > args.test_steps*args.batch_size:
break

for otm in args.tensor_maps_out:
y_preds[otm.name] = np.array(y_preds[otm.name])
y_trues[otm.name] = np.array(y_trues[otm.name])

for otm in args.tensor_maps_out:
if otm.is_survival_curve():
plot_survival(y_preds[otm.name], y_trues[otm.name], f'{otm.name.upper()} Model:{args.id}', otm.days_window)
elif otm.is_categorical():
plot_roc(y_preds[otm.name], y_trues[otm.name], otm.channel_map, f'{otm.name} ROC')
plot_precision_recall_per_class(y_preds[otm.name], y_trues[otm.name], otm.channel_map,
f'{otm.name} Precision Recall')
elif otm.is_continuous():
plot_scatter(y_preds[otm.name], y_trues[otm.name], f'{otm.name} Scatter')


def datetime_to_float(d):
return pd.to_datetime(d, utc=True).timestamp()

Expand Down
5 changes: 5 additions & 0 deletions ml4h/tensormap/mgb/xdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ml4h.TensorMap import TensorMap, Interpretation

ecg_5000_std = TensorMap('ecg_5000_std', Interpretation.CONTINUOUS, shape=(5000, 12))
ecg_single_lead_I = TensorMap(f'ecg_strip_I', Interpretation.CONTINUOUS, shape=(5000, 1))

hypertension_icd_only = TensorMap(name='hypertension_icd_only', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_only': 0, 'hypertension_icd_only': 1})
Expand All @@ -25,6 +26,10 @@
hypercholesterolemia = TensorMap(name='hypercholesterolemia', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypercholesterolemia': 0, 'hypercholesterolemia': 1})

n_intervals = 25
af_tmap = TensorMap('survival_curve_af', Interpretation.SURVIVAL_CURVE, shape=(n_intervals*2,),)
death_tmap = TensorMap('death_event', Interpretation.SURVIVAL_CURVE, shape=(n_intervals*2,),)


def ecg_median_biosppy(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
tensor = np.zeros(tm.shape, dtype=np.float32)
Expand Down
57 changes: 57 additions & 0 deletions model_zoo/registration_reveals_genetics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
This folder contains the code and notebooks used in our paper: ["Genetic Architectures of Medical Images Revealed by Registration of Multiple Modalities"](https://www.biorxiv.org/content/10.1101/2023.07.27.550885v1)

In this paper we show how the systematic importance of registration for finding genetic signals directly from medical imaging modalities.
This is demonstrated across a wide array of registration techniques.
Our multimodal autoencoder comparison framework allows us to learn representations of medical images before and after registration.
The learned registration methods considered are graphically summarized here:
![Learned Registration Methods](./registration.png)

For example, to train a uni-modal autoencoder for DXA 2 scans:
```bash
python /path/to/ml4h/ml4h/recipes.py \
--mode train \
--tensors /path/to/hd5_tensors/ \
--output_folder /path/to/output/ \
--tensormap_prefix ml4h.tensormap.ukb \
--input_tensors dxa.dxa_2 --output_tensors dxa.dxa_2 \
--encoder_blocks conv_encode --merge_blocks --decoder_blocks conv_decode \
--activation swish --conv_layers 32 --conv_width 31 --dense_blocks 32 32 32 32 32 --dense_layers 256 --block_size 3 \
--inspect_model --learning_rate 0.0001 \
--batch_size 4 --epochs 216 --training_steps 128 --validation_steps 36 --test_steps 4 --patience 36 \
--id dxa_2_autoencoder_256d
```

To train the cross-modal (DXA 2 <-> DXA5) registration with the DropFuse model the command line is:
```bash
python /path/to/ml4h/ml4h/recipes.py \
--mode train \
--tensors /path/to/hd5_tensors/ \
--output_folder /path/to/output/ \
--tensormap_prefix ml4h.tensormap.ukb \
--input_tensors dxa.dxa_2 dxa.dxa_5 --output_tensors dxa.dxa_2 dxa.dxa_5 \
--encoder_blocks conv_encode --merge_blocks pair --decoder_blocks conv_decode \
--pairs dxa.dxa_2 dxa.dxa_5 --pair_loss contrastive --pair_loss_weight 0.1 --pair_merge dropout \
--activation swish --conv_layers 32 --conv_width 31 --dense_blocks 32 32 32 32 32 --dense_layers 256 --block_size 3 \
--inspect_model --learning_rate 0.0001 \
--batch_size 4 --epochs 216 --training_steps 128 --validation_steps 36 --test_steps 4 --patience 36 \
--id dxa_2_5_dropfuse_256d
```
Similiarly, autoencoders and cross modal fusion for all the modalities considered in the paper can be trained by changing the `--input_tensors` and `--output_tensors` arguments to point at the appropriate `TensorMap`, and if necessary updating the model architecture hyperparameters.
Table 1 lists all the modalities included in the paper.
![Table of modalities](./table1.png)

Then with latent space inference with models before and after registration we can evaluate their learned representations.
```bash
python /home/sam/ml4h/ml4h/recipes.py \
--mode infer_encoders \
--tensors /path/to/hd5_tensors/ \
--output_folder /path/to/output/ \
--tensormap_prefix ml4h.tensormap.ukb \
--input_tensors dxa.dxa_2 --output_tensors dxa.dxa_2 \
--id dxa_2_autoencoder_256d \
--model_file /path/to/output/dxa_2_autoencoder_256d/dxa_2_autoencoder_256d.h5
```

We compare the strength and number of biological signals found with the [Latent Space Comparisons notebook](./latent_space_comparisons.ipynb).
This notebook is used to populate the data summarized in Table 2 of the paper.
![Table of results](./table2.png)
Loading

0 comments on commit e24768f

Please sign in to comment.