Skip to content

Commit

Permalink
Paired multimodal autoencoders (with tests!) (#393)
Browse files Browse the repository at this point in the history
Adds recipe `train_paired` for multimodal models with losses to encourage the same embeddings of different modalities.
  • Loading branch information
lucidtronix authored Nov 13, 2020
1 parent 0889e86 commit 198ee81
Show file tree
Hide file tree
Showing 14 changed files with 642 additions and 116 deletions.
23 changes: 21 additions & 2 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import importlib
import numpy as np
import multiprocessing
from typing import Set, Dict, List, Optional
from typing import Set, Dict, List, Optional, Tuple
from collections import defaultdict

from ml4h.logger import load_config
Expand Down Expand Up @@ -148,6 +148,7 @@ def parse_args():
parser.add_argument('--dense_normalize', default=None, choices=list(NORMALIZATION_CLASSES), help='Type of normalization layer for dense layers.')
parser.add_argument('--activation', default='relu', help='Activation function for hidden units in neural nets dense layers.')
parser.add_argument('--conv_layers', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional layers.')
parser.add_argument('--conv_width', default=[71], nargs='*', type=int, help='X dimension of convolutional kernel for 1D models. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.')
parser.add_argument('--conv_x', default=[3], nargs='*', type=int, help='X dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.')
parser.add_argument('--conv_y', default=[3], nargs='*', type=int, help='Y dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.')
parser.add_argument('--conv_z', default=[2], nargs='*', type=int, help='Z dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.')
Expand All @@ -168,7 +169,13 @@ def parse_args():
'--u_connect', nargs=2, action='append',
help='U-Net connect first TensorMap to second TensorMap. They must be the same shape except for number of channels. Can be provided multiple times.',
)
parser.add_argument('--aligned_dimension', default=16, type=int, help='Dimensionality of aligned embedded space for multi-modal alignment models.')
parser.add_argument(
'--pairs', nargs=2, action='append',
help='TensorMap pairs for paired autoencoder. The pair_loss metric will encourage similar embeddings for each two input TensorMap pairs. Can be provided multiple times.',
)
parser.add_argument('--pair_loss', default='cosine', help='Distance metric between paired embeddings', choices=['euclid', 'cosine'])
parser.add_argument('--pair_loss_weight', type=float, default=1.0, help='Weight on the pair loss term relative to other losses')
parser.add_argument('--multimodal_merge', default='average', choices=['average', 'concatenate'], help='How to merge modality specific encodings.')
parser.add_argument(
'--max_parameters', default=9000000, type=int,
help='Maximum number of trainable parameters in a model during hyperparameter optimization.',
Expand Down Expand Up @@ -219,6 +226,9 @@ def parse_args():
parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training')
parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training')
parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value')
parser.add_argument(
'--save_last_model', default=False, action='store_true',
help='If true saves the model weights from the last training epoch, otherwise the model with best validation loss is saved.')

# Run specific and debugging arguments
parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.')
Expand Down Expand Up @@ -392,6 +402,14 @@ def _process_u_connect_args(u_connect: Optional[List[List]], tensormap_prefix) -
return new_u_connect


def _process_pair_args(pairs: Optional[List[List]], tensormap_prefix) -> List[Tuple[TensorMap, TensorMap]]:
pairs = pairs or []
new_pairs = []
for pair in pairs:
new_pairs.append((tensormap_lookup(pair[0], tensormap_prefix), tensormap_lookup(pair[1], tensormap_prefix)))
return new_pairs


def _process_args(args):
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
args_file = os.path.join(args.output_folder, args.id, 'arguments_' + now_string + '.txt')
Expand All @@ -404,6 +422,7 @@ def _process_args(args):
f.write(k + ' = ' + str(v) + '\n')
load_config(args.logging_level, os.path.join(args.output_folder, args.id), 'log_' + now_string, args.min_sample_id)
args.u_connect = _process_u_connect_args(args.u_connect, args.tensormap_prefix)
args.pairs = _process_pair_args(args.pairs, args.tensormap_prefix)

args.tensor_maps_in = []
args.tensor_maps_out = []
Expand Down
61 changes: 61 additions & 0 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ def predictions_to_pngs(
ax.add_patch(matplotlib.patches.Rectangle(y_corner, y_width, y_height, linewidth=1, edgecolor='y', facecolor='none'))
logging.info(f"True BBox: {corner}, {width}, {height} Predicted BBox: {y_corner}, {y_width}, {y_height} Vmin {vmin} Vmax{vmax}")
plt.savefig(f"{folder}{sample_id}_bbox_batch_{i:02d}{IMAGE_EXT}")
elif tm.axes() == 2:
fig = plt.figure(figsize=(SUBPLOT_SIZE, SUBPLOT_SIZE * 3))
for i in range(y.shape[0]):
sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '')
title = f'{tm.name}_{sample_id}_reconstruction'
for j in range(tm.shape[1]):
plt.subplot(tm.shape[1], 1, j + 1)
plt.plot(labels[tm.output_name()][i, :, j], c='k', linestyle='--', label='original')
plt.plot(y[i, :, j], c='b', label='reconstruction')
if j == 0:
plt.title(title)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(folder, title + IMAGE_EXT))
plt.clf()
elif len(tm.shape) == 3:
for i in range(y.shape[0]):
sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '')
Expand Down Expand Up @@ -1332,3 +1347,49 @@ def _get_occurrences(df, order, start, end):
fpath = os.path.join(args.output_folder, args.id, 'summary_cohort_counts.csv')
pd.DataFrame.from_dict(cohort_counts, orient='index', columns=['count']).rename_axis('description').to_csv(fpath)
logging.info(f'Saved cohort counts to {fpath}')


