From e24768f684f980614a7e67960955f19089dedaea Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 17 Jan 2024 12:41:35 -0500 Subject: [PATCH] Registration paper readme (#551) * readme and latent space comparison for registration models --- .gitattributes | 3 + ml4h/data_descriptions.py | 20 +- ml4h/metrics.py | 2 +- ml4h/recipes.py | 158 ++++++- ml4h/tensormap/mgb/xdl.py | 5 + .../registration_reveals_genetics/README.md | 57 +++ .../latent_space_comparisons.ipynb | 428 ++++++++++++++++++ .../registration.png | 3 + .../registration_reveals_genetics/table1.png | 3 + .../registration_reveals_genetics/table2.png | 3 + 10 files changed, 672 insertions(+), 10 deletions(-) create mode 100644 model_zoo/registration_reveals_genetics/README.md create mode 100644 model_zoo/registration_reveals_genetics/latent_space_comparisons.ipynb create mode 100644 model_zoo/registration_reveals_genetics/registration.png create mode 100644 model_zoo/registration_reveals_genetics/table1.png create mode 100644 model_zoo/registration_reveals_genetics/table2.png diff --git a/.gitattributes b/.gitattributes index 030ae10a4..e056953b3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -22,3 +22,6 @@ model_zoo/liver_fat_from_mri_ukb/liver_fat_from_echo_teacher_model.png filter=lf model_zoo/liver_fat_from_mri_ukb/liver_fat_from_ideal_student_model.png filter=lfs diff=lfs merge=lfs -text model_zoo/ECG_PheWAS/ukb_phewas.png filter=lfs diff=lfs merge=lfs -text model_zoo/dropfuse/overview.png filter=lfs diff=lfs merge=lfs -text +model_zoo/registration_reveals_genetics/registration.png filter=lfs diff=lfs merge=lfs -text +model_zoo/registration_reveals_genetics/table1.png filter=lfs diff=lfs merge=lfs -text +model_zoo/registration_reveals_genetics/table2.png filter=lfs diff=lfs merge=lfs -text diff --git a/ml4h/data_descriptions.py b/ml4h/data_descriptions.py index d9b41e484..57b1c42f9 100644 --- a/ml4h/data_descriptions.py +++ b/ml4h/data_descriptions.py @@ -381,15 +381,29 @@ def dataframe_data_description_from_tensor_map( dataframe: pd.DataFrame, is_input: bool = False, ) -> DataDescription: + if tensor_map.is_survival_curve(): + if tensor_map.name == 'survival_curve_af': + event_age = 'af_age' + event_column = 'survival_curve_af' + else: + event_age = f'{tensor_map.name.replace("_event", "_age")}' + event_column = tensor_map.name + return SurvivalWideFile( + wide_df=dataframe, + name=tensor_map.output_name(), + intervals=tensor_map.shape[0]//2, + event_age=event_age, + event_column=event_column, + ) if tensor_map.is_categorical(): process_col = one_hot_sex else: process_col = make_zscore(dataframe[tensor_map.name].mean(), dataframe[tensor_map.name].std()) return DataFrameDataDescription( dataframe, - col = tensor_map.name, - process_col = process_col, - name = tensor_map.input_name() if is_input else tensor_map.output_name(), + col=tensor_map.name, + process_col=process_col, + name=tensor_map.input_name() if is_input else tensor_map.output_name(), ) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index d4444f174..c5cda5aef 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -10,7 +10,7 @@ from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error -from neurite.tf.losses import Dice +#from neurite.tf.losses import Dice STRING_METRICS = [ 'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae', diff --git a/ml4h/recipes.py b/ml4h/recipes.py index fc567b846..13473e195 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -25,19 +25,18 @@ from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb from ml4h.models.model_factory import make_multimodal_multitask_model from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2 -from ml4h.explorations import test_labels_to_label_map, infer_with_pixels from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model -from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp +from ml4h.plots import plot_roc, plot_precision_recall_per_class, plot_scatter from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription -from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients +from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored +from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp - from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival @@ -71,6 +70,8 @@ def run(args): train_legacy(args) elif 'train_xdl' == args.mode: train_xdl(args) + elif 'train_xdl_af' == args.mode: + train_xdl_af(args) elif 'test' == args.mode: test_multimodal_multitask(args) elif 'compare' == args.mode: @@ -141,12 +142,13 @@ def run(args): logging.exception(e) if args.gcs_cloud_bucket is not None: - save_to_google_cloud(args) - + save_to_google_cloud(args) + end_time = timer() elapsed_time = end_time - start_time logging.info("Executed the '{}' operation in {:.2f} seconds".format(args.mode, elapsed_time)) + def save_to_google_cloud(args): """ @@ -378,6 +380,150 @@ def option_picker(sample_id, data_descriptions): merger.save(f'{args.output_folder}{args.id}/merger.h5') +def train_xdl_af(args): + mrn_df = pd.read_csv(args.app_csv) + mrn_df = mrn_df.dropna(subset=['last_encounter']) + mrn_df.MRN = mrn_df.MRN.astype(int) + mrn_df['survival_curve_af'] = mrn_df.af_event + #mrn_df['start_date'] = mrn_df.start_fu_datetime + + if 'start_fu_age' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days + elif 'start_fu' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days + for ot in args.tensor_maps_out: + mrn_df = mrn_df[mrn_df[ot.name].notna()] + + mrn_df = mrn_df.set_index('MRN') + + output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out] + + # ecg_dd = ECGDataDescription( + # args.tensors, + # name=f'input_ecg_strip_I_continuous', + # ecg_len=5000, # all ECGs will be linearly interpolated to be this length + # transforms=[standardize_by_sample_ecg], # these will be applied in order + # leads={'I': 0}, + # ) + ecg_dd = ECGDataDescription( + args.tensors, + name=args.tensor_maps_in[0].input_name(), + ecg_len=5000, # all ECGs will be linearly interpolated to be this length + transforms=[standardize_by_sample_ecg], # these will be applied in order + # data will be automatically localized from s3 + ) + + def option_picker(sample_id, data_descriptions): + ecg_dts = ecg_dd.get_loading_options(sample_id) + start_dt = output_dds[0].get_loading_options(sample_id)[0]['start_date'] + min_ecg_dt = start_dt - pd.to_timedelta("1095d") + dates = [] + for dt in ecg_dts: + if min_ecg_dt <= dt[DATE_OPTION_KEY] <= start_dt: + dates.append(dt) + if len(dates) == 0: + raise ValueError('No matching dates') + chosen_dt = np.random.choice(dates) + chosen_dt['start_date'] = start_dt + chosen_dt['day_delta'] = (start_dt - chosen_dt[DATE_OPTION_KEY]).days + return {dd: chosen_dt for dd in data_descriptions} + + logging.info(f'output_dds[0].name {output_dds[0].name}') + logging.info(f'output_dds[0].get_loading_options(sample_id)[0] {output_dds[0].get_loading_options(3773)[0]}') + logging.info(f'option_picker {option_picker(3773, [ecg_dd])}') + + sg = DataDescriptionSampleGetter( + input_data_descriptions=[ecg_dd], # what we want a model to use as input data + output_data_descriptions=output_dds, # what we want a model to predict from the input data + option_picker=option_picker, + ) + + b = sg(3773) + logging.info(f'batch output: {b[1]}') + + model, encoders, decoders, merger = make_multimodal_multitask_model(**args.__dict__) + + train_ids = list(mrn_df[mrn_df.split == 'train'].index) + valid_ids = list(mrn_df[mrn_df.split == 'valid'].index) + test_ids = list(mrn_df[mrn_df.split == 'test'].index) + + train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch) + valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch) + + num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) + num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) + + generate_train = TensorMapDataLoader2( + batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, + dataset=train_dataset, + num_workers=num_train_workers, + ) + generate_valid = TensorMapDataLoader2( + batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, + dataset=valid_dataset, + num_workers=num_valid_workers, + ) + + model = train_model_from_generators( + model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, args.epochs, + args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, args.tensor_maps_out, + save_last_model=args.save_last_model, + ) + for tm in encoders: + encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5') + for tm in decoders: + decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5') + if merger: + merger.save(f'{args.output_folder}{args.id}/merger.h5') + + test_sg = DataDescriptionSampleGetter( + input_data_descriptions=[ecg_dd], # what we want a model to use as input data + output_data_descriptions=output_dds, # what we want a model to predict from the input data + option_picker=option_picker, + ) + test_dataset = SampleGetterIterableDataset(sample_ids=list(test_ids), sample_getter=test_sg, + get_epoch=shuffle_get_epoch) + + generate_test = TensorMapDataLoader2( + batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, + dataset=test_dataset, + num_workers=num_train_workers, + ) + + y_trues = defaultdict(list) + y_preds = defaultdict(list) + logging.info(f'Now start testing... output keys: {list(next(generate_test)[1].keys())}') + for X, y in generate_test: + preds = model.predict(X, verbose=0) + if len(model.output_names) == 1: + preds = [preds] + predictions_dict = {name: pred for name, pred in zip(model.output_names, preds)} + for otm in args.tensor_maps_out: + y_preds[otm.name].extend(predictions_dict[otm.output_name()]) + y_trues[otm.name].extend(y[otm.output_name()]) + + if len(y_trues[otm.name]) % 400 == 0: + print(f'Predicted on {len(y_trues[otm.name])} test examples.') + if len(y_trues[otm.name]) > args.test_steps*args.batch_size: + break + + for otm in args.tensor_maps_out: + y_preds[otm.name] = np.array(y_preds[otm.name]) + y_trues[otm.name] = np.array(y_trues[otm.name]) + + for otm in args.tensor_maps_out: + if otm.is_survival_curve(): + plot_survival(y_preds[otm.name], y_trues[otm.name], f'{otm.name.upper()} Model:{args.id}', otm.days_window) + elif otm.is_categorical(): + plot_roc(y_preds[otm.name], y_trues[otm.name], otm.channel_map, f'{otm.name} ROC') + plot_precision_recall_per_class(y_preds[otm.name], y_trues[otm.name], otm.channel_map, + f'{otm.name} Precision Recall') + elif otm.is_continuous(): + plot_scatter(y_preds[otm.name], y_trues[otm.name], f'{otm.name} Scatter') + + def datetime_to_float(d): return pd.to_datetime(d, utc=True).timestamp() diff --git a/ml4h/tensormap/mgb/xdl.py b/ml4h/tensormap/mgb/xdl.py index fe54db2cd..cdb808300 100644 --- a/ml4h/tensormap/mgb/xdl.py +++ b/ml4h/tensormap/mgb/xdl.py @@ -5,6 +5,7 @@ from ml4h.TensorMap import TensorMap, Interpretation ecg_5000_std = TensorMap('ecg_5000_std', Interpretation.CONTINUOUS, shape=(5000, 12)) +ecg_single_lead_I = TensorMap(f'ecg_strip_I', Interpretation.CONTINUOUS, shape=(5000, 1)) hypertension_icd_only = TensorMap(name='hypertension_icd_only', interpretation=Interpretation.CATEGORICAL, channel_map={'no_hypertension_icd_only': 0, 'hypertension_icd_only': 1}) @@ -25,6 +26,10 @@ hypercholesterolemia = TensorMap(name='hypercholesterolemia', interpretation=Interpretation.CATEGORICAL, channel_map={'no_hypercholesterolemia': 0, 'hypercholesterolemia': 1}) +n_intervals = 25 +af_tmap = TensorMap('survival_curve_af', Interpretation.SURVIVAL_CURVE, shape=(n_intervals*2,),) +death_tmap = TensorMap('death_event', Interpretation.SURVIVAL_CURVE, shape=(n_intervals*2,),) + def ecg_median_biosppy(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray: tensor = np.zeros(tm.shape, dtype=np.float32) diff --git a/model_zoo/registration_reveals_genetics/README.md b/model_zoo/registration_reveals_genetics/README.md new file mode 100644 index 000000000..f932410e2 --- /dev/null +++ b/model_zoo/registration_reveals_genetics/README.md @@ -0,0 +1,57 @@ +This folder contains the code and notebooks used in our paper: ["Genetic Architectures of Medical Images Revealed by Registration of Multiple Modalities"](https://www.biorxiv.org/content/10.1101/2023.07.27.550885v1) + +In this paper we show how the systematic importance of registration for finding genetic signals directly from medical imaging modalities. +This is demonstrated across a wide array of registration techniques. +Our multimodal autoencoder comparison framework allows us to learn representations of medical images before and after registration. +The learned registration methods considered are graphically summarized here: +![Learned Registration Methods](./registration.png) + +For example, to train a uni-modal autoencoder for DXA 2 scans: +```bash + python /path/to/ml4h/ml4h/recipes.py \ + --mode train \ + --tensors /path/to/hd5_tensors/ \ + --output_folder /path/to/output/ \ + --tensormap_prefix ml4h.tensormap.ukb \ + --input_tensors dxa.dxa_2 --output_tensors dxa.dxa_2 \ + --encoder_blocks conv_encode --merge_blocks --decoder_blocks conv_decode \ + --activation swish --conv_layers 32 --conv_width 31 --dense_blocks 32 32 32 32 32 --dense_layers 256 --block_size 3 \ + --inspect_model --learning_rate 0.0001 \ + --batch_size 4 --epochs 216 --training_steps 128 --validation_steps 36 --test_steps 4 --patience 36 \ + --id dxa_2_autoencoder_256d +``` + +To train the cross-modal (DXA 2 <-> DXA5) registration with the DropFuse model the command line is: +```bash + python /path/to/ml4h/ml4h/recipes.py \ + --mode train \ + --tensors /path/to/hd5_tensors/ \ + --output_folder /path/to/output/ \ + --tensormap_prefix ml4h.tensormap.ukb \ + --input_tensors dxa.dxa_2 dxa.dxa_5 --output_tensors dxa.dxa_2 dxa.dxa_5 \ + --encoder_blocks conv_encode --merge_blocks pair --decoder_blocks conv_decode \ + --pairs dxa.dxa_2 dxa.dxa_5 --pair_loss contrastive --pair_loss_weight 0.1 --pair_merge dropout \ + --activation swish --conv_layers 32 --conv_width 31 --dense_blocks 32 32 32 32 32 --dense_layers 256 --block_size 3 \ + --inspect_model --learning_rate 0.0001 \ + --batch_size 4 --epochs 216 --training_steps 128 --validation_steps 36 --test_steps 4 --patience 36 \ + --id dxa_2_5_dropfuse_256d +``` +Similiarly, autoencoders and cross modal fusion for all the modalities considered in the paper can be trained by changing the `--input_tensors` and `--output_tensors` arguments to point at the appropriate `TensorMap`, and if necessary updating the model architecture hyperparameters. +Table 1 lists all the modalities included in the paper. +![Table of modalities](./table1.png) + +Then with latent space inference with models before and after registration we can evaluate their learned representations. +```bash + python /home/sam/ml4h/ml4h/recipes.py \ + --mode infer_encoders \ + --tensors /path/to/hd5_tensors/ \ + --output_folder /path/to/output/ \ + --tensormap_prefix ml4h.tensormap.ukb \ + --input_tensors dxa.dxa_2 --output_tensors dxa.dxa_2 \ + --id dxa_2_autoencoder_256d \ + --model_file /path/to/output/dxa_2_autoencoder_256d/dxa_2_autoencoder_256d.h5 +``` + +We compare the strength and number of biological signals found with the [Latent Space Comparisons notebook](./latent_space_comparisons.ipynb). +This notebook is used to populate the data summarized in Table 2 of the paper. +![Table of results](./table2.png) \ No newline at end of file diff --git a/model_zoo/registration_reveals_genetics/latent_space_comparisons.ipynb b/model_zoo/registration_reveals_genetics/latent_space_comparisons.ipynb new file mode 100644 index 000000000..b37703724 --- /dev/null +++ b/model_zoo/registration_reveals_genetics/latent_space_comparisons.ipynb @@ -0,0 +1,428 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import time\n", + "from collections import defaultdict, Counter\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from scipy import stats\n", + "from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge\n", + "from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score\n", + "from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score, roc_auc_score\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "from ml4h.explorations import latent_space_dataframe\n", + "\n", + "# IPython imports\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import colors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_scores = {}\n", + "label_file = '/home/sam/trained_models/explore_phenotypes_3/tensors_all_union.csv'\n", + "labels = pd.read_csv(label_file)\n", + "phenotypes = [\n", + " 'Sex_Male_0_0',\n", + " 'sex',\n", + " 'diabetes_type_2',\n", + " 'hypercholesterolemia',\n", + " 'hypertension',\n", + " 'Atrial_fibrillation',\n", + " 'age', 'bmi', 'bmi_0', 'RRInterval', \n", + " 'LVM', 'LVEDV', 'LVESV', \n", + " 'PC1', 'PC2', 'PC5',\n", + "\n", + "]\n", + "\n", + "col_rename = {f'22009_Genetic-principal-components_0_{i}': f'PC{i}' for i in range(1,41)}\n", + "col_rename['Genetic-sex_Male_0_0'] = 'sex'\n", + "col_rename['21003_Age-when-attended-assessment-centre_2_0'] = 'age'\n", + "col_rename['21001_Body-mass-index-BMI_2_0'] = 'bmi'\n", + "col_rename['21001_Body-mass-index-BMI_0_0'] = 'bmi_0'\n", + "col_rename['2887_Number-of-cigarettes-previously-smoked-daily_0_0'] = 'smoking'\n", + "col_rename['30690_Cholesterol_0_0'] = 'cholesterol'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "latent_files = {\n", + "'Brain T1': '/home/sam/trained_models/brain_t1_slice_80_autoencoder_256d_v2022_06_07/hidden_axial_80_brain_t1_slice_80_autoencoder_256d_v2022_06_07.tsv',\n", + "'Brain MNI': '/home/sam/trained_models/t1_mni_slices_48_80_autoencoder_256d/hidden_axial_68_100_t1_mni_slices_48_80_autoencoder_256d.tsv',\n", + "}\n", + "latent_size = {\n", + "'Brain T1':256,\n", + "'Brain MNI':256,\n", + "}\n", + "\n", + "pairs = [\n", + " ('Brain MNI', 'Brain T1'),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "latent_files = {\n", + "'DXA 12': '/home/sam/trained_models/dxa_12_autoencoder_256d/hidden_dxa_1_12_dxa_12_autoencoder_256d.tsv',\n", + "'DXA 12 Homography': '/home/sam/trained_models/dxa_12_homography_autoencoder_512d/hidden_dxa_1_12_dxa_12_homography_autoencoder_512d.tsv',\n", + "}\n", + "latent_size = {\n", + "'DXA 12':256,\n", + "'DXA 12 Homography':512,\n", + "}\n", + "\n", + "pairs = [\n", + " ('DXA 12 Homography', 'DXA 12'),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "latent_files = {\n", + "'DXA 2 AE': '/home/sam/trained_models/dxa_2_autoencoder_256d/hidden_dxa_1_2_dxa_2_autoencoder_256d.tsv',\n", + "'DXA 2 DF': '/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_2_dxa_2_5_dropfuse_256d.tsv',\n", + "'DXA 5 AE': '/home/sam/trained_models/dxa_5_autoencoder_256d/hidden_dxa_1_5_dxa_5_autoencoder_256d.tsv',\n", + "'DXA 5 DF': '/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_5_dxa_2_5_dropfuse_256d.tsv',\n", + "}\n", + "latent_size = {\n", + "'DXA 2 AE':256,\n", + "'DXA 2 DF':256,\n", + "'DXA 5 AE':256,\n", + "'DXA 5 DF':256, \n", + "}\n", + "\n", + "pairs = [\n", + " ('DXA 2 DF','DXA 2 AE'),\n", + " ('DXA 5 DF','DXA 5 AE'),\n", + "]\n", + "\n", + "# latent_files = {\n", + "# 'cMRI Dense AE': '/home/sam/csvs/dense_autoencoder_1_sample_cmri_inferences.tsv',\n", + "# 'cMRI Circular AE': '/home/sam/csvs/circular_autoencoder_lax_4ch_v2023_04_28.tsv',\n", + "# 'ECG 10s': '/home/sam/trained_models/ecg_rest_autoencoder_256d_v2023_05_09/hidden_strip_ecg_rest_autoencoder_256d_v2023_05_09.tsv',\n", + "# 'ECG Median': '/home/sam/trained_models/hypertuned_48m_16e_ecg_median_raw_10_autoencoder_256d/hidden_embed_hypertuned_48m_16e_ecg_median_raw_10_autoencoder_256d.tsv',\n", + "# 'cMRI AE':'/home/sam/trained_models/hypertuned_32m_8e_lax_4ch_heart_center_autoencoder_256d/hidden_lax_4ch_heart_center_hypertuned_32m_8e_lax_4ch_heart_center_autoencoder_256d.tsv', \n", + "# 'cMRI ECG DropFuse': '/home/sam/trained_models/dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07/hidden_lax_4ch_heart_center_dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07.tsv',\n", + "# 'ECG cMRI DropFuse': '/home/sam/trained_models/dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07/hidden_ecg_rest_median_raw_10_dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07.tsv',\n", + " \n", + "# }\n", + "# latent_size = {\n", + "# 'cMRI Dense AE': 50,\n", + "# 'cMRI Circular AE': 50,\n", + "\n", + "# 'ECG 10s': 256,\n", + "# 'ECG Median': 256,\n", + "# 'cMRI AE':256,\n", + "# 'cMRI ECG DropFuse':256,\n", + "# 'ECG cMRI DropFuse':256,\n", + " \n", + "# }\n", + "# pairs = [\n", + "# ('cMRI Circular AE','cMRI Dense AE'),\n", + "# ('ECG Median','ECG 10s'),\n", + "# ('cMRI ECG DropFuse','cMRI AE'),\n", + "# ('ECG cMRI DropFuse','cMRI AE'),\n", + " \n", + "# ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def fit_logistic(label_header, train, test, indexes, verbose=False):\n", + " if verbose:\n", + " print(f'{label_header} len train {len(train)} len test {len(test)}')\n", + " print(f'\\nTrain:\\n{train[label_header].value_counts()} \\n\\nTest:\\n{test[label_header].value_counts()}')\n", + " clf = LogisticRegression(penalty='elasticnet', solver='saga', class_weight='balanced', l1_ratio=0.5)\n", + " clf.fit(train[indexes], train[label_header])\n", + " \n", + " sparsity = np.mean(clf.coef_ == 0) * 100\n", + " score = clf.score(test[indexes], test[label_header])\n", + " train_score = clf.score(train[indexes], train[label_header])\n", + " auc_score = roc_auc_score(clf.predict(test[indexes]), test[label_header])\n", + " train_auc_score = roc_auc_score(clf.predict(train[indexes]), train[label_header])\n", + " if verbose:\n", + " print(f'{label_header} AUC:{auc_score:.3f} Train AUC:{train_auc_score:.3f}, Sparsity: {sparsity:.2f}\\n')\n", + " return auc_score\n", + "\n", + "def fit_linear(label_header, train, test, indexes, verbose=False):\n", + " if verbose:\n", + " print(f'{label_header} len train {len(train)} len test {len(test)}')\n", + " print(f'\\nTrain:\\n{len(train[label_header].value_counts())} \\n\\nTest:\\n{len(test[label_header].value_counts())}')\n", + " \n", + " clf = make_pipeline(StandardScaler(with_mean=True), Ridge(solver='lsqr', max_iter=250000))\n", + " clf.fit(train[indexes], train[label_header])\n", + "\n", + " score = clf.score(test[indexes], test[label_header])\n", + " train_score = clf.score(train[indexes], train[label_header])\n", + " if verbose:\n", + " print(f'{label_header} R^2:{score:.3f} Train R^2:{train_score:.3f}\\n')\n", + " return score\n", + "\n", + "def latent_space_regression(label_file, latent_file, num_features = 256, start_features=0, train_ratio = 0.6, folds=4, verbose=False):\n", + " labels = pd.read_csv(label_file)\n", + " if latent_file.split('.')[-1].lower() == 'csv':\n", + " indexes = [f'{i}' for i in range(start_features, num_features)]\n", + " latent = pd.read_csv(latent_file)\n", + " else:\n", + " indexes = [f'latent_{i}' for i in range(start_features, num_features)]\n", + " latent = pd.read_csv(latent_file, sep='\\t')\n", + " \n", + " df = pd.merge(labels, latent, left_on='fpath', right_on='sample_id', how='inner')\n", + " df = df.rename(columns=col_rename)\n", + " scores = {}\n", + " errors = {}\n", + " for label in phenotypes if len(phenotypes) else labels.columns:\n", + " try:\n", + " full = df[df[label].notna()]\n", + " if len(full[label].value_counts()) > 2:\n", + " s = []\n", + " for _ in range(folds):\n", + " train = full.sample(frac=train_ratio)\n", + " test = full.drop(train.index)\n", + " s.append(fit_linear(label, train, test, indexes, verbose))\n", + " scores[f'{label} R^2'] = np.mean(s)\n", + " errors[f'{label} R^2'] = 2*np.std(s)\n", + " else:\n", + " s = []\n", + " for _ in range(folds):\n", + " train = full.sample(frac=train_ratio)\n", + " test = full.drop(train.index)\n", + " s.append(fit_logistic(label, train, test, indexes, verbose))\n", + " scores[f'{label} AUC'] = np.mean(s)\n", + " errors[f'{label} AUC'] = 2*np.std(s) \n", + " except Exception as e:\n", + " print(f'Could not fit LR for {label} {e}')\n", + " \n", + " for k,v in sorted(scores.items(), key=lambda x: x[0].lower()):\n", + " print(f'{k} {v:.3f}')\n", + "\n", + " return scores, errors\n", + "\n", + "\n", + "def plot_nested_dictionary(all_scores):\n", + " n = 4\n", + " chack = ['tab:orange', 'tab:blue', 'tab:green']\n", + " for model in all_scores:\n", + " n = max(n, len(all_scores[model][0]))\n", + " cols = max(2, int(math.ceil(math.sqrt(n))))\n", + " rows = max(2, int(math.ceil(n / cols)))\n", + " fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5), dpi=300)\n", + " renest = defaultdict(dict)\n", + " errors = defaultdict(dict)\n", + " for model in all_scores:\n", + " for metric in all_scores[model][0]:\n", + " renest[metric][model] = all_scores[model][0][metric]\n", + " errors[metric][model] = all_scores[model][1][metric]\n", + " for metric, ax in zip(renest, axes.ravel()):\n", + " models = [k for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]\n", + " values = [v for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]\n", + " err = [v for k,v in sorted(errors[metric].items(), key=lambda x: x[0].lower())]\n", + " y_pos = np.arange(len(models))\n", + " #print(f' {len(renest[metric])} len(models) : {len(models)} metric {renest[metric]}')\n", + " ax.barh(y_pos, values, xerr=err, align='center')\n", + " ax.set_yticks(y_pos)\n", + " ax.set_yticklabels(models)\n", + " ax.invert_yaxis() # labels read top-to-bottom\n", + " if 'AUC' in metric:\n", + " ax.set_xlabel('AUROC')\n", + " else:\n", + " ax.set_xlabel('$R^2$')\n", + " \n", + " ax.barh(y_pos, values, xerr=err, align='center', color=colors.TABLEAU_COLORS)\n", + "# if len(metric.split('_')) > 1:\n", + "# metric = metric.split('_')[1] + metric[-4:]\n", + "# ax.set_title(metric.replace('R^2', '$R^2$'))\n", + " if '21001_Body-mass-index-BMI_0_0' in metric:\n", + " ax.set_title('BMI')\n", + " elif '21003_Age-when-attended-assessment-centre_2_0' in metric:\n", + " ax.set_title('Age')\n", + " elif 'Sex_Male_0_0' in metric:\n", + " ax.set_title('Sex')\n", + " else:\n", + " ax.set_title(metric.split(' ')[0]) \n", + " \n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for name in latent_files:\n", + " all_scores[name] = latent_space_regression(label_file, latent_files[name], \n", + " num_features=latent_size[name],\n", + " folds=3, train_ratio=0.8,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compare_pairs(scores, pairs):\n", + " for (p1,p2) in pairs:\n", + " s1 = scores[p1]\n", + " s2 = scores[p2]\n", + " stats = Counter()\n", + " print(f'comparing {p1} to {p2}' )\n", + " for k in s2[0]:\n", + " print(f'\\t{k} {s1[0][k]:0.3f}, {s2[0][k]:0.3f}, Diff: {(s1[0][k] - s2[0][k]):0.3f} ')\n", + " if 'R^2' in k:\n", + " stats[f'{p1} R^2 sum'] += s1[0][k]\n", + " stats[f'{p2} R^2 sum'] += s2[0][k]\n", + " stats[f'{p1} R^2 std'] += s1[1][k]\n", + " stats[f'{p2} R^2 std'] += s2[1][k]\n", + " stats['R^2 n'] += 1\n", + " elif 'AUC' in k:\n", + " stats[f'{p1} AUC sum'] += s1[0][k]\n", + " stats[f'{p2} AUC sum'] += s2[0][k]\n", + " stats[f'{p1} AUC std'] += s1[1][k]\n", + " stats[f'{p2} AUC std'] += s2[1][k]\n", + " stats['AUC n'] += 1\n", + " auc1 = stats[f'{p1} AUC sum']/stats['AUC n'] \n", + " auc_std1 = stats[f'{p1} AUC std']/(stats['AUC n'] )\n", + " auc2 = stats[f'{p2} AUC sum']/stats['AUC n']\n", + " auc_std2 = stats[f'{p2} AUC std']/(stats['AUC n'] ) \n", + " r21 = stats[f'{p1} R^2 sum']/stats['R^2 n']\n", + " r2_std1 = stats[f'{p1} R^2 std']/(stats['R^2 n'] )\n", + " r22 = stats[f'{p2} R^2 sum']/stats['R^2 n']\n", + " r2_std2 = stats[f'{p2} R^2 std']/(stats['R^2 n'] )\n", + " print(f\"\\n {p1} vs {p2} \")\n", + " print(f\"\\t\\t Mean AUCs {auc1:0.3f} ({auc1-auc_std1:0.3f}, {auc1+auc_std1:0.3f}), {auc2:0.3f} ({auc2-auc_std2:0.3f}, {auc2+auc_std2:0.3f}) \")\n", + " print(f\" \\t\\t Mean R^2 {r21:0.3f} ({r21-r2_std1:0.3f}, {r21+r2_std1:0.3f}), {r22:0.3f} ({r22-r2_std2:0.3f}, {r22+r2_std2:0.3f}) \\n\\n\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "compare_pairs(all_scores, pairs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lf='/home/sam/trained_models/dxa_5_autoencoder_256d/hidden_dxa_1_5_dxa_5_autoencoder_256d.tsv'\n", + "all_scores['DXA 5 AE'] = latent_space_regression(label_file, lf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lf='/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_5_dxa_2_5_dropfuse_256d.tsv'\n", + "all_scores['DXA 5 DF'] = latent_space_regression(label_file, lf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lf='/home/sam/trained_models/dxa_11_autoencoder_256d/hidden_dxa_1_11_dxa_11_autoencoder_256d.tsv'\n", + "all_scores['DXA 11 AE'] = latent_space_regression(label_file, lf)\n", + "lf='/home/sam/trained_models/dxa_11_12_dropfuse_256d_v2023_04_17/hidden_dxa_1_11_dxa_11_12_dropfuse_256d_v2023_04_17.tsv'\n", + "all_scores['DXA 11 DF'] = latent_space_regression(label_file, lf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lf=f'/home/sam/trained_models/hypertuned_64m_18e_lax_4ch_heart_center_autoencoder_256d/hidden_embed_hypertuned_64m_18e_lax_4ch_heart_center_autoencoder_256d.tsv'\n", + "all_scores['MRI Autoencoder 256D'] = latent_space_regression(label_file, lf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lf = f'/home/sam/csvs/03-11-2021_simclr_320-320_ukb_embeddings.csv'\n", + "all_scores['ECG PCLR 320D'] = latent_space_regression(label_file, lf, num_features=320)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_nested_dictionary(all_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/model_zoo/registration_reveals_genetics/registration.png b/model_zoo/registration_reveals_genetics/registration.png new file mode 100644 index 000000000..c3dfef211 --- /dev/null +++ b/model_zoo/registration_reveals_genetics/registration.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a61ad1186235a5c24958202a73279d4a9b0f08a13c82b85e92f40d38e4c0119c +size 529604 diff --git a/model_zoo/registration_reveals_genetics/table1.png b/model_zoo/registration_reveals_genetics/table1.png new file mode 100644 index 000000000..5c38c4414 --- /dev/null +++ b/model_zoo/registration_reveals_genetics/table1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa54d40f5143e9efad3ce5e01c53bc950bca9301a7ee6a7c8c0edcdcc3d7ea01 +size 259349 diff --git a/model_zoo/registration_reveals_genetics/table2.png b/model_zoo/registration_reveals_genetics/table2.png new file mode 100644 index 000000000..d15e00f43 --- /dev/null +++ b/model_zoo/registration_reveals_genetics/table2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0a056860884eafcd450fa22df45ab19b4513fa3e35572354187a9ed1d6b89f1 +size 344842