diff --git a/ml4h/arguments.py b/ml4h/arguments.py index f00c6f9f9..3a8e53516 100644 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -20,7 +20,7 @@ import importlib import numpy as np import multiprocessing -from typing import Set, Dict, List, Optional +from typing import Set, Dict, List, Optional, Tuple from collections import defaultdict from ml4h.logger import load_config @@ -148,6 +148,7 @@ def parse_args(): parser.add_argument('--dense_normalize', default=None, choices=list(NORMALIZATION_CLASSES), help='Type of normalization layer for dense layers.') parser.add_argument('--activation', default='relu', help='Activation function for hidden units in neural nets dense layers.') parser.add_argument('--conv_layers', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional layers.') + parser.add_argument('--conv_width', default=[71], nargs='*', type=int, help='X dimension of convolutional kernel for 1D models. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_x', default=[3], nargs='*', type=int, help='X dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_y', default=[3], nargs='*', type=int, help='Y dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_z', default=[2], nargs='*', type=int, help='Z dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') @@ -168,7 +169,13 @@ def parse_args(): '--u_connect', nargs=2, action='append', help='U-Net connect first TensorMap to second TensorMap. They must be the same shape except for number of channels. Can be provided multiple times.', ) - parser.add_argument('--aligned_dimension', default=16, type=int, help='Dimensionality of aligned embedded space for multi-modal alignment models.') + parser.add_argument( + '--pairs', nargs=2, action='append', + help='TensorMap pairs for paired autoencoder. The pair_loss metric will encourage similar embeddings for each two input TensorMap pairs. Can be provided multiple times.', + ) + parser.add_argument('--pair_loss', default='cosine', help='Distance metric between paired embeddings', choices=['euclid', 'cosine']) + parser.add_argument('--pair_loss_weight', type=float, default=1.0, help='Weight on the pair loss term relative to other losses') + parser.add_argument('--multimodal_merge', default='average', choices=['average', 'concatenate'], help='How to merge modality specific encodings.') parser.add_argument( '--max_parameters', default=9000000, type=int, help='Maximum number of trainable parameters in a model during hyperparameter optimization.', @@ -219,6 +226,9 @@ def parse_args(): parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training') parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training') parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value') + parser.add_argument( + '--save_last_model', default=False, action='store_true', + help='If true saves the model weights from the last training epoch, otherwise the model with best validation loss is saved.') # Run specific and debugging arguments parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.') @@ -384,6 +394,14 @@ def _process_u_connect_args(u_connect: Optional[List[List]], tensormap_prefix) - return new_u_connect +def _process_pair_args(pairs: Optional[List[List]], tensormap_prefix) -> List[Tuple[TensorMap, TensorMap]]: + pairs = pairs or [] + new_pairs = [] + for pair in pairs: + new_pairs.append((tensormap_lookup(pair[0], tensormap_prefix), tensormap_lookup(pair[1], tensormap_prefix))) + return new_pairs + + def _process_args(args): now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') args_file = os.path.join(args.output_folder, args.id, 'arguments_' + now_string + '.txt') @@ -396,6 +414,7 @@ def _process_args(args): f.write(k + ' = ' + str(v) + '\n') load_config(args.logging_level, os.path.join(args.output_folder, args.id), 'log_' + now_string, args.min_sample_id) args.u_connect = _process_u_connect_args(args.u_connect, args.tensormap_prefix) + args.pairs = _process_pair_args(args.pairs, args.tensormap_prefix) args.tensor_maps_in = [] args.tensor_maps_out = [] diff --git a/ml4h/explorations.py b/ml4h/explorations.py index dc21246db..4e82968ab 100644 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -79,6 +79,21 @@ def predictions_to_pngs( ax.add_patch(matplotlib.patches.Rectangle(y_corner, y_width, y_height, linewidth=1, edgecolor='y', facecolor='none')) logging.info(f"True BBox: {corner}, {width}, {height} Predicted BBox: {y_corner}, {y_width}, {y_height} Vmin {vmin} Vmax{vmax}") plt.savefig(f"{folder}{sample_id}_bbox_batch_{i:02d}{IMAGE_EXT}") + elif tm.axes() == 2: + fig = plt.figure(figsize=(SUBPLOT_SIZE, SUBPLOT_SIZE * 3)) + for i in range(y.shape[0]): + sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') + title = f'{tm.name}_{sample_id}_reconstruction' + for j in range(tm.shape[1]): + plt.subplot(tm.shape[1], 1, j + 1) + plt.plot(labels[tm.output_name()][i, :, j], c='k', linestyle='--', label='original') + plt.plot(y[i, :, j], c='b', label='reconstruction') + if j == 0: + plt.title(title) + plt.legend() + plt.tight_layout() + plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + plt.clf() elif len(tm.shape) == 3: for i in range(y.shape[0]): sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') @@ -1332,3 +1347,49 @@ def _get_occurrences(df, order, start, end): fpath = os.path.join(args.output_folder, args.id, 'summary_cohort_counts.csv') pd.DataFrame.from_dict(cohort_counts, orient='index', columns=['count']).rename_axis('description').to_csv(fpath) logging.info(f'Saved cohort counts to {fpath}') + + +def directions_in_latent_space(stratify_column, stratify_thresh, split_column, split_thresh, latent_cols, latent_df): + hit = latent_df.loc[latent_df[stratify_column] >= stratify_thresh][latent_cols].to_numpy() + miss = latent_df.loc[latent_df[stratify_column] < stratify_thresh][latent_cols].to_numpy() + miss_mean_vector = np.mean(miss, axis=0) + hit_mean_vector = np.mean(hit, axis=0) + strat_vector = hit_mean_vector - miss_mean_vector + + hit1 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh) + & (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy() + miss1 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh) + & (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy() + hit2 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh) + & (latent_df[split_column] < split_thresh)][latent_cols].to_numpy() + miss2 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh) + & (latent_df[split_column] < split_thresh)][latent_cols].to_numpy() + miss_mean_vector1 = np.mean(miss1, axis=0) + hit_mean_vector1 = np.mean(hit1, axis=0) + angle1 = angle_between(miss_mean_vector1, hit_mean_vector1) + miss_mean_vector2 = np.mean(miss2, axis=0) + hit_mean_vector2 = np.mean(hit2, axis=0) + angle2 = angle_between(miss_mean_vector2, hit_mean_vector2) + h1_vector = hit_mean_vector1 - miss_mean_vector1 + h2_vector = hit_mean_vector2 - miss_mean_vector2 + angle3 = angle_between(h1_vector, h2_vector) + print(f'\n Between {stratify_column}, and splits: {split_column}\n', + f'Angles h1 and m1: {angle1:.2f}, h2 and m2 {angle2:.2f} h1-m1 and h2-m2 {angle3:.2f} degrees.\n' + f'stratify threshold: {stratify_thresh}, split thresh: {split_thresh}, \n' + f'hit_mean_vector2 shape {miss_mean_vector1.shape}, miss1:{hit_mean_vector2.shape} \n' + f'Hit1 shape {hit1.shape}, miss1:{miss1.shape} threshold:{stratify_thresh}\n' + f'Hit2 shape {hit2.shape}, miss2:{miss2.shape}\n') + + return hit_mean_vector1, miss_mean_vector1, hit_mean_vector2, miss_mean_vector2 + + +def latent_space_dataframe(infer_hidden_tsv, explore_csv): + df = pd.read_csv(explore_csv) + df['fpath'] = pd.to_numeric(df['fpath'], errors='coerce') + df2 = pd.read_csv(infer_hidden_tsv, sep='\t') + df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce') + latent_df = pd.merge(df, df2, left_on='fpath', right_on='sample_id', how='inner') + latent_df.info() + return latent_df + + diff --git a/ml4h/models.py b/ml4h/models.py index 493870ad1..0fde8d284 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -25,7 +25,7 @@ from tensorflow.keras.layers import SpatialDropout1D, SpatialDropout2D, SpatialDropout3D, add, concatenate from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, LSTM, RepeatVector from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D, UpSampling1D, UpSampling2D, UpSampling3D, MaxPooling1D -from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer +from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D import tensorflow_probability as tfp @@ -525,30 +525,20 @@ def l2_norm(x, axis=None): """ takes an input tensor and returns the l2 norm along specified axis """ - square_sum = K.sum(K.square(x), axis=axis, keepdims=True) norm = K.sqrt(K.maximum(square_sum, K.epsilon())) - return norm def pairwise_cosine_difference(t1, t2): - """ - A [batch x n x d] tensor of n rows with d dimensions - B [batch x m x d] tensor of n rows with d dimensions - - returns: - D [batch x n x m] tensor of cosine similarity scores between each point i Model: + opt = get_optimizer( + kwargs['optimizer'], kwargs['learning_rate'], steps_per_epoch=kwargs['training_steps'], + learning_rate_schedule=kwargs['learning_rate_schedule'], optimizer_kwargs=kwargs.get('optimizer_kwargs'), + ) + if 'model_file' in kwargs and kwargs['model_file'] is not None: + custom_dict = _get_custom_objects(kwargs['tensor_maps_out']) + encoders = {} + for tm in kwargs['tensor_maps_in']: + encoders[tm] = load_model(f"{os.path.dirname(kwargs['model_file'])}/encoder_{tm.name}.h5", custom_objects=custom_dict, compile=False) + decoders = {} + for tm in kwargs['tensor_maps_out']: + decoders[tm] = load_model(f"{os.path.dirname(kwargs['model_file'])}/decoder_{tm.name}.h5", custom_objects=custom_dict, compile=False) + logging.info(f"Attempting to load model file from: {kwargs['model_file']}") + m = load_model(kwargs['model_file'], custom_objects=custom_dict, compile=False) + m.compile(optimizer=opt, loss=custom_dict['loss']) + m.summary() + logging.info(f"Loaded model file from: {kwargs['model_file']}") + return m, encoders, decoders + + inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} + real_serial_layers = kwargs['model_layers'] + kwargs['model_layers'] = None + multimodal_activations = [] + encoders = {} + decoders = {} + outputs = {} + losses = [] + for left, right in pairs: + if left in encoders: + encode_left = encoders[left] + else: + kwargs['tensor_maps_in'] = [left] + left_model = make_multimodal_multitask_model(**kwargs) + encode_left = make_hidden_layer_model(left_model, [left], kwargs['hidden_layer']) + h_left = encode_left(inputs[left]) + + if right in encoders: + encode_right = encoders[right] + else: + kwargs['tensor_maps_in'] = [right] + right_model = make_multimodal_multitask_model(**kwargs) + encode_right = make_hidden_layer_model(right_model, [right], kwargs['hidden_layer']) + h_right = encode_right(inputs[right]) + + if pair_loss == 'cosine': + loss_layer = CosineLossLayer(pair_loss_weight) + elif pair_loss == 'euclid': + loss_layer = L2LossLayer(pair_loss_weight) + + multimodal_activations.extend(loss_layer([h_left, h_right])) + encoders[left] = encode_left + encoders[right] = encode_right + + kwargs['tensor_maps_in'] = list(inputs.keys()) + if multimodal_merge == 'average': + multimodal_activation = Average()(multimodal_activations) + elif multimodal_merge == 'concatenate': + multimodal_activation = Concatenate()(multimodal_activations) + multimodal_activation = Dense(units=kwargs['dense_layers'][0], use_bias=False)(multimodal_activation) + multimodal_activation = _activation_layer(kwargs['activation'])(multimodal_activation) + else: + raise NotImplementedError(f'No merge architecture for method: {multimodal_merge}') + latent_inputs = Input(shape=(kwargs['dense_layers'][-1]), name='input_concept_space') + + # build decoder models + for tm in kwargs['tensor_maps_out']: + if tm.axes() > 1: + shape = _calc_start_shape(num_upsamples=len(kwargs['dense_blocks']), output_shape=tm.shape, + upsample_rates=[kwargs['pool_x'], kwargs['pool_y'], kwargs['pool_z']], + channels=kwargs['dense_blocks'][-1]) + + restructure = FlatToStructure(output_shape=shape, activation=kwargs['activation'], + normalization=kwargs['dense_normalize']) + + decode = ConvDecoder( + tensor_map_out=tm, + filters_per_dense_block=kwargs['dense_blocks'][::-1], + conv_layer_type=kwargs['conv_type'], + conv_x=kwargs['conv_x'] if tm.axes() > 2 else kwargs['conv_width'], + conv_y=kwargs['conv_y'], + conv_z=kwargs['conv_z'], + block_size=kwargs['block_size'], + activation=kwargs['activation'], + normalization=kwargs['conv_normalize'], + regularization=kwargs['conv_regularize'], + regularization_rate=kwargs['conv_regularize_rate'], + upsample_x=kwargs['pool_x'], + upsample_y=kwargs['pool_y'], + upsample_z=kwargs['pool_z'], + u_connect_parents=[tm_in for tm_in in kwargs['tensor_maps_in'] if tm in kwargs['u_connect'][tm_in]], + ) + reconstruction = decode(restructure(latent_inputs), {}, {}) + else: + dense_block = FullyConnectedBlock( + widths=kwargs['dense_layers'], + activation=kwargs['activation'], + normalization=kwargs['dense_normalize'], + regularization=kwargs['dense_regularize'], + regularization_rate=kwargs['dense_regularize_rate'], + is_encoder=False, + ) + decode = DenseDecoder(tensor_map_out=tm, parents=tm.parents, activation=kwargs['activation']) + reconstruction = decode(dense_block(latent_inputs), {}, {}) + + decoder = Model(latent_inputs, reconstruction, name=tm.output_name()) + decoders[tm] = decoder + outputs[tm.output_name()] = decoder(multimodal_activation) + losses.append(tm.loss) + + m = Model(inputs=list(inputs.values()), outputs=list(outputs.values())) + my_metrics = {tm.output_name(): tm.metrics for tm in kwargs['tensor_maps_out']} + m.compile(optimizer=opt, loss=losses, metrics=my_metrics) + m.summary() + + if real_serial_layers is not None: + m.load_weights(real_serial_layers, by_name=True) + logging.info(f"Loaded model weights from:{real_serial_layers}") + + return m, encoders, decoders + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~ Training ~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1173,13 +1290,13 @@ def train_model_from_generators( inspect_show_labels: bool, return_history: bool = False, plot: bool = True, + save_last_model: bool = False, ) -> Union[Model, Tuple[Model, History]]: """Train a model from tensor generators for validation and training data. Training data lives on disk, it will be loaded by generator functions. Plots the metric history after training. Creates a directory to save weights, if necessary. Measures runtime and plots architecture diagram if inspect_model is True. - :param model: The model to optimize :param generate_train: Generator function that yields mini-batches of training data. :param generate_valid: Generator function that yields mini-batches of validation data. @@ -1193,6 +1310,9 @@ def train_model_from_generators( :param inspect_model: If True, measure training and inference runtime of the model and generate architecture plot. :param inspect_show_labels: If True, show labels on the architecture plot. :param return_history: If true return history from training and don't plot the training history + :param plot: If true, plots the metrics for train and validation set at the end of each epoch + :param save_last_model: If true saves the model weights from last epoch otherwise saves model with best validation loss + :return: The optimized model. """ model_file = os.path.join(output_folder, run_id, run_id + MODEL_EXT) @@ -1206,7 +1326,7 @@ def train_model_from_generators( history = model.fit( generate_train, steps_per_epoch=training_steps, epochs=epochs, verbose=1, validation_steps=validation_steps, validation_data=generate_valid, - callbacks=_get_callbacks(patience, model_file), + callbacks=_get_callbacks(patience, model_file, save_last_model), ) generate_train.kill_workers() generate_valid.kill_workers() @@ -1221,10 +1341,10 @@ def train_model_from_generators( def _get_callbacks( - patience: int, model_file: str, + patience: int, model_file: str, save_last_model: bool ) -> List[Callback]: callbacks = [ - ModelCheckpoint(filepath=model_file, verbose=1, save_best_only=True), + ModelCheckpoint(filepath=model_file, verbose=1, save_best_only=not save_last_model), EarlyStopping(monitor='val_loss', patience=patience * 3, verbose=1), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=patience, verbose=1), ] diff --git a/ml4h/plots.py b/ml4h/plots.py index c353d1b94..28ef0d5e8 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -22,6 +22,9 @@ from tensorflow.keras.optimizers.schedules import LearningRateSchedule import matplotlib + +from ml4h.tensor_generators import _sample_csv_to_set + matplotlib.use('Agg') # Need this to write images from the GSA servers. Order matters: import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt from matplotlib.ticker import NullFormatter @@ -29,6 +32,7 @@ from matplotlib.ticker import AutoMinorLocator, MultipleLocator from sklearn import manifold +from sklearn.decomposition import PCA from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score, roc_auc_score from sklearn.calibration import calibration_curve @@ -42,7 +46,6 @@ import ml4h.tensormap.ukb.ecg import ml4h.tensormap.mgb.ecg from ml4h.tensormap.mgb.dynamic import make_waveform_maps - from ml4h.TensorMap import TensorMap from ml4h.metrics import concordance_index, coefficient_of_determination from ml4h.defines import IMAGE_EXT, JOIN_CHAR, PDF_EXT, TENSOR_EXT, ECG_REST_LEADS, ECG_REST_MEDIAN_LEADS, PARTNERS_DATETIME_FORMAT, PARTNERS_DATE_FORMAT, HD5_GROUP_CHAR @@ -165,7 +168,7 @@ def evaluate_predictions( if tm.sentinel is not None: y_predictions = y_predictions[y_truth != tm.sentinel] y_truth = y_truth[y_truth != tm.sentinel] - _plot_reconstruction(tm, y_truth, y_predictions, folder, test_paths) + plot_reconstruction(tm, y_truth, y_predictions, folder, test_paths) if prediction_flat.shape[0] == truth_flat.shape[0]: performance_metrics.update(subplot_pearson_per_class(prediction_flat, truth_flat, tm.channel_map, protected, title, prefix=folder)) elif tm.is_continuous(): @@ -190,6 +193,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix='./figu cols = max(2, int(math.ceil(math.sqrt(total_plots)))) rows = max(2, int(math.ceil(total_plots / cols))) f, axes = plt.subplots(rows, cols, figsize=(int(cols*SUBPLOT_SIZE), int(rows*SUBPLOT_SIZE))) + logging.info(f'all keys {list(sorted(history.history.keys()))}') for k in sorted(history.history.keys()): if not k.startswith('val_'): if isinstance(history.history[k][0], LearningRateSchedule): @@ -464,7 +468,7 @@ def subplot_pearson_per_class( os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches='tight') plt.clf() - logging.info(f"Saved Pearson correlations at: {figure_path} with {len(protected)} protected TensorMaps.") + logging.info(f"{label_text} saved at: {figure_path}{f' with {len(protected)} protected TensorMaps.' if len(protected) else '.'}") return labels_to_areas @@ -1226,11 +1230,11 @@ def _plot_partners_figure( def plot_partners_ecgs(args): plot_tensors = [ ml4h.tensormap.mgb.ecg.partners_ecg_patientid, ml4h.tensormap.mgb.ecg.partners_ecg_firstname, ml4h.tensormap.mgb.ecg.partners_ecg_lastname, - ml4h.tensormap.mgb.ecg.partners_ecg_sex, ml4h.tensormap.mgb.ecg.partners_ecg_dob, ml4h.tensormap.mgb.ecg.partners_ecg_age, - ml4h.tensormap.mgb.ecg.partners_ecg_datetime, ml4h.tensormap.mgb.ecg.partners_ecg_sitename, ml4h.tensormap.mgb.ecg.partners_ecg_location, - ml4h.tensormap.mgb.ecg.partners_ecg_read_md, ml4h.tensormap.mgb.ecg.partners_ecg_taxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_rate_md, - ml4h.tensormap.mgb.ecg.partners_ecg_pr_md, ml4h.tensormap.mgb.ecg.partners_ecg_qrs_md, ml4h.tensormap.mgb.ecg.partners_ecg_qt_md, - ml4h.tensormap.mgb.ecg.partners_ecg_paxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_raxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_qtc_md, + ml4h.tensormap.mgb.ecg.partners_ecg_sex, ml4h.tensormap.mgb.ecg.partners_ecg_dob, ml4h.tensormap.mgb.ecg.partners_ecg_age, + ml4h.tensormap.mgb.ecg.partners_ecg_datetime, ml4h.tensormap.mgb.ecg.partners_ecg_sitename, ml4h.tensormap.mgb.ecg.partners_ecg_location, + ml4h.tensormap.mgb.ecg.partners_ecg_read_md, ml4h.tensormap.mgb.ecg.partners_ecg_taxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_rate_md, + ml4h.tensormap.mgb.ecg.partners_ecg_pr_md, ml4h.tensormap.mgb.ecg.partners_ecg_qrs_md, ml4h.tensormap.mgb.ecg.partners_ecg_qt_md, + ml4h.tensormap.mgb.ecg.partners_ecg_paxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_raxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_qtc_md, ] voltage_tensor = make_waveform_maps('partners_ecg_2500_raw') tensor_maps_in = plot_tensors + [voltage_tensor] @@ -1500,12 +1504,13 @@ def plot_ecg_rest( :param is_blind: if True, the plot gets blinded (helpful for review and annotation) """ map_fields_to_tmaps = { - 'ramp': ml4h.tensormap.ukb.ecg.ecg_rest_ramplitude_raw, - 'samp': ml4h.tensormap.ukb.ecg.ecg_rest_samplitude_raw, - 'aVL': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_avl, - 'Sokolow_Lyon': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_sokolow_lyon, - 'Cornell': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_cornell, - } + 'ramp': ml4h.tensormap.ukb.ecg.ecg_rest_ramplitude_raw, + 'samp': ml4h.tensormap.ukb.ecg.ecg_rest_samplitude_raw, + 'aVL': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_avl, + 'Sokolow_Lyon': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_sokolow_lyon, + 'Cornell': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_cornell, + } + raw_scale = 0.005 # Conversion from raw to mV default_yrange = ECG_REST_PLOT_DEFAULT_YRANGE # mV time_interval = 2.5 # time-interval per plot in seconds. ts_Reference data is in s, voltage measurement is 5 uv per lsb @@ -2076,32 +2081,154 @@ def _text_on_plot(axes, x, y, text, alpha=0.8, background='white'): t.set_bbox({'facecolor': background, 'alpha': alpha, 'edgecolor': background}) -def _plot_reconstruction( +def plot_reconstruction( tm: TensorMap, y_true: np.ndarray, y_pred: np.ndarray, - folder: str, paths: List[str], + folder: str, paths: List[str], num_samples: int = 4, ): - num_samples = 3 logging.info(f'Plotting {num_samples} reconstructions of {tm}.') if None in tm.shape: # can't handle dynamic shapes return + os.makedirs(os.path.dirname(folder), exist_ok=True) for i in range(num_samples): - title = f'{tm.name}_{os.path.basename(paths[i]).replace(TENSOR_EXT, "")}_reconstruction' - y = y_true[i].reshape(tm.shape) - yp = y_pred[i].reshape(tm.shape) + sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') + title = f'{tm.name}_{sample_id}_reconstruction' + y = y_true[i] + yp = y_pred[i] if tm.axes() == 2: - fig = plt.figure(figsize=(SUBPLOT_SIZE, SUBPLOT_SIZE * num_samples)) + index2channel = {v: k for k, v in tm.channel_map.items()} + fig, axes = plt.subplots(tm.shape[1], 2, figsize=(2 * SUBPLOT_SIZE, 6*SUBPLOT_SIZE)) #, sharey=True) for j in range(tm.shape[1]): - plt.subplot(tm.shape[1], 1, j + 1) - plt.plot(y[:, j], c='k', linestyle='--', label='original') - plt.plot(yp[:, j], c='b', label='reconstruction') - if j == 0: - plt.title(title) - plt.legend() + axes[j, 0].plot(y[:, j], c='k', label='original') + axes[j, 1].plot(yp[:, j], c='b', label='reconstruction') + axes[j, 0].set_title(f'Lead: {index2channel[j]}') + axes[j, 0].legend() + axes[j, 1].legend() plt.tight_layout() - # TODO: implement 3d, 4d - plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + elif tm.axes() == 3: + if tm.is_categorical(): + plt.imsave(f"{folder}{sample_id}_{tm.name}_truth_{i:02d}{IMAGE_EXT}", np.argmax(y, axis=-1), cmap='plasma') + plt.imsave(f"{folder}{sample_id}_{tm.name}_prediction_{i:02d}{IMAGE_EXT}", np.argmax(yp, axis=-1), cmap='plasma') + else: + plt.imsave(f'{folder}{sample_id}_{tm.name}_truth_{i:02d}{IMAGE_EXT}', y[:, :, 0], cmap='gray') + plt.imsave(f'{folder}{sample_id}_{tm.name}_prediction_{i:02d}{IMAGE_EXT}', yp[:, :, 0], cmap='gray') + elif tm.axes() == 4: + for j in range(y.shape[3]): + image_path_base = f'{folder}{sample_id}_{tm.name}_{i:03d}_{j:03d}' + if tm.is_categorical(): + truth = np.argmax(yp[:, :, j, :], axis=-1) + prediction = np.argmax(y[:, :, j, :], axis=-1) + plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', truth, cmap='plasma') + plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', prediction, cmap='plasma') + else: + plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', y[:, :, j, 0], cmap='gray') + plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', yp[:, :, j, 0], cmap='gray') plt.clf() -if __name__ == '__main__': - plot_noisy() +def pca_on_matrix(matrix, pca_components, prefix='./figures/'): + pca = PCA() + pca.fit(matrix) + print(f'PCA explains {100 * np.sum(pca.explained_variance_ratio_[:pca_components]):0.1f}% of variance with {pca_components} top PCA components.') + matrix_reduced = pca.transform(matrix)[:, :pca_components] + print(f'PCA reduces matrix shape:{matrix_reduced.shape} from matrix shape: {matrix.shape}') + plot_scree(pca_components, 100 * pca.explained_variance_ratio_, prefix) + return pca, matrix_reduced + + +def plot_scree(pca_components, percent_explained, prefix='./figures/'): + _ = plt.figure(figsize=(6, 4)) + plt.plot(range(len(percent_explained)), percent_explained, 'g.-', linewidth=1) + plt.axvline(x=pca_components, c='r', linewidth=3) + label = f'{np.sum(percent_explained[:pca_components]):0.1f}% of variance explained by top {pca_components} of {len(percent_explained)} components' + plt.text(pca_components + 0.02 * len(percent_explained), percent_explained[1], label) + plt.title('Scree Plot') + plt.xlabel('Principal Components') + plt.ylabel('% of Variance Explained by Each Component') + figure_path = f'{prefix}pca_{pca_components}_of_{len(percent_explained)}.png' + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) + + +def unit_vector(vector): + """ Returns the unit vector of the vector. """ + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2): + """ Returns the angle in radians between vectors 'v1' and 'v2':: + angle_between((1, 0, 0), (0, 1, 0)) + 90 + angle_between((1, 0, 0), (1, 0, 0)) + 0.0 + angle_between((1, 0, 0), (-1, 0, 0)) + 180 + """ + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) * 180 / 3.141592 + + +def stratify_latent_space(stratify_column, stratify_thresh, latent_cols, latent_df): + hit = latent_df.loc[latent_df[stratify_column] >= stratify_thresh][latent_cols].to_numpy() + miss = latent_df.loc[latent_df[stratify_column] < stratify_thresh][latent_cols].to_numpy() + miss_mean_vector = np.mean(miss, axis=0) + hit_mean_vector = np.mean(hit, axis=0) + angle = angle_between(miss_mean_vector, hit_mean_vector) + print(f'Angle between {stratify_column} and all others: {angle}, \n' + f'Hit shape {hit.shape}, miss:{miss.shape} threshold:{stratify_thresh}\n' + f'Distance: {np.linalg.norm(hit_mean_vector - miss_mean_vector):.3f}, ' + f'Hit std {np.std(hit, axis=1).mean():.3f}, miss std:{np.std(miss, axis=1).mean():.3f}\n') + return hit_mean_vector, miss_mean_vector + + +def plot_hit_to_miss_transforms(latent_df, decoders, feature='Sex_Female_0_0', prefix='./figures/', + thresh=1.0, latent_dimension=256, samples=16, scalar=3.0, cmap='plasma', test_csv=None): + latent_cols = [f'latent_{i}' for i in range(latent_dimension)] + female, male = stratify_latent_space(feature, thresh, latent_cols, latent_df) + sex_vector = female - male + if test_csv is not None: + sample_ids = [int(s) for s in _sample_csv_to_set(test_csv) if len(s) > 4 and int(s) in latent_df.index] + latent_df = latent_df.loc[sample_ids] + latent_df.info() + logging.info(f'Subset to test set with {len(sample_ids)} samples') + + samples = min(len(latent_df.index), samples) + embeddings = latent_df.iloc[:samples][latent_cols].to_numpy() + sexes = latent_df.iloc[:samples][feature].to_numpy() + logging.info(f'Embedding shape: {embeddings.shape} sexes shape: {sexes.shape}') + + sex_vectors = np.tile(sex_vector, (samples, 1)) + male_to_female = embeddings + (scalar * sex_vectors) + female_to_male = embeddings - (scalar * sex_vectors) + for dtm in decoders: + predictions = decoders[dtm].predict(embeddings) + m2f = decoders[dtm].predict(male_to_female) + f2m = decoders[dtm].predict(female_to_male) + if dtm.axes() == 3: + fig, axes = plt.subplots(max(2, samples), 2, figsize=(18, samples * 4)) + for i in range(samples): + axes[i, 0].set_title(f"{feature}: {sexes[i]} ?>== thresh: + axes[i, 1].imshow(np.argmax(f2m[i, ...], axis=-1), cmap=cmap) + axes[i, 1].set_title(f'{feature} to less than {thresh}') + else: + axes[i, 1].imshow(np.argmax(m2f[i, ...], axis=-1), cmap=cmap) + axes[i, 1].set_title(f'{feature} to more than or equal to {thresh}') + else: + axes[i, 0].imshow(predictions[i, ..., 0], cmap='gray') + if sexes[i] >= thresh: + axes[i, 1].imshow(f2m[i, ..., 0], cmap='gray') + axes[i, 1].set_title(f'{feature} to less than {thresh}') + else: + axes[i, 1].imshow(m2f[i, ..., 0], cmap='gray') + axes[i, 1].set_title(f'{feature} to more than or equal to {thresh}') + figure_path = f'{prefix}/{dtm.name}_{feature}_transform_scalar_{scalar}.png' + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 80c22f93c..56ea2a00a 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -15,15 +15,16 @@ from ml4h.defines import TENSOR_EXT, MODEL_EXT from ml4h.tensormap.tensor_map_maker import write_tensor_maps from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb -from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, explore +from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, explore, latent_space_dataframe from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, cross_reference from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator -from ml4h.models import make_character_model_plus, embed_model_predict, make_siamese_model, make_multimodal_multitask_model +from ml4h.models import make_character_model_plus, embed_model_predict, make_siamese_model, make_multimodal_multitask_model, make_paired_autoencoder_model from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients from ml4h.models import train_model_from_generators, get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model, saliency_map -from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_prediction_calibrations +from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_prediction_calibrations, \ + plot_reconstruction, plot_hit_to_miss_transforms from ml4h.tensorize.tensor_writer_ukbb import write_tensors, append_fields_from_csv, append_gene_csv, write_tensors_from_dicom_pngs, write_tensors_from_ecg_pngs from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp @@ -74,20 +75,20 @@ def run(args): ecg_dates(args.tensors, args.output_folder, args.id) elif 'plot_histograms' == args.mode: plot_histograms_of_tensors_in_pdf(args.id, args.tensors, args.output_folder, args.max_samples) - elif 'plot_heatmap' == args.mode: - plot_heatmap_of_tensors(args.id, args.tensors, args.output_folder, args.min_samples, args.max_samples) elif 'plot_resting_ecgs' == args.mode: plot_ecg_rest_mp(args.tensors, args.min_sample_id, args.max_sample_id, args.output_folder, args.num_workers) elif 'plot_partners_ecgs' == args.mode: plot_partners_ecgs(args) - elif 'tabulate_correlations' == args.mode: - tabulate_correlations_of_tensors(args.id, args.tensors, args.output_folder, args.min_samples, args.max_samples) elif 'train_shallow' == args.mode: train_shallow_model(args) elif 'train_char' == args.mode: train_char_model(args) elif 'train_siamese' == args.mode: train_siamese_model(args) + elif 'train_paired' == args.mode: + train_paired_model(args) + elif 'inspect_paired' == args.mode: + inspect_paired_model(args) elif 'write_tensor_maps' == args.mode: write_tensor_maps(args) elif 'append_continuous_csv' == args.mode: @@ -135,8 +136,8 @@ def train_multimodal_multitask(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = make_multimodal_multitask_model(**args.__dict__) model = train_model_from_generators( - model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, - args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, + model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, args.epochs, + args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, save_last_model=args.save_last_model ) out_path = os.path.join(args.output_folder, args.id + '/') @@ -285,14 +286,14 @@ def infer_multimodal_multitask(args): logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") -def hidden_inference_file_name(output_folder: str, id_: str) -> str: - return os.path.join(output_folder, id_, 'hidden_inference_' + id_ + '.tsv') +def _hidden_file_name(output_folder: str, prefix_: str, id_: str, extension_: str) -> str: + return os.path.join(output_folder, id_, prefix_ + id_ + extension_) def infer_hidden_layer_multimodal_multitask(args): stats = Counter() args.num_workers = 0 - inference_tsv = hidden_inference_file_name(args.output_folder, args.id) + inference_tsv = _hidden_file_name(args.output_folder, 'hidden_inference_', args.id, '.tsv') tsv_style_is_genetics = 'genetics' in args.tsv_style tensor_paths = [os.path.join(args.tensors, tp) for tp in sorted(os.listdir(args.tensors)) if os.path.splitext(tp)[-1].lower() == TENSOR_EXT] # hard code batch size to 1 so we can iterate over file names and generated tensors together in the tensor_paths for loop @@ -303,6 +304,7 @@ def infer_hidden_layer_multimodal_multitask(args): generate_test.set_worker_paths(tensor_paths) full_model = make_multimodal_multitask_model(**args.__dict__) embed_model = make_hidden_layer_model(full_model, args.tensor_maps_in, args.hidden_layer) + embed_model.save(_hidden_file_name(args.output_folder, f'{args.hidden_layer}_encoder_', args.id, '.h5')) dummy_input = {tm.input_name(): np.zeros((1,) + full_model.get_layer(tm.input_name()).input_shape[0][1:]) for tm in args.tensor_maps_in} dummy_out = embed_model.predict(dummy_input) latent_dimensions = int(np.prod(dummy_out.shape[1:])) @@ -392,6 +394,62 @@ def train_siamese_model(args): ) +def train_paired_model(args): + full_model, encoders, decoders = make_paired_autoencoder_model(**args.__dict__) + generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) + train_model_from_generators( + full_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, + args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, + save_last_model=True + ) + for tm in encoders: + encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5') + for tm in decoders: + decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5') + out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') + test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) + samples = min(args.test_steps * args.batch_size, 12) + predictions_list = full_model.predict(test_data) + predictions_dict = {name: pred for name, pred in zip(full_model.output_names, predictions_list)} + logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') + performance_metrics = {} + for tm in args.tensor_maps_out: + if tm.axes() == 1: + y = predictions_dict[tm.output_name()] + y_truth = np.array(test_labels[tm.output_name()]) + metrics = evaluate_predictions(tm, y, y_truth, {}, tm.name, os.path.join(args.output_folder, args.id), test_paths) + performance_metrics.update(metrics) + for i, etm in enumerate(encoders): + embed = encoders[etm].predict(test_data[etm.input_name()]) + plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) + for dtm in decoders: + reconstruction = decoders[dtm].predict(embed) + logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') + my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') + if not os.path.exists(os.path.dirname(my_out_path)): + os.makedirs(os.path.dirname(my_out_path)) + if dtm.axes() > 1: + plot_reconstruction(dtm, test_data[dtm.input_name()], reconstruction, my_out_path, test_paths, samples) + else: + evaluate_predictions(dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, test_paths) + return performance_metrics + + +def inspect_paired_model(args): + full_model, encoders, decoders = make_paired_autoencoder_model(**args.__dict__) + infer_hidden_tsv = _hidden_file_name(args.output_folder, 'hidden_inference_', args.id, '.tsv') + latent_df = latent_space_dataframe(infer_hidden_tsv, args.app_csv) + out_folder = os.path.join(args.output_folder, args.id, 'latent_transformations/') + for tm in args.tensor_maps_protected: + index2channel = {v: k for k, v in tm.channel_map.items()} + thresh = 1 if tm.is_categorical() else tm.normalization.mean + plot_hit_to_miss_transforms(latent_df, decoders, + feature=index2channel[0], + thresh=thresh, + latent_dimension=args.dense_layers[0], + prefix=out_folder) + + def plot_predictions(args): _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = make_multimodal_multitask_model(**args.__dict__) diff --git a/ml4h/tensormap/mgb/dynamic.py b/ml4h/tensormap/mgb/dynamic.py index 2708a8f85..8e858a654 100644 --- a/ml4h/tensormap/mgb/dynamic.py +++ b/ml4h/tensormap/mgb/dynamic.py @@ -23,7 +23,8 @@ INCIDENCE_CSV = '/media/erisone_snf13/lc_outcomes.csv' CARDIAC_SURGERY_OUTCOMES_CSV = '/data/sts-data/mgh-preop-ecg-outcome-labels.csv' PARTNERS_PREFIX = 'partners_ecg_rest' -WIDE_FILE = '/home/sam/ml/hf-wide-2020-08-18-with-lvh-and-lbbb.tsv' +WIDE_FILE = '/home/sam/ml/hf-wide-2020-09-15-with-lvh-and-lbbb.tsv' +#WIDE_FILE = '/home/sam/ml/mgh-wide-2020-06-25-with-mrn.tsv' def make_mgb_dynamic_tensor_maps(desired_map_name: str) -> TensorMap: @@ -511,6 +512,7 @@ def _days_to_years_float(s: str): except ValueError: return None + def _time_to_event_tensor_from_days(tm: TensorMap, has_disease: int, follow_up_days: int): tensor = np.zeros(tm.shape, dtype=np.float32) if follow_up_days > tm.days_window: @@ -535,7 +537,7 @@ def _survival_curve_tensor_from_dates(tm: TensorMap, has_disease: int, assessmen def tensor_from_wide( - file_name: str, patient_column: str = 'fpath', age_column: str = 'age', bmi_column: str = 'bmi', + file_name: str, patient_column: str = 'Mrn', age_column: str = 'age', bmi_column: str = 'bmi', sex_column: str = 'sex', hf_column: str = 'any_hf_age', start_column: str = 'start_fu', end_column: str = 'last_encounter', delimiter: str = '\t', population_normalize: int = 2000, target: str = 'ecg', skip_prevalent: bool = True, @@ -562,8 +564,11 @@ def tensor_from_wide( try: patient_key = int(float(row[patient_index])) patient_data[patient_key] = { - 'age': _days_to_years_float(row[age_index]), 'bmi': _to_float_or_none(row[bmi_index]), 'sex': row[sex_index], - 'hf_age': _days_to_years_float(row[hf_index]), 'end_age': _days_to_years_float(row[end_index]), + 'age': _days_to_years_float(row[age_index]), + 'bmi': _to_float_or_none(row[bmi_index]), + 'sex': row[sex_index], + 'hf_age': _days_to_years_float(row[hf_index]), + 'end_age': _days_to_years_float(row[end_index]), 'start_date': datetime.datetime.strptime(row[start_index], CARDIAC_SURGERY_DATE_FORMAT), } @@ -574,7 +579,7 @@ def tensor_from_wide( def tensor_from_file(tm: TensorMap, hd5: h5py.File, dependents=None): mrn_int = _hd5_filename_to_mrn_int(hd5.filename) if mrn_int not in patient_data: - raise KeyError(f'{tm.name} mrn not in legacy csv.') + raise KeyError(f'{tm.name} mrn not in csv.') if patient_data[mrn_int]['end_age'] is None or patient_data[mrn_int]['age'] is None: raise ValueError(f'{tm.name} could not find ages.') if patient_data[mrn_int]['end_age'] - patient_data[mrn_int]['age'] < 0: diff --git a/ml4h/tensormap/ukb/by_script.py b/ml4h/tensormap/ukb/by_script.py index 8f4591894..260f883e9 100644 --- a/ml4h/tensormap/ukb/by_script.py +++ b/ml4h/tensormap/ukb/by_script.py @@ -4275,8 +4275,10 @@ ukb_23098_0 = TensorMap('23098_Weight_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 78.03287707120329, 'std': 15.89979106473436}, annotation_units=1, channel_map={'23098_Weight_0_0': 0, }) ukb_23099_0 = TensorMap('23099_Body-fat-percentage_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.451786835511985, 'std': 8.547574421386416}, annotation_units=1, channel_map={'23099_Body-fat-percentage_0_0': 0, }) ukb_23099_1 = TensorMap('23099_Body-fat-percentage_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.264657151412308, 'std': 8.292322694561557}, annotation_units=1, channel_map={'23099_Body-fat-percentage_1_0': 0, }) +ukb_23099_2 = TensorMap('23099_Body-fat-percentage_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.264657151412308, 'std': 8.292322694561557}, annotation_units=1, channel_map={'23099_Body-fat-percentage_2_0': 0, }) ukb_23100_1 = TensorMap('23100_Whole-body-fat-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.313408909308837, 'std': 9.148610869548651}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_1_0': 0, }) -ukb_23100_0 = TensorMap('23100_Whole-body-fat-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.85684055488536, 'std': 9.565426323411884}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_0_0': 0, }) +ukb_23100_2 = TensorMap('23100_Whole-body-fat-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.85684055488536, 'std': 9.565426323411884}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_2_0': 0, }) +ukb_23101_2 = TensorMap('23101_Whole-body-fatfree-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 52.606156455797255, 'std': 11.126052649633138}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_2_0': 0, }) ukb_23101_1 = TensorMap('23101_Whole-body-fatfree-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 52.606156455797255, 'std': 11.126052649633138}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_1_0': 0, }) ukb_23101_0 = TensorMap('23101_Whole-body-fatfree-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 53.21773908875434, 'std': 11.496421819167042}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_0_0': 0, }) ukb_23102_0 = TensorMap('23102_Whole-body-water-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 38.947460814537706, 'std': 8.414129150401765}, annotation_units=1, channel_map={'23102_Whole-body-water-mass_0_0': 0, }) @@ -4300,44 +4302,59 @@ ukb_23110_1 = TensorMap('23110_Impedance-of-arm-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 337.08795210775764, 'std': 57.39185814366964}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_1_0': 0, }) ukb_23110_2 = TensorMap('23110_Impedance-of-arm-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 326.5779425175772, 'std': 54.696785486803726}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_2_0': 0, }) ukb_23110_0 = TensorMap('23110_Impedance-of-arm-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 332.26580591961624, 'std': 56.83465434478377}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_0_0': 0, }) +ukb_23111_2 = TensorMap('23111_Leg-fat-percentage-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.525313045647284, 'std': 10.627681452664147}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_2_0': 0, }) ukb_23111_1 = TensorMap('23111_Leg-fat-percentage-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.525313045647284, 'std': 10.627681452664147}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_1_0': 0, }) ukb_23111_0 = TensorMap('23111_Leg-fat-percentage-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 32.04635516524227, 'std': 10.694715982950983}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_0_0': 0, }) +ukb_23112_2 = TensorMap('23112_Leg-fat-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.126630082314794, 'std': 1.8301929159686312}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_2_0': 0, }) ukb_23112_1 = TensorMap('23112_Leg-fat-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.126630082314794, 'std': 1.8301929159686312}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_1_0': 0, }) ukb_23112_0 = TensorMap('23112_Leg-fat-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.308648770510948, 'std': 1.898437368686469}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_0_0': 0, }) +ukb_23113_2 = TensorMap('23113_Leg-fatfree-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.792466949363929, 'std': 1.9185583701790712}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_2_0': 0, }) ukb_23113_1 = TensorMap('23113_Leg-fatfree-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.792466949363929, 'std': 1.9185583701790712}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_1_0': 0, }) ukb_23113_0 = TensorMap('23113_Leg-fatfree-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.96966001620808, 'std': 2.0246336838262913}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_0_0': 0, }) ukb_23114_0 = TensorMap('23114_Leg-predicted-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.48647434933969, 'std': 1.9273038036404626}, annotation_units=1, channel_map={'23114_Leg-predicted-mass-right_0_0': 0, }) ukb_23114_1 = TensorMap('23114_Leg-predicted-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.322803691693691, 'std': 1.828108029602454}, annotation_units=1, channel_map={'23114_Leg-predicted-mass-right_1_0': 0, }) +ukb_23115_2 = TensorMap('23115_Leg-fat-percentage-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.41810426540285, 'std': 10.597272488375483}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_2_0': 0, }) ukb_23115_1 = TensorMap('23115_Leg-fat-percentage-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.41810426540285, 'std': 10.597272488375483}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_1_0': 0, }) ukb_23115_0 = TensorMap('23115_Leg-fat-percentage-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.951824770075714, 'std': 10.648351373024937}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_0_0': 0, }) ukb_23116_0 = TensorMap('23116_Leg-fat-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.242730601973062, 'std': 1.871902668963224}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_0_0': 0, }) ukb_23116_1 = TensorMap('23116_Leg-fat-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.064025941631329, 'std': 1.8044984851208248}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_1_0': 0, }) +ukb_23116_2 = TensorMap('23116_Leg-fat-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.064025941631329, 'std': 1.8044984851208248}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_2_0': 0, }) +ukb_23117_2 = TensorMap('23117_Leg-fatfree-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.697755051134946, 'std': 1.9100971996154277}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_2_0': 0, }) ukb_23117_1 = TensorMap('23117_Leg-fatfree-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.697755051134946, 'std': 1.9100971996154277}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_1_0': 0, }) ukb_23117_0 = TensorMap('23117_Leg-fatfree-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.866400981491417, 'std': 2.01067768245527}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_0_0': 0, }) ukb_23118_1 = TensorMap('23118_Leg-predicted-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.23430281865802, 'std': 1.819746251807764}, annotation_units=1, channel_map={'23118_Leg-predicted-mass-left_1_0': 0, }) ukb_23118_0 = TensorMap('23118_Leg-predicted-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.389605651769429, 'std': 1.914070754426116}, annotation_units=1, channel_map={'23118_Leg-predicted-mass-left_0_0': 0, }) ukb_23119_0 = TensorMap('23119_Arm-fat-percentage-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.531533189516836, 'std': 10.173278005428738}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_0_0': 0, }) ukb_23119_1 = TensorMap('23119_Arm-fat-percentage-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.08498378648042, 'std': 9.660940777028525}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_1_0': 0, }) +ukb_23119_2 = TensorMap('23119_Arm-fat-percentage-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.08498378648042, 'std': 9.660940777028525}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_2_0': 0, }) +ukb_23120_2 = TensorMap('23120_Arm-fat-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.190551259665752, 'std': 0.5962184768943481}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_2_0': 0, }) ukb_23120_1 = TensorMap('23120_Arm-fat-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.190551259665752, 'std': 0.5962184768943481}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_1_0': 0, }) ukb_23120_0 = TensorMap('23120_Arm-fat-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.239808434567098, 'std': 0.6384610930020305}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_0_0': 0, }) +ukb_23121_2 = TensorMap('23121_Arm-fatfree-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.838897480668496, 'std': 0.7716166268521579}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_2_0': 0, }) ukb_23121_1 = TensorMap('23121_Arm-fatfree-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.838897480668496, 'std': 0.7716166268521579}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_1_0': 0, }) ukb_23121_0 = TensorMap('23121_Arm-fatfree-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8931570158616644, 'std': 0.8215411754827264}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_0_0': 0, }) ukb_23122_0 = TensorMap('23122_Arm-predicted-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.7094604752076186, 'std': 0.7833267524137338}, annotation_units=1, channel_map={'23122_Arm-predicted-mass-right_0_0': 0, }) ukb_23122_1 = TensorMap('23122_Arm-predicted-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.6591219755550015, 'std': 0.7367273170117283}, annotation_units=1, channel_map={'23122_Arm-predicted-mass-right_1_0': 0, }) +ukb_23123_2 = TensorMap('23123_Arm-fat-percentage-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.945781990521326, 'std': 9.82011062734792}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_2_0': 0, }) ukb_23123_1 = TensorMap('23123_Arm-fat-percentage-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.945781990521326, 'std': 9.82011062734792}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_1_0': 0, }) ukb_23123_0 = TensorMap('23123_Arm-fat-percentage-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 30.42513586266444, 'std': 10.267361725229156}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_0_0': 0, }) ukb_23124_0 = TensorMap('23124_Arm-fat-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.3197648029289917, 'std': 0.7131871500892241}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_0_0': 0, }) ukb_23124_1 = TensorMap('23124_Arm-fat-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.2664488126122535, 'std': 0.6681665072338635}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_1_0': 0, }) +ukb_23124_2 = TensorMap('23124_Arm-fat-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.2664488126122535, 'std': 0.6681665072338635}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_2_0': 0, }) +ukb_23125_2 = TensorMap('23125_Arm-fatfree-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8760987777500615, 'std': 0.7893732349099633}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_2_0': 0, }) ukb_23125_1 = TensorMap('23125_Arm-fatfree-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8760987777500615, 'std': 0.7893732349099633}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_1_0': 0, }) ukb_23125_0 = TensorMap('23125_Arm-fatfree-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.9256461729453624, 'std': 0.8372290130379668}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_0_0': 0, }) ukb_23126_0 = TensorMap('23126_Arm-predicted-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.7400753431894462, 'std': 0.7977945107270243}, annotation_units=1, channel_map={'23126_Arm-predicted-mass-left_0_0': 0, }) ukb_23126_1 = TensorMap('23126_Arm-predicted-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.6941082564230485, 'std': 0.7532883471284786}, annotation_units=1, channel_map={'23126_Arm-predicted-mass-left_1_0': 0, }) +ukb_23127_2 = TensorMap('23127_Trunk-fat-percentage_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.227587927163885, 'std': 7.733697629204932}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_2_0': 0, }) ukb_23127_1 = TensorMap('23127_Trunk-fat-percentage_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.227587927163885, 'std': 7.733697629204932}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_1_0': 0, }) ukb_23127_0 = TensorMap('23127_Trunk-fat-percentage_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.17342071305921, 'std': 8.008510881672985}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_0_0': 0, }) +ukb_23128_2 = TensorMap('23128_Trunk-fat-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.669508605637319, 'std': 5.019106497864643}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_2_0': 0, }) ukb_23128_1 = TensorMap('23128_Trunk-fat-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.669508605637319, 'std': 5.019106497864643}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_1_0': 0, }) ukb_23128_0 = TensorMap('23128_Trunk-fat-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.735808957116548, 'std': 5.1711708664366345}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_0_0': 0, }) ukb_23129_0 = TensorMap('23129_Trunk-fatfree-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.585953690492165, 'std': 5.981176227799396}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_0_0': 0, }) ukb_23129_1 = TensorMap('23129_Trunk-fatfree-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.42263906211025, 'std': 5.885068008414232}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_1_0': 0, }) +ukb_23129_2 = TensorMap('23129_Trunk-fatfree-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.42263906211025, 'std': 5.885068008414232}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_2_0': 0, }) ukb_23130_0 = TensorMap('23130_Trunk-predicted-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 28.36869182530567, 'std': 5.804851285767162}, annotation_units=1, channel_map={'23130_Trunk-predicted-mass_0_0': 0, }) ukb_23130_1 = TensorMap('23130_Trunk-predicted-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 28.2279920179596, 'std': 5.708378220103142}, annotation_units=1, channel_map={'23130_Trunk-predicted-mass_1_0': 0, }) ukb_23200_2 = TensorMap('23200_L1L4-area_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 59.829805862939246, 'std': 8.280381408666097}, annotation_units=1, channel_map={'23200_L1L4-area_2_0': 0, }) diff --git a/ml4h/tensormap/ukb/genetics.py b/ml4h/tensormap/ukb/genetics.py index be785c8f5..4f2b6d86d 100644 --- a/ml4h/tensormap/ukb/genetics.py +++ b/ml4h/tensormap/ukb/genetics.py @@ -1,4 +1,5 @@ from ml4h.TensorMap import TensorMap, Interpretation +from ml4h.defines import StorageType from ml4h.metrics import weighted_crossentropy @@ -77,8 +78,11 @@ def _ttn_tensor_from_file(tm, hd5, dependents={}): }, ) -genetic_caucasian = TensorMap('Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', channel_map={'no_caucasian': 0, 'caucasian': 1}) +genetic_caucasian = TensorMap( + 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG, + channel_map={'no_caucasian': 0, 'Genetic-ethnic-grouping_Caucasian_0_0': 1}) + genetic_caucasian_weighted = TensorMap( - 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', - channel_map={'no_caucasian': 0, 'caucasian': 1}, loss=weighted_crossentropy([10.0, 1.0], 'caucasian_loss'), + 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG, + channel_map={'no_caucasian': 0, 'Genetic-ethnic-grouping_Caucasian_0_0': 1}, loss=weighted_crossentropy([10.0, 1.0], 'caucasian_loss'), ) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index d81ff9c90..c8a6bb323 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -215,7 +215,7 @@ def _slice_tensor_from_file(tm, hd5, dependents={}): def _segmented_dicom_slices(dicom_key_prefix, path_prefix='ukb_cardiac_mri', step=1, total_slices=50): def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}): tensor = np.zeros(tm.shape, dtype=np.float32) - if path_prefix == 'ukb_liver_mri': + if tm.axes() == 3 or path_prefix == 'ukb_liver_mri': categorical_index_slice = get_tensor_at_first_date(hd5, path_prefix, f'{dicom_key_prefix}1') categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) @@ -896,8 +896,9 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): ), ) lax_4ch_diastole_slice0_224_3d = TensorMap( - 'lax_4ch_diastole_slice0_224_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1), - normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/instance_0', 0), + 'lax_4ch_diastole_slice0_224_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1), loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/instance_0', 0), ) lax_4ch_diastole_slice0_256_3d = TensorMap( 'lax_4ch_diastole_slice0_256_3d', Interpretation.CONTINUOUS, shape=(192, 256, 1), @@ -953,8 +954,8 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): 'ukb_cardiac_mri/cine_segmented_ao_dist/instance_0', 0, ), ) -cine_segmented_ao_dist_slice0_3d = TensorMap( - 'cine_segmented_ao_dist_slice0_3d', Interpretation.CONTINUOUS, shape=(256, 256, 1), loss='logcosh', +aorta_diastole_slice0_3d = TensorMap( + 'aorta_diastole_slice0_3d', Interpretation.CONTINUOUS, shape=(192, 256, 1), loss='logcosh', normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_ao_dist/instance_0', 0), ) cine_segmented_lvot_slice0_3d = TensorMap( @@ -1179,6 +1180,10 @@ def _pad_crop_tensor(tm, hd5, dependents={}): channel_map=MRI_SAX_SEGMENTED_CHANNEL_MAP, ) +segmented_aorta_diastole = TensorMap( + 'segmented_aorta_diastole', Interpretation.CATEGORICAL, shape=(192, 256, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slices('cine_segmented_ao_dist_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, +) cine_segmented_ao_dist = TensorMap( 'cine_segmented_ao_dist', Interpretation.CATEGORICAL, shape=(160, 192, 100, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slices('cine_segmented_ao_dist_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, @@ -1251,14 +1256,30 @@ def sax_tensor_from_file(tm, hd5, dependents={}): [1.0, 40.0, 40.0], 'sax_all_diastole_segmented', ), ) +sax_all_diastole_192_segmented_weighted = TensorMap( + 'sax_all_diastole_segmented', Interpretation.CATEGORICAL, shape=(192, 192, 13, 3), + channel_map=MRI_SEGMENTED_CHANNEL_MAP, + loss=weighted_crossentropy( + [1.0, 40.0, 40.0], 'sax_all_diastole_segmented', + ), +) sax_all_diastole = TensorMap( 'sax_all_diastole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('diastole'), - dependent_map=sax_all_diastole_segmented, + path_prefix='ukb_cardiac_mri', ) sax_all_diastole_weighted = TensorMap( 'sax_all_diastole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('diastole'), - dependent_map=sax_all_diastole_segmented_weighted, + dependent_map=sax_all_diastole_segmented_weighted, path_prefix='ukb_cardiac_mri', +) + +sax_all_diastole_192 = TensorMap( + 'sax_all_diastole', shape=(192, 192, 13, 1), tensor_from_file=sax_tensor('diastole'), + dependent_map=sax_all_diastole_segmented, path_prefix='ukb_cardiac_mri', +) +sax_all_diastole_192_weighted = TensorMap( + 'sax_all_diastole', shape=(192, 192, 13, 1), tensor_from_file=sax_tensor('diastole'), + dependent_map=sax_all_diastole_segmented_weighted, path_prefix='ukb_cardiac_mri', ) sax_all_systole_segmented = TensorMap( @@ -1271,6 +1292,7 @@ def sax_tensor_from_file(tm, hd5, dependents={}): loss=weighted_crossentropy([1.0, 40.0, 40.0], 'sax_all_systole_segmented'), ) + sax_all_systole = TensorMap( 'sax_all_systole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('systole'), dependent_map=sax_all_systole_segmented, @@ -1351,6 +1373,14 @@ def _slice_tensor_from_file(tm, hd5, dependents={}): 'aorta_slice_nekoui', shape=(200, 240, 1), normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_ao_dist/instance_0', 'cine_segmented_ao_dist_nekoui_annotated_'), ) +lvot_slice_jamesp = TensorMap( + 'lvot_slice_jamesp', shape=(200, 240, 1), normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lvot/instance_0', 'cine_segmented_lvot_jamesp_annotated_'), +) +lvot_slice_nekoui = TensorMap( + 'lvot_slice_nekoui', shape=(200, 240, 1), normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lvot/instance_0', 'cine_segmented_lvot_nekoui_annotated_'), +) lax_2ch_slice_jamesp = TensorMap( 'lax_2ch_slice_jamesp', shape=(192, 160, 1), normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lax_2ch/instance_0', 'cine_segmented_lax_2ch_jamesp_annotated_'), @@ -1387,6 +1417,14 @@ def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}): 'cine_segmented_ao_dist', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slice('cine_segmented_ao_dist_nekoui_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, ) +cine_segmented_lvot_jamesp = TensorMap( + 'cine_segmented_lvot', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_LVOT_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slice('cine_segmented_lvot_jamesp_annotated_'), channel_map=MRI_LVOT_SEGMENTED_CHANNEL_MAP, +) +cine_segmented_lvot_nekoui = TensorMap( + 'cine_segmented_lvot', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_LVOT_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slice('cine_segmented_lvot_nekoui_annotated_'), channel_map=MRI_LVOT_SEGMENTED_CHANNEL_MAP, +) cine_segmented_lax_2ch_jamesp = TensorMap( 'cine_segmented_lax_2ch_slice', Interpretation.CATEGORICAL, shape=(192, 160, len(MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slice('cine_segmented_lax_2ch_jamesp_annotated_'), channel_map=MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, diff --git a/ml4h/test_utils.py b/ml4h/test_utils.py index e8d53c7fe..d729f1c23 100644 --- a/ml4h/test_utils.py +++ b/ml4h/test_utils.py @@ -35,6 +35,7 @@ ), ] +TENSOR_MAP_PAIRS = [[(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[3])], [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[3]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[4])]] TMAPS_UP_TO_4D = CONTINUOUS_TMAPS[:-1] + CATEGORICAL_TMAPS[:-1] TMAPS_5D = CONTINUOUS_TMAPS[-1:] + CATEGORICAL_TMAPS[-1:] MULTIMODAL_UP_TO_4D = [list(x) for x in product(CONTINUOUS_TMAPS[:-1], CATEGORICAL_TMAPS[:-1])] diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh index 32edfb1f3..a83fd582d 100755 --- a/scripts/jupyter.sh +++ b/scripts/jupyter.sh @@ -54,7 +54,7 @@ while getopts ":ip:ch" opt ; do ;; c) DOCKER_IMAGE=${DOCKER_IMAGE_NO_GPU} - GPU_DEVICE="" + GPU_DEVICE="" ;; :) echo "ERROR: Option -${OPTARG} requires an argument." 1>&2 @@ -99,7 +99,6 @@ ${DOCKER_COMMAND} run -it \ ${GPU_DEVICE} \ --rm \ --ipc=host \ ---hostname=$(hostname) \ -v /home/${USER}/:/home/${USER}/ \ -v /mnt/:/mnt/ \ -p 0.0.0.0:${PORT}:${PORT} \ diff --git a/scripts/tf.sh b/scripts/tf.sh index 779a0daca..d3adb6b10 100755 --- a/scripts/tf.sh +++ b/scripts/tf.sh @@ -35,19 +35,7 @@ done export GROUP_NAMES GROUP_IDS # Create string to be called in Docker's bash shell via eval; -# this creates a user, adds groups, adds user to groups, then calls the Python script -CALL_DOCKER_AS_USER=" - apt-get -y install sudo; - useradd -u $(id -u) ${USER}; - GROUP_NAMES_ARR=( \${GROUP_NAMES} ); - GROUP_IDS_ARR=( \${GROUP_IDS} ); - for (( i=0; i<\${#GROUP_NAMES_ARR[@]}; ++i )); do - echo \"Creating group\" \${GROUP_NAMES_ARR[i]} \"with gid\" \${GROUP_IDS_ARR[i]}; - groupadd -f -g \${GROUP_IDS_ARR[i]} \${GROUP_NAMES_ARR[i]}; - echo \"Adding user ${USER} to group\" \${GROUP_NAMES_ARR[i]} - usermod -aG \${GROUP_NAMES_ARR[i]} ${USER} - done; - sudo -u ${USER}" +CALL_DOCKER_AS_USER= ################### HELP TEXT ############################################ @@ -85,7 +73,7 @@ USAGE_MESSAGE ################### OPTION PARSING ####################################### -while getopts ":i:d:m:ctjrhT" opt ; do +while getopts ":i:d:m:ctjuhT" opt ; do case ${opt} in h) usage @@ -113,8 +101,19 @@ while getopts ":i:d:m:ctjrhT" opt ; do mkdir -p /home/${USER}/jupyter/root/ mkdir -p /mnt/ml4cvd/projects/${USER}/projects/jupyter/auto/ ;; - r) # Output owned by root - CALL_DOCKER_AS_USER="" + u) # this creates a user, adds groups, adds user to groups, then calls the Python script + CALL_DOCKER_AS_USER=" + apt-get -y install sudo; + useradd -u $(id -u) ${USER}; + GROUP_NAMES_ARR=( \${GROUP_NAMES} ); + GROUP_IDS_ARR=( \${GROUP_IDS} ); + for (( i=0; i<\${#GROUP_NAMES_ARR[@]}; ++i )); do + echo \"Creating group\" \${GROUP_NAMES_ARR[i]} \"with gid\" \${GROUP_IDS_ARR[i]}; + groupadd -f -g \${GROUP_IDS_ARR[i]} \${GROUP_NAMES_ARR[i]}; + echo \"Adding user ${USER} to group\" \${GROUP_NAMES_ARR[i]} + usermod -aG \${GROUP_NAMES_ARR[i]} ${USER} + done; + sudo -u ${USER}" ;; T) PYTHON_COMMAND=${TEST_COMMAND} diff --git a/tests/conftest.py b/tests/conftest.py index b553cf848..1b5cfde2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,9 @@ from ml4h.test_utils import build_hdf5s -def pytest_configure(): +def pytest_configure(config): pytest.N_TENSORS = 50 + config.addinivalue_line("markers", "slow: mark tests as slow") #@mock.patch.dict(TMAPS, MOCK_TMAPS) diff --git a/tests/test_models.py b/tests/test_models.py index 61df950cb..5c19868b1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,7 +7,8 @@ from typing import List, Optional, Dict, Tuple, Iterator from ml4h.TensorMap import TensorMap -from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, check_no_bottleneck +from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, \ + check_no_bottleneck, make_paired_autoencoder_model from ml4h.test_utils import TMAPS_UP_TO_4D, MULTIMODAL_UP_TO_4D, CATEGORICAL_TMAPS, CONTINUOUS_TMAPS, SEGMENT_IN, SEGMENT_OUT, PARENT_TMAPS, CYCLE_PARENTS from ml4h.test_utils import LANGUAGE_TMAP_1HOT_WINDOW, LANGUAGE_TMAP_1HOT_SOFTMAX @@ -18,14 +19,14 @@ 'dense_layers': [4, 2], 'dense_blocks': [5, 3], 'block_size': 3, - 'conv_width': 3, 'learning_rate': 1e-3, 'optimizer': 'adam', 'conv_type': 'conv', 'conv_layers': [6, 5, 3], - 'conv_x': [3], - 'conv_y': [3], - 'conv_z': [2], + 'conv_width': [71]*5, + 'conv_x': [3]*5, + 'conv_y': [3]*5, + 'conv_z': [2]*5, 'padding': 'same', 'max_pools': [], 'pool_type': 'max', @@ -39,6 +40,16 @@ 'dense_regularize_rate': .1, 'dense_normalize': 'batch_norm', 'bottleneck_type': BottleneckType.FlattenRestructure, + 'pair_loss': 'cosine', + 'training_steps': 12, + 'learning_rate': 0.00001, + 'epochs': 6, + 'optimizer': 'adam', + 'learning_rate_schedule': None, + 'model_layers': None, + 'model_file': None, + 'hidden_layer': 'embed', + 'u_connect': defaultdict(dict), } @@ -54,19 +65,20 @@ def make_training_data(input_tmaps: List[TensorMap], output_tmaps: List[TensorMa ), ]) -def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None): +def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None, skip_shape_check: bool = False): if m is None: m = make_multimodal_multitask_model( input_tmaps, output_tmaps, **DEFAULT_PARAMS, ) - for tmap, tensor in zip(input_tmaps, m.inputs): - assert tensor.shape[1:] == tmap.shape - assert tensor.shape[1:] == tmap.shape - for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs): - assert tensor.shape[1:] == tmap.shape - assert tensor.shape[1:] == tmap.shape + if not skip_shape_check: + for tmap, tensor in zip(input_tmaps, m.inputs): + assert tensor.shape[1:] == tmap.shape + assert tensor.shape[1:] == tmap.shape + for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs): + assert tensor.shape[1:] == tmap.shape + assert tensor.shape[1:] == tmap.shape data = make_training_data(input_tmaps, output_tmaps) history = m.fit(data, steps_per_epoch=2, epochs=2, validation_data=data, validation_steps=2) for tmap in output_tmaps: @@ -294,8 +306,8 @@ def test_parents(self, output_tmaps): def test_language_models(self, input_output_tmaps, tmpdir): params = DEFAULT_PARAMS.copy() m = make_multimodal_multitask_model( - input_output_tmaps[0], - input_output_tmaps[1], + tensor_maps_in=input_output_tmaps[0], + tensor_maps_out=input_output_tmaps[1], **params ) assert_model_trains(input_output_tmaps[0], input_output_tmaps[1], m) @@ -309,6 +321,72 @@ def test_language_models(self, input_output_tmaps, tmpdir): **DEFAULT_PARAMS, ) + @pytest.mark.parametrize( + 'pairs', + [ + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1])], + [(CATEGORICAL_TMAPS[2], CATEGORICAL_TMAPS[1])], + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[3])] + ], + ) + def test_paired_models(self, pairs, tmpdir): + params = DEFAULT_PARAMS.copy() + pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) + params['u_connect'] = {tm: [] for tm in pair_list} + m, encoders, decoders = make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list, + **params + ) + assert_model_trains(pair_list, pair_list, m, skip_shape_check=True) + m.save(os.path.join(tmpdir, 'paired_ae.h5')) + path = os.path.join(tmpdir, f'm{MODEL_EXT}') + m.save(path) + make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list, + **params + ) + + @pytest.mark.parametrize( + 'pairs', + [ + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1])], + [(CATEGORICAL_TMAPS[2], CATEGORICAL_TMAPS[1])], + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[3])] + ], + ) + @pytest.mark.parametrize( + 'output_tmaps', + [ + [CONTINUOUS_TMAPS[0]], + [CATEGORICAL_TMAPS[0]], + [CONTINUOUS_TMAPS[0], CATEGORICAL_TMAPS[0]], + ], + ) + def test_semi_supervised_paired_models(self, pairs, output_tmaps, tmpdir): + params = DEFAULT_PARAMS.copy() + pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) + params['u_connect'] = {tm: [] for tm in pair_list} + m, encoders, decoders = make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list+output_tmaps, + **params + ) + assert_model_trains(pair_list, pair_list+output_tmaps, m, skip_shape_check=True) + m.save(os.path.join(tmpdir, 'paired_ae.h5')) + path = os.path.join(tmpdir, f'm{MODEL_EXT}') + m.save(path) + make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list+output_tmaps, + **params + ) + @pytest.mark.parametrize( 'tmaps', [_rotate(PARENT_TMAPS, i) for i in range(len(PARENT_TMAPS))], diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 468dc44c6..df8f4a1e3 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -3,7 +3,7 @@ import pandas as pd import numpy as np -from ml4h.recipes import inference_file_name, hidden_inference_file_name +from ml4h.recipes import inference_file_name, _hidden_file_name from ml4h.recipes import train_multimodal_multitask, compare_multimodal_multitask_models from ml4h.recipes import infer_multimodal_multitask, infer_hidden_layer_multimodal_multitask from ml4h.recipes import compare_multimodal_scalar_task_models, _find_learning_rate @@ -42,7 +42,7 @@ def test_infer_genetics(self, default_arguments): def test_infer_hidden(self, default_arguments): infer_hidden_layer_multimodal_multitask(default_arguments) - tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id) + tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id) inferred = pd.read_csv(tsv, sep='\t') assert len(set(inferred['sample_id'])) == pytest.N_TENSORS @@ -50,7 +50,7 @@ def test_infer_hidden_genetics(self, default_arguments): default_arguments.tsv_style = 'genetics' infer_hidden_layer_multimodal_multitask(default_arguments) default_arguments.tsv_style = 'standard' - tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id) + tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id) inferred = pd.read_csv(tsv, sep='\t') assert len(set(inferred['FID'])) == pytest.N_TENSORS