def directions_in_latent_space(stratify_column, stratify_thresh, split_column, split_thresh, latent_cols, latent_df):
hit = latent_df.loc[latent_df[stratify_column] >= stratify_thresh][latent_cols].to_numpy()
miss = latent_df.loc[latent_df[stratify_column] < stratify_thresh][latent_cols].to_numpy()
miss_mean_vector = np.mean(miss, axis=0)
hit_mean_vector = np.mean(hit, axis=0)
strat_vector = hit_mean_vector - miss_mean_vector

hit1 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh)
& (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy()
miss1 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh)
& (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy()
hit2 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh)
& (latent_df[split_column] < split_thresh)][latent_cols].to_numpy()
miss2 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh)
& (latent_df[split_column] < split_thresh)][latent_cols].to_numpy()
miss_mean_vector1 = np.mean(miss1, axis=0)
hit_mean_vector1 = np.mean(hit1, axis=0)
angle1 = angle_between(miss_mean_vector1, hit_mean_vector1)
miss_mean_vector2 = np.mean(miss2, axis=0)
hit_mean_vector2 = np.mean(hit2, axis=0)
angle2 = angle_between(miss_mean_vector2, hit_mean_vector2)
h1_vector = hit_mean_vector1 - miss_mean_vector1
h2_vector = hit_mean_vector2 - miss_mean_vector2
angle3 = angle_between(h1_vector, h2_vector)
print(f'\n Between {stratify_column}, and splits: {split_column}\n',
f'Angles h1 and m1: {angle1:.2f}, h2 and m2 {angle2:.2f} h1-m1 and h2-m2 {angle3:.2f} degrees.\n'
f'stratify threshold: {stratify_thresh}, split thresh: {split_thresh}, \n'
f'hit_mean_vector2 shape {miss_mean_vector1.shape}, miss1:{hit_mean_vector2.shape} \n'
f'Hit1 shape {hit1.shape}, miss1:{miss1.shape} threshold:{stratify_thresh}\n'
f'Hit2 shape {hit2.shape}, miss2:{miss2.shape}\n')

return hit_mean_vector1, miss_mean_vector1, hit_mean_vector2, miss_mean_vector2


def latent_space_dataframe(infer_hidden_tsv, explore_csv):
df = pd.read_csv(explore_csv)
df['fpath'] = pd.to_numeric(df['fpath'], errors='coerce')
df2 = pd.read_csv(infer_hidden_tsv, sep='\t')
df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce')
latent_df = pd.merge(df, df2, left_on='fpath', right_on='sample_id', how='inner')
latent_df.info()
return latent_df


Loading

0 comments on commit 198ee81

Please sign in to comment.