From dc444415d997df69210bb991c77fa22c6410a99f Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Tue, 29 Sep 2020 17:14:20 -0400 Subject: [PATCH 01/21] paired --- CONTRIBUTING.md | 192 ---------- README.md | 32 +- docker/terra_image/Dockerfile | 8 +- docker/terra_image/README.md | 8 +- docker/vm_boot_images/Dockerfile | 5 +- .../config/tensorflow-requirements.txt | 2 - ml4h/arguments.py | 19 +- ml4h/explorations.py | 61 ++++ ml4h/models.py | 133 ++++++- ml4h/plots.py | 194 ++++++++-- ml4h/recipes.py | 77 +++- ml4h/tensorize/tensor_writer_ukbb.py | 181 ++++------ ml4h/tensormap/mgb/dynamic.py | 15 +- ml4h/tensormap/ukb/by_script.py | 19 +- ml4h/tensormap/ukb/genetics.py | 10 +- ml4h/tensormap/ukb/mri.py | 52 ++- ml4h/test_utils.py | 1 + .../visualization_tools/annotation_storage.py | 36 +- ml4h/visualization_tools/annotations.py | 55 ++- .../batch_image_annotations.py | 236 ------------ .../dicom_interactive_plots.py | 74 ++-- ml4h/visualization_tools/dicom_plots.py | 122 +++---- .../ecg_interactive_plots.py | 22 +- ml4h/visualization_tools/ecg_reshape.py | 58 ++- ml4h/visualization_tools/ecg_static_plots.py | 11 +- ml4h/visualization_tools/facets.py | 13 +- ml4h/visualization_tools/hd5_mri_plots.py | 181 +++++----- .../paired_multimodal_autoencoder.ipynb | 113 +++++- .../paired_multimodal_segmenter_mri_ecg.ipynb | 210 +++++++---- notebooks/autoencoders/vae_mri_slice.ipynb | 2 +- notebooks/mnist_demo.ipynb | 90 ++++- .../mri/mri_cardiac_long_axis_sketch.ipynb | 8 +- .../mri/mri_cardiac_short_axis_sketch.ipynb | 14 +- .../identify_a_sample_to_review.ipynb | 6 +- .../review_results/image_annotations.ipynb | 268 -------------- .../review_results/review_one_sample.ipynb | 8 +- ...handling_for_notebook_visualizations.ipynb | 40 ++- ...ntify_a_sample_to_review_interactive.ipynb | 4 +- .../image_annotations_demo.ipynb | 259 -------------- .../review_one_sample_interactive.ipynb | 4 +- pylintrc | 337 ------------------ scripts/jupyter.sh | 3 +- tests/test_models.py | 69 +++- tests/test_recipes.py | 6 +- 44 files changed, 1289 insertions(+), 1969 deletions(-) delete mode 100644 CONTRIBUTING.md delete mode 100644 ml4h/visualization_tools/batch_image_annotations.py delete mode 100644 notebooks/review_results/image_annotations.ipynb delete mode 100644 notebooks/terra_featured_workspace/image_annotations_demo.ipynb delete mode 100644 pylintrc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 2e2006c07..000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,192 +0,0 @@ -# Contributing - -1. Before making a substantial pull request, consider first [filing an issue](https://github.com/broadinstitute/ml/issues) describing the feature addition or change you wish to make. -1. [Get setup](#setup-for-code-contributions) -1. [Follow the coding style](#python-coding-style) -1. [Test your code](#testing) -1. Send a [pull request](https://github.com/broadinstitute/ml/pulls) - -## Setup for code contributions - -### Get setup for GitHub - -Small typos in code or documentation may be edited directly using the GitHub web interface. Otherwise: - -1. If you are new to GitHub, don't start here. Instead, work through a GitHub tutorial such as https://guides.github.com/activities/hello-world/. -1. Create a fork of https://github.com/broadinstitute/ml -1. Clone your fork. -1. Work from a feature branch. See the [Appendix](#appendix) for detailed `git` commands. - -### Install precommit - -[`pre-commit`](https://pre-commit.com/) is a framework for managing and maintaining multi-language pre-commit hooks. - -``` -# Install pre-commit -pip3 install pre-commit -# Install the git hook scripts by running this within the git clone directory -cd ${HOME}/ml -pre-commit install -``` - -See [.pre-commit-config.yaml](https://github.com/broadinstitute/ml/blob/master/.pre-commit-config.yaml) for the currently configured pre-commit hooks for ml4cvd. - -### Install git-secrets - -```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our -repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be -installed with Homebrew by running ```brew install git-secrets```. - -To add hooks to all repositories that you initialize or clone in the future: - -```git secrets --install --global``` - -To add hooks to all local repositories: - -``` -git secrets --install ~/.git-templates/git-secrets -git config --global init.templateDir ~/.git-templates/git-secrets -``` - -We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid -committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to -```git_secrets_provider_ml4cvd.txt``` to match types of critical data that are not currently covered by the patterns in that -file. To register the patterns in this file with ```git-secrets```: - -``` -git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4cvd.txt -``` - -### Install pylint - -[`pylint`](https://www.pylint.org/) is a Python static code analysis tool which looks for programming errors, helps enforcing a coding standard, sniffs for code smells and offers simple refactoring suggestions. - -``` -# Install pylint -pip3 install pylint -``` - -See [pylintrc](https://github.com/broadinstitute/ml/blob/master/pylintrc) for the current lint configuration for ml4cvd. - -# Python coding style - -Changes to ml4cvd should conform to [PEP 8 -- Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/). See also [Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) as another decription of this coding style. - -Use `pylint` to check your Python changes: - -```bash -pylint --rcfile=${HOME}/ml/pylintrc myfile.py -``` - -Any messages returned by `pylint` are intended to be self-explanatory, but that isn't always the case. - -* Search for `pylint ` or `pylint ` for more details on the recommended code change to resolve the lint issue. -* Or add comment `# pylint: disable=` to the end of the line of code. - -# Testing - -## Testing of `recipes` - -Unit tests can be run in Docker with -``` -${HOME}/ml/scripts/tf.sh -T ${HOME}/ml/tests -``` -Unit tests can be run locally in a conda environment with -``` -python -m pytest ${HOME}/ml/tests -``` -Some of the unit tests are slow due to creating, saving and loading `tensorflow` models. -To skip those tests to move quickly, run -``` -python -m pytest ${HOME}/ml/tests -m "not slow" -``` -pytest can also run specific tests using `::`. For example - -``` -python -m pytest ${HOME}/ml/tests/test_models.py::TestMakeMultimodalMultitaskModel::test_u_connect_segment -``` - -For more pytest usage information, checkout the [usage guide](https://docs.pytest.org/en/latest/usage.html). - -## Testing of `visualization_tools` - -The code in [ml4cvd/visualization_tools](https://github.com/broadinstitute/ml/tree/master/ml4cvd/visualization_tools) is primarily interactive so we add test cases to notebook [test_error_handling_for_notebook_visualizations.ipynb](https://github.com/broadinstitute/ml/blob/master/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb) and visually inspect the output of `Cells -> Run all`. - -# Appendix - -For the ml4cvd GitHub repository, we are doing ‘merge and squash’ of pull requests. So that means your fork does not match upstream after your pull request has been merged. The easiest way to manage this is to always work in a feature branch, instead of checking changes into your fork’s master branch. - - -## How to work on a new feature - -(1) Get the latest version of the upstream repo - -``` -git fetch upstream -``` - -Note: If you get an error saying that upstream is unknown, run the following remote add command and then re-run the fetch command. You only need to do this once per git clone. - -``` -git remote add upstream https://github.com/broadinstitute/ml.git -``` - -(2) Make sure your master branch is “even” with upstream. - -``` -git checkout master -git merge --ff-only upstream/master -git push -``` - -Now the master branch of your fork on GitHub should say *"This branch is even with broadinstitute:master."*. - - -(3) Create a feature branch for your change. - -``` -git checkout -b my-feature-branch-name -``` - -Because you created this feature branch from your master branch that was up to date with upstream (step 2), your feature branch is also up to date with upstream. Commit your changes to this branch until you are happy with them. - -(4) Push your changes to GitHub and send a pull request. - -``` -git push --set-upstream origin my-feature-branch-name -``` - -After your pull request is merged, its safe to delete your branch! - -## I accidentally checked a new change to my master branch instead of a feature branch. How to fix this? - -(1) Soft undo your change(s). This leaves the changes in the files on disk but undoes the commit. - -``` -git checkout master -# Moves pointer back to previous HEAD -git reset --soft HEAD@{1} -``` - -Or if you need to move back several commits to the most recent one in common with upstream, you can change ‘1’ to be however many commits back you need to go. - -(2) “stash” your now-unchecked-in changes so that you can get them back later. - -``` -git stash -``` - -(3) Now do the [How to work on a new feature](#how-to-work-on-a-new-feature) step to bring master up to date and create your new feature branch that is “even” with upstream. Here are those commands again: - -``` -git fetch upstream -git merge --ff-only upstream/master -git checkout -b my-feature-branch-name -``` - -(4) “unstash” your changes. - -``` -git stash pop -``` -Now you can proceed with your work! diff --git a/README.md b/README.md index 8996bfe39..0335a4885 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # ml4h `ml4h` is a project aimed at using machine learning to model multi-modal cardiovascular time series and imaging data. `ml4h` began as a set of tools to make it easy to work -with the UK Biobank on Google Cloud Platform and has since expanded to include other data sources +with the UK Biobank on the Google Cloud and has since expanded to include other data sources and functionality. @@ -9,7 +9,6 @@ Getting Started * [Setting up your local environment](#setting-up-your-local-environment) * [Setting up a remote VM](#setting-up-a-remote-vm) * Modeling/Data Sources/Tests [(`ml4h/DATA_MODELING_TESTS.md`)](ml4h/DATA_MODELING_TESTS.md) -* [Contributing Code](#contributing-code) Advanced Topics: * Tensorizing Data (going from raw data to arrays suitable for modeling, in `ml4h/tensorize/README.md, TENSORIZE.md` ) @@ -20,7 +19,7 @@ Clone the repo ``` git clone git@github.com:broadinstitute/ml.git ``` -Make sure you have installed the [Google Cloud SDK (gcloud)](https://cloud.google.com/sdk/docs/downloads-interactive). With [Homebrew](https://brew.sh/), you can use +Make sure you have installed the [google cloud tools (gcloud)](https://cloud.google.com/storage/docs/gsutil_install). With [Homebrew](https://brew.sh/), you can use ``` brew cask install google-cloud-sdk ``` @@ -146,6 +145,29 @@ If you get a public key error run: `gcloud compute config-ssh` Now open a browser on your laptop and go to the URL `http://localhost:8888` -## Contributing code -Want to contribute code to this project? Please see [CONTRIBUTING](./CONTRIBUTING.md) for developer setup and other details. +### Installing git-secrets + +```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our +repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be +installed with Homebrew by running ```brew install git-secrets```. + +To add hooks to all repositories that you initialize or clone in the future: + +```git secrets --install --global``` + +To add hooks to all local repositories: + +``` +git secrets --install ~/.git-templates/git-secrets +git config --global init.templateDir ~/.git-templates/git-secrets +``` + +We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid +committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to +```git_secrets_provider_ml4h.txt``` to match types of critical data that are not currently covered by the patterns in that +file. To register the patterns in this file with ```git-secrets```: + +``` +git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4h.txt +``` diff --git a/docker/terra_image/Dockerfile b/docker/terra_image/Dockerfile index a59ecd6ae..721f94500 100644 --- a/docker/terra_image/Dockerfile +++ b/docker/terra_image/Dockerfile @@ -1,4 +1,4 @@ -FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.6 +FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.0 # https://github.com/DataBiosphere/terra-docker/blob/master/terra-jupyter-gatk/CHANGELOG.md USER root @@ -19,10 +19,6 @@ RUN pip3 install --user -r $HOME/ml4h_pkg/config/tensorflow-requirements.txt \ # first few rows of the downloaded dataframe of query results. # Pin version due to https://github.com/googleapis/google-cloud-python/issues/9965 && pip3 install --upgrade --user google-cloud-bigquery[pandas]==1.22.0 \ - # Upgrade to a newer version. The one on the base Terra image was a bit too old. - && pip3 install --upgrade --user numpy \ # Configure notebook extensions. && jupyter nbextension install --user --py vega \ - && jupyter nbextension enable --user --py vega \ - && jupyter nbextension install --user --py ipycanvas \ - && jupyter nbextension enable --user --py ipycanvas + && jupyter nbextension enable --user --py vega diff --git a/docker/terra_image/README.md b/docker/terra_image/README.md index 9a81dc74a..71284c0bd 100644 --- a/docker/terra_image/README.md +++ b/docker/terra_image/README.md @@ -2,13 +2,13 @@ To build and push: ``` -mv ml4h ml4hBAK_$(date +"%Y%m%d_%H%M%S") \ +mv ml4cvd ml4cvdBAK_$(date +"%Y%m%d_%H%M%S") \ && mv config configBAK_$(date +"%Y%m%d_%H%M%S") \ - && cp -r ../../ml4h . \ + && cp -r ../../ml4cvd . \ && cp -r ../vm_boot_images/config . \ && gcloud --project uk-biobank-sek-data builds submit \ --timeout 20m \ - --tag gcr.io/uk-biobank-sek-data/ml4h_terra:`date +"%Y%m%d_%H%M%S"` . + --tag gcr.io/uk-biobank-sek-data/ml4cvd_terra:`date +"%Y%m%d_%H%M%S"` . ``` Notes: @@ -20,5 +20,5 @@ available to docker. cd notebooks find . -name "*.ipynb" -type f -print0 | \ xargs -0 perl -i -pe \ - 's/gcr.io\/uk-biobank-sek-data\/ml4h_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4h_terra:20200623_145127/g' + 's/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:20200623_145127/g' ``` diff --git a/docker/vm_boot_images/Dockerfile b/docker/vm_boot_images/Dockerfile index 59e5b32be..a62694ca0 100644 --- a/docker/vm_boot_images/Dockerfile +++ b/docker/vm_boot_images/Dockerfile @@ -34,7 +34,4 @@ RUN apt-get install python3-tk libgl1-mesa-glx libxt-dev -y # Requirements for the tensorflow project RUN pip3 install --upgrade pip RUN pip3 install -r pre_requirements.txt -RUN pip3 install -r tensorflow-requirements.txt \ - # Configure notebook extensions. - && jupyter nbextension install --user --py ipycanvas \ - && jupyter nbextension enable --user --py ipycanvas +RUN pip3 install -r tensorflow-requirements.txt diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index d782967af..bb6a1e777 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -28,5 +28,3 @@ altair facets-overview plotnine vega -ipycanvas==0.4.1 -ipyannotations==0.2.0 diff --git a/ml4h/arguments.py b/ml4h/arguments.py index f00c6f9f9..6692bcc74 100644 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -20,7 +20,7 @@ import importlib import numpy as np import multiprocessing -from typing import Set, Dict, List, Optional +from typing import Set, Dict, List, Optional, Tuple from collections import defaultdict from ml4h.logger import load_config @@ -168,6 +168,10 @@ def parse_args(): '--u_connect', nargs=2, action='append', help='U-Net connect first TensorMap to second TensorMap. They must be the same shape except for number of channels. Can be provided multiple times.', ) + parser.add_argument( + '--pairs', nargs=2, action='append', + help='TensorMap pairs for paired autoencoder. The pair_loss metric will encourage similar embeddings for each two input TensorMap pairs. Can be provided multiple times.', + ) parser.add_argument('--aligned_dimension', default=16, type=int, help='Dimensionality of aligned embedded space for multi-modal alignment models.') parser.add_argument( '--max_parameters', default=9000000, type=int, @@ -201,6 +205,7 @@ def parse_args(): parser.add_argument('--validation_steps', default=18, type=int, help='Number of validation batches to examine in an epoch validation.') parser.add_argument('--learning_rate', default=0.0002, type=float, help='Learning rate during training.') parser.add_argument('--mixup_alpha', default=0, type=float, help='If positive apply mixup and sample from a Beta with this value as shape parameter alpha.') + parser.add_argument('--pair_loss', default='cosine', help='Distance metric between paired embeddings', choices=['euclid', 'cosine']) parser.add_argument( '--label_weights', nargs='*', type=float, help='List of per-label weights for weighted categorical cross entropy. If provided, must map 1:1 to number of labels.', @@ -219,6 +224,9 @@ def parse_args(): parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training') parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training') parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value') + parser.add_argument( + '--save_last_model', default=False, action='store_true', + help='If true saves the model weights from the last training epoch, otherwise the model with best validation loss is saved.') # Run specific and debugging arguments parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.') @@ -384,6 +392,14 @@ def _process_u_connect_args(u_connect: Optional[List[List]], tensormap_prefix) - return new_u_connect +def _process_pair_args(pairs: Optional[List[List]], tensormap_prefix) -> List[Tuple[TensorMap, TensorMap]]: + pairs = pairs or [] + new_pairs = [] + for pair in pairs: + new_pairs.append((tensormap_lookup(pair[0], tensormap_prefix), tensormap_lookup(pair[1], tensormap_prefix))) + return new_pairs + + def _process_args(args): now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') args_file = os.path.join(args.output_folder, args.id, 'arguments_' + now_string + '.txt') @@ -396,6 +412,7 @@ def _process_args(args): f.write(k + ' = ' + str(v) + '\n') load_config(args.logging_level, os.path.join(args.output_folder, args.id), 'log_' + now_string, args.min_sample_id) args.u_connect = _process_u_connect_args(args.u_connect, args.tensormap_prefix) + args.pairs = _process_pair_args(args.pairs, args.tensormap_prefix) args.tensor_maps_in = [] args.tensor_maps_out = [] diff --git a/ml4h/explorations.py b/ml4h/explorations.py index dc21246db..4e82968ab 100644 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -79,6 +79,21 @@ def predictions_to_pngs( ax.add_patch(matplotlib.patches.Rectangle(y_corner, y_width, y_height, linewidth=1, edgecolor='y', facecolor='none')) logging.info(f"True BBox: {corner}, {width}, {height} Predicted BBox: {y_corner}, {y_width}, {y_height} Vmin {vmin} Vmax{vmax}") plt.savefig(f"{folder}{sample_id}_bbox_batch_{i:02d}{IMAGE_EXT}") + elif tm.axes() == 2: + fig = plt.figure(figsize=(SUBPLOT_SIZE, SUBPLOT_SIZE * 3)) + for i in range(y.shape[0]): + sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') + title = f'{tm.name}_{sample_id}_reconstruction' + for j in range(tm.shape[1]): + plt.subplot(tm.shape[1], 1, j + 1) + plt.plot(labels[tm.output_name()][i, :, j], c='k', linestyle='--', label='original') + plt.plot(y[i, :, j], c='b', label='reconstruction') + if j == 0: + plt.title(title) + plt.legend() + plt.tight_layout() + plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + plt.clf() elif len(tm.shape) == 3: for i in range(y.shape[0]): sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') @@ -1332,3 +1347,49 @@ def _get_occurrences(df, order, start, end): fpath = os.path.join(args.output_folder, args.id, 'summary_cohort_counts.csv') pd.DataFrame.from_dict(cohort_counts, orient='index', columns=['count']).rename_axis('description').to_csv(fpath) logging.info(f'Saved cohort counts to {fpath}') + + +def directions_in_latent_space(stratify_column, stratify_thresh, split_column, split_thresh, latent_cols, latent_df): + hit = latent_df.loc[latent_df[stratify_column] >= stratify_thresh][latent_cols].to_numpy() + miss = latent_df.loc[latent_df[stratify_column] < stratify_thresh][latent_cols].to_numpy() + miss_mean_vector = np.mean(miss, axis=0) + hit_mean_vector = np.mean(hit, axis=0) + strat_vector = hit_mean_vector - miss_mean_vector + + hit1 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh) + & (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy() + miss1 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh) + & (latent_df[split_column] >= split_thresh)][latent_cols].to_numpy() + hit2 = latent_df.loc[(latent_df[stratify_column] >= stratify_thresh) + & (latent_df[split_column] < split_thresh)][latent_cols].to_numpy() + miss2 = latent_df.loc[(latent_df[stratify_column] < stratify_thresh) + & (latent_df[split_column] < split_thresh)][latent_cols].to_numpy() + miss_mean_vector1 = np.mean(miss1, axis=0) + hit_mean_vector1 = np.mean(hit1, axis=0) + angle1 = angle_between(miss_mean_vector1, hit_mean_vector1) + miss_mean_vector2 = np.mean(miss2, axis=0) + hit_mean_vector2 = np.mean(hit2, axis=0) + angle2 = angle_between(miss_mean_vector2, hit_mean_vector2) + h1_vector = hit_mean_vector1 - miss_mean_vector1 + h2_vector = hit_mean_vector2 - miss_mean_vector2 + angle3 = angle_between(h1_vector, h2_vector) + print(f'\n Between {stratify_column}, and splits: {split_column}\n', + f'Angles h1 and m1: {angle1:.2f}, h2 and m2 {angle2:.2f} h1-m1 and h2-m2 {angle3:.2f} degrees.\n' + f'stratify threshold: {stratify_thresh}, split thresh: {split_thresh}, \n' + f'hit_mean_vector2 shape {miss_mean_vector1.shape}, miss1:{hit_mean_vector2.shape} \n' + f'Hit1 shape {hit1.shape}, miss1:{miss1.shape} threshold:{stratify_thresh}\n' + f'Hit2 shape {hit2.shape}, miss2:{miss2.shape}\n') + + return hit_mean_vector1, miss_mean_vector1, hit_mean_vector2, miss_mean_vector2 + + +def latent_space_dataframe(infer_hidden_tsv, explore_csv): + df = pd.read_csv(explore_csv) + df['fpath'] = pd.to_numeric(df['fpath'], errors='coerce') + df2 = pd.read_csv(infer_hidden_tsv, sep='\t') + df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce') + latent_df = pd.merge(df, df2, left_on='fpath', right_on='sample_id', how='inner') + latent_df.info() + return latent_df + + diff --git a/ml4h/models.py b/ml4h/models.py index 493870ad1..3eb53ccbd 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -533,16 +533,9 @@ def l2_norm(x, axis=None): def pairwise_cosine_difference(t1, t2): - """ - A [batch x n x d] tensor of n rows with d dimensions - B [batch x m x d] tensor of n rows with d dimensions - - returns: - D [batch x n x m] tensor of cosine similarity scores between each point i Model: + opt = get_optimizer( + kwargs['optimizer'], kwargs['learning_rate'], steps_per_epoch=kwargs['training_steps'], + learning_rate_schedule=kwargs['learning_rate_schedule'], optimizer_kwargs=kwargs.get('optimizer_kwargs'), + ) + if 'model_file' in kwargs and kwargs['model_file'] is not None: + custom_dict = _get_custom_objects(kwargs['tensor_maps_out']) + encoders = {} + for tm in kwargs['tensor_maps_in']: + encoders[tm] = load_model(f"{os.path.dirname(kwargs['model_file'])}/encoder_{tm.name}.h5", custom_objects=custom_dict, compile=False) + decoders = {} + for tm in kwargs['tensor_maps_out']: + decoders[tm] = load_model(f"{os.path.dirname(kwargs['model_file'])}/decoder_{tm.name}.h5", custom_objects=custom_dict, compile=False) + logging.info(f"Attempting to load model file from: {kwargs['model_file']}") + m = load_model(kwargs['model_file'], custom_objects=custom_dict, compile=False) + m.compile(optimizer=opt, loss=custom_dict['loss']) + m.summary() + logging.info(f"Loaded model file from: {kwargs['model_file']}") + return m, encoders, decoders + + inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} + original_outputs = {tm: 1 for tm in kwargs['tensor_maps_out']} + real_serial_layers = kwargs['model_layers'] + kwargs['model_layers'] = None + multimodal_activations = [] + encoders = {} + decoders = {} + outputs = {} + losses = [] + for left, right in pairs: + kwargs['tensor_maps_in'] = [left] + left_model = make_multimodal_multitask_model(**kwargs) + encode_left = make_hidden_layer_model(left_model, [left], kwargs['hidden_layer']) + h_left = encode_left(inputs[left]) + + kwargs['tensor_maps_in'] = [right] + right_model = make_multimodal_multitask_model(**kwargs) + encode_right = make_hidden_layer_model(right_model, [right], kwargs['hidden_layer']) + h_right = encode_right(inputs[right]) + + if pair_loss == 'cosine': + loss_layer = CosineLossLayer(1.0) + elif pair_loss == 'euclid': + loss_layer = L2LossLayer(1.0) + + paired_embeddings = loss_layer([h_left, h_right]) + multimodal_activations.extend(paired_embeddings) + if left not in encoders: + encoders[left] = encode_left + if right not in encoders: + encoders[right] = encode_right + + multimodal_activation = Concatenate()(multimodal_activations) + multimodal_activation = Dense(units=kwargs['dense_layers'][0])(multimodal_activation) + #multimodal_activation = _activation_layer(kwargs['activation'])(multimodal_activation) + latent_inputs = Input(shape=(kwargs['dense_layers'][0]), name='input_concept_space') + + # build decoder models + for tm in kwargs['tensor_maps_out']: + if tm.axes() > 1: + shape = _calc_start_shape(num_upsamples=len(kwargs['dense_blocks']), output_shape=tm.shape, + upsample_rates=[kwargs['pool_x'], kwargs['pool_y'], kwargs['pool_z']], + channels=kwargs['dense_blocks'][-1]) + + restructure = FlatToStructure(output_shape=shape, activation=kwargs['activation'], + normalization=kwargs['dense_normalize']) + + decode = ConvDecoder( + tensor_map_out=tm, + filters_per_dense_block=kwargs['dense_blocks'][::-1], + conv_layer_type=kwargs['conv_type'], + conv_x=kwargs['conv_x'], + conv_y=kwargs['conv_y'], + conv_z=kwargs['conv_z'], + block_size=kwargs['block_size'], + activation=kwargs['activation'], + normalization=kwargs['conv_normalize'], + regularization=kwargs['conv_regularize'], + regularization_rate=kwargs['conv_regularize_rate'], + upsample_x=kwargs['pool_x'], + upsample_y=kwargs['pool_y'], + upsample_z=kwargs['pool_z'], + u_connect_parents=[tm_in for tm_in in kwargs['tensor_maps_in'] if tm in kwargs['u_connect'][tm_in]], + ) + reconstruction = decode(restructure(latent_inputs), {}, {}) + else: + decode = DenseDecoder(tensor_map_out=tm, parents=tm.parents, activation=kwargs['activation']) + reconstruction = decode(latent_inputs, {}, {}) + + decoder = Model(latent_inputs, reconstruction, name=tm.output_name()) + decoders[tm] = decoder + outputs[tm.output_name()] = decoder(multimodal_activation) + losses.append(tm.loss) + + kwargs['tensor_maps_out'] = list(original_outputs.keys()) + kwargs['tensor_maps_in'] = list(inputs.keys()) + + m = Model(inputs=list(inputs.values()), outputs=outputs) + my_metrics = {tm.output_name(): tm.metrics for tm in kwargs['tensor_maps_out']} + m.compile(optimizer=opt, loss=losses, metrics=my_metrics) + m.summary() + + if real_serial_layers is not None: + m.load_weights(kwargs['model_layers'], by_name=True) + logging.info(f"Loaded model weights from:{kwargs['model_layers']}") + + return m, encoders, decoders + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~ Training ~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1173,13 +1279,13 @@ def train_model_from_generators( inspect_show_labels: bool, return_history: bool = False, plot: bool = True, + save_last_model: bool = False, ) -> Union[Model, Tuple[Model, History]]: """Train a model from tensor generators for validation and training data. Training data lives on disk, it will be loaded by generator functions. Plots the metric history after training. Creates a directory to save weights, if necessary. Measures runtime and plots architecture diagram if inspect_model is True. - :param model: The model to optimize :param generate_train: Generator function that yields mini-batches of training data. :param generate_valid: Generator function that yields mini-batches of validation data. @@ -1193,6 +1299,9 @@ def train_model_from_generators( :param inspect_model: If True, measure training and inference runtime of the model and generate architecture plot. :param inspect_show_labels: If True, show labels on the architecture plot. :param return_history: If true return history from training and don't plot the training history + :param plot: If true, plots the metrics for train and validation set at the end of each epoch + :param save_last_model: If true saves the model weights from last epoch otherwise saves model with best validation loss + :return: The optimized model. """ model_file = os.path.join(output_folder, run_id, run_id + MODEL_EXT) @@ -1206,7 +1315,7 @@ def train_model_from_generators( history = model.fit( generate_train, steps_per_epoch=training_steps, epochs=epochs, verbose=1, validation_steps=validation_steps, validation_data=generate_valid, - callbacks=_get_callbacks(patience, model_file), + callbacks=_get_callbacks(patience, model_file, save_last_model), ) generate_train.kill_workers() generate_valid.kill_workers() @@ -1221,10 +1330,10 @@ def train_model_from_generators( def _get_callbacks( - patience: int, model_file: str, + patience: int, model_file: str, save_last_model: bool ) -> List[Callback]: callbacks = [ - ModelCheckpoint(filepath=model_file, verbose=1, save_best_only=True), + ModelCheckpoint(filepath=model_file, verbose=1, save_best_only=not save_last_model), EarlyStopping(monitor='val_loss', patience=patience * 3, verbose=1), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=patience, verbose=1), ] diff --git a/ml4h/plots.py b/ml4h/plots.py index c353d1b94..c60282803 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -29,6 +29,7 @@ from matplotlib.ticker import AutoMinorLocator, MultipleLocator from sklearn import manifold +from sklearn.decomposition import PCA from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score, roc_auc_score from sklearn.calibration import calibration_curve @@ -39,10 +40,6 @@ from scipy.ndimage.filters import gaussian_filter from scipy import stats -import ml4h.tensormap.ukb.ecg -import ml4h.tensormap.mgb.ecg -from ml4h.tensormap.mgb.dynamic import make_waveform_maps - from ml4h.TensorMap import TensorMap from ml4h.metrics import concordance_index, coefficient_of_determination from ml4h.defines import IMAGE_EXT, JOIN_CHAR, PDF_EXT, TENSOR_EXT, ECG_REST_LEADS, ECG_REST_MEDIAN_LEADS, PARTNERS_DATETIME_FORMAT, PARTNERS_DATE_FORMAT, HD5_GROUP_CHAR @@ -165,7 +162,7 @@ def evaluate_predictions( if tm.sentinel is not None: y_predictions = y_predictions[y_truth != tm.sentinel] y_truth = y_truth[y_truth != tm.sentinel] - _plot_reconstruction(tm, y_truth, y_predictions, folder, test_paths) + plot_reconstruction(tm, y_truth, y_predictions, folder, test_paths) if prediction_flat.shape[0] == truth_flat.shape[0]: performance_metrics.update(subplot_pearson_per_class(prediction_flat, truth_flat, tm.channel_map, protected, title, prefix=folder)) elif tm.is_continuous(): @@ -190,10 +187,15 @@ def plot_metric_history(history, training_steps: int, title: str, prefix='./figu cols = max(2, int(math.ceil(math.sqrt(total_plots)))) rows = max(2, int(math.ceil(total_plots / cols))) f, axes = plt.subplots(rows, cols, figsize=(int(cols*SUBPLOT_SIZE), int(rows*SUBPLOT_SIZE))) + logging.info(f'all keys {list(sorted(history.history.keys()))}') for k in sorted(history.history.keys()): if not k.startswith('val_'): if isinstance(history.history[k][0], LearningRateSchedule): history.history[k] = [history.history[k][0](i * training_steps) for i in range(len(history.history[k]))] + if len(np.array(history.history[k]).shape) > 1: # Hack for models with paired loss + history.history[k] = np.array(history.history[k])[:, 0, 0] + if 'val_' + k in history.history: + history.history['val_' + k] = np.array(history.history['val_' + k])[:, 0, 0] axes[row, col].plot(history.history[k]) k_split = str(k).replace('output_', '').split('_') k_title = " ".join(OrderedDict.fromkeys(k_split)) @@ -464,7 +466,7 @@ def subplot_pearson_per_class( os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches='tight') plt.clf() - logging.info(f"Saved Pearson correlations at: {figure_path} with {len(protected)} protected TensorMaps.") + logging.info(f"{label_text} saved at: {figure_path}{f' with {len(protected)} protected TensorMaps.' if len(protected) else '.'}") return labels_to_areas @@ -1225,15 +1227,16 @@ def _plot_partners_figure( def plot_partners_ecgs(args): plot_tensors = [ - ml4h.tensormap.mgb.ecg.partners_ecg_patientid, ml4h.tensormap.mgb.ecg.partners_ecg_firstname, ml4h.tensormap.mgb.ecg.partners_ecg_lastname, - ml4h.tensormap.mgb.ecg.partners_ecg_sex, ml4h.tensormap.mgb.ecg.partners_ecg_dob, ml4h.tensormap.mgb.ecg.partners_ecg_age, - ml4h.tensormap.mgb.ecg.partners_ecg_datetime, ml4h.tensormap.mgb.ecg.partners_ecg_sitename, ml4h.tensormap.mgb.ecg.partners_ecg_location, - ml4h.tensormap.mgb.ecg.partners_ecg_read_md, ml4h.tensormap.mgb.ecg.partners_ecg_taxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_rate_md, - ml4h.tensormap.mgb.ecg.partners_ecg_pr_md, ml4h.tensormap.mgb.ecg.partners_ecg_qrs_md, ml4h.tensormap.mgb.ecg.partners_ecg_qt_md, - ml4h.tensormap.mgb.ecg.partners_ecg_paxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_raxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_qtc_md, + 'partners_ecg_patientid', 'partners_ecg_firstname', 'partners_ecg_lastname', + 'partners_ecg_sex', 'partners_ecg_dob', 'partners_ecg_age', + 'partners_ecg_datetime', 'partners_ecg_sitename', 'partners_ecg_location', + 'partners_ecg_read_md', 'partners_ecg_taxis_md', 'partners_ecg_rate_md', + 'partners_ecg_pr_md', 'partners_ecg_qrs_md', 'partners_ecg_qt_md', + 'partners_ecg_paxis_md', 'partners_ecg_raxis_md', 'partners_ecg_qtc_md', ] - voltage_tensor = make_waveform_maps('partners_ecg_2500_raw') - tensor_maps_in = plot_tensors + [voltage_tensor] + voltage_tensor = 'partners_ecg_2500_raw' + from ml4h.tensor_maps_partners_ecg_labels import TMAPS + tensor_maps_in = [TMAPS[it] for it in plot_tensors + [voltage_tensor]] tensor_paths = [os.path.join(args.tensors, tp) for tp in os.listdir(args.tensors) if os.path.splitext(tp)[-1].lower()==TENSOR_EXT] if 'clinical' == args.plot_mode: @@ -1500,12 +1503,13 @@ def plot_ecg_rest( :param is_blind: if True, the plot gets blinded (helpful for review and annotation) """ map_fields_to_tmaps = { - 'ramp': ml4h.tensormap.ukb.ecg.ecg_rest_ramplitude_raw, - 'samp': ml4h.tensormap.ukb.ecg.ecg_rest_samplitude_raw, - 'aVL': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_avl, - 'Sokolow_Lyon': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_sokolow_lyon, - 'Cornell': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_cornell, - } + 'ramp': 'ecg_rest_ramplitude_raw', + 'samp': 'ecg_rest_samplitude_raw', + 'aVL': 'ecg_rest_lvh_avl', + 'Sokolow_Lyon': 'ecg_rest_lvh_sokolow_lyon', + 'Cornell': 'ecg_rest_lvh_cornell', + } + from ml4h.tensor_from_file import TMAPS raw_scale = 0.005 # Conversion from raw to mV default_yrange = ECG_REST_PLOT_DEFAULT_YRANGE # mV time_interval = 2.5 # time-interval per plot in seconds. ts_Reference data is in s, voltage measurement is 5 uv per lsb @@ -1517,7 +1521,7 @@ def plot_ecg_rest( with h5py.File(tensor_path, 'r') as hd5: traces, text = _ecg_rest_traces_and_text(hd5) for field in map_fields_to_tmaps: - tm = map_fields_to_tmaps[field] + tm = TMAPS[map_fields_to_tmaps[field]] patient_dic[field] = np.zeros(tm.shape) try: patient_dic[field][:] = tm.tensor_from_file(tm, hd5) @@ -2076,32 +2080,150 @@ def _text_on_plot(axes, x, y, text, alpha=0.8, background='white'): t.set_bbox({'facecolor': background, 'alpha': alpha, 'edgecolor': background}) -def _plot_reconstruction( +def plot_reconstruction( tm: TensorMap, y_true: np.ndarray, y_pred: np.ndarray, - folder: str, paths: List[str], + folder: str, paths: List[str], num_samples: int = 4, ): - num_samples = 3 logging.info(f'Plotting {num_samples} reconstructions of {tm}.') if None in tm.shape: # can't handle dynamic shapes return for i in range(num_samples): - title = f'{tm.name}_{os.path.basename(paths[i]).replace(TENSOR_EXT, "")}_reconstruction' + sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') + title = f'{tm.name}_{sample_id}_reconstruction' y = y_true[i].reshape(tm.shape) yp = y_pred[i].reshape(tm.shape) if tm.axes() == 2: - fig = plt.figure(figsize=(SUBPLOT_SIZE, SUBPLOT_SIZE * num_samples)) + index2channel = {v: k for k, v in tm.channel_map.items()} + fig, axes = plt.subplots(tm.shape[1], 2, figsize=(2 * SUBPLOT_SIZE, 6*SUBPLOT_SIZE)) for j in range(tm.shape[1]): - plt.subplot(tm.shape[1], 1, j + 1) - plt.plot(y[:, j], c='k', linestyle='--', label='original') - plt.plot(yp[:, j], c='b', label='reconstruction') - if j == 0: - plt.title(title) - plt.legend() + axes[j, 0].plot(y[:, j], c='k', linestyle='--', label='original') + axes[j, 1].plot(yp[:, j], c='b', label='reconstruction') + axes[j, 0].set_title(f'Lead: {index2channel[j]}') + axes[j, 0].legend() + axes[j, 1].legend() plt.tight_layout() - # TODO: implement 3d, 4d - plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + plt.savefig(os.path.join(folder, title + IMAGE_EXT)) + elif tm.axes() == 3: + if tm.is_categorical(): + plt.imsave(f"{folder}{sample_id}_{tm.name}_truth_{i:02d}{IMAGE_EXT}", np.argmax(y, axis=-1), cmap='plasma') + plt.imsave(f"{folder}{sample_id}_{tm.name}_prediction_{i:02d}{IMAGE_EXT}", np.argmax(yp, axis=-1), cmap='plasma') + else: + plt.imsave(f'{folder}{sample_id}_{tm.name}_truth_{i:02d}{IMAGE_EXT}', y[:, :, 0], cmap='gray') + plt.imsave(f'{folder}{sample_id}_{tm.name}_prediction_{i:02d}{IMAGE_EXT}', yp[:, :, 0], cmap='gray') + elif tm.axes() == 4: + for j in range(y.shape[3]): + image_path_base = f'{folder}{sample_id}_{tm.name}_{i:03d}_{j:03d}' + if tm.is_categorical(): + truth = np.argmax(yp[tm.output_name()][:, :, j, :], axis=-1) + prediction = np.argmax(y[:, :, j, :], axis=-1) + plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', truth, cmap='plasma') + plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', prediction, cmap='plasma') + else: + plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', y[:, :, j, 0], cmap='gray') + plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', yp[:, :, j, :], cmap='gray') plt.clf() -if __name__ == '__main__': - plot_noisy() +def pca_on_matrix(matrix, pca_components, prefix='./figures/'): + pca = PCA() + pca.fit(matrix) + print(f'PCA explains {100 * np.sum(pca.explained_variance_ratio_[:pca_components]):0.1f}% of variance with {pca_components} top PCA components.') + matrix_reduced = pca.transform(matrix)[:, :pca_components] + print(f'PCA reduces matrix shape:{matrix_reduced.shape} from matrix shape: {matrix.shape}') + plot_scree(pca_components, 100 * pca.explained_variance_ratio_, prefix) + return pca, matrix_reduced + + +def plot_scree(pca_components, percent_explained, prefix='./figures/'): + _ = plt.figure(figsize=(6, 4)) + plt.plot(range(len(percent_explained)), percent_explained, 'g.-', linewidth=1) + plt.axvline(x=pca_components, c='r', linewidth=3) + label = f'{np.sum(percent_explained[:pca_components]):0.1f}% of variance explained by top {pca_components} of {len(percent_explained)} components' + plt.text(pca_components + 0.02 * len(percent_explained), percent_explained[1], label) + plt.title('Scree Plot') + plt.xlabel('Principal Components') + plt.ylabel('% of Variance Explained by Each Component') + figure_path = f'{prefix}pca_{pca_components}_of_{len(percent_explained)}.png' + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) + + +def unit_vector(vector): + """ Returns the unit vector of the vector. """ + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2): + """ Returns the angle in radians between vectors 'v1' and 'v2':: + angle_between((1, 0, 0), (0, 1, 0)) + 90 + angle_between((1, 0, 0), (1, 0, 0)) + 0.0 + angle_between((1, 0, 0), (-1, 0, 0)) + 180 + """ + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) * 180 / 3.141592 + + +def stratify_latent_space(stratify_column, stratify_thresh, latent_cols, latent_df): + hit = latent_df.loc[latent_df[stratify_column] >= stratify_thresh][latent_cols].to_numpy() + miss = latent_df.loc[latent_df[stratify_column] < stratify_thresh][latent_cols].to_numpy() + miss_mean_vector = np.mean(miss, axis=0) + hit_mean_vector = np.mean(hit, axis=0) + angle = angle_between(miss_mean_vector, hit_mean_vector) + print(f'Angle between {stratify_column} and all others: {angle}, \n' + f'Hit shape {hit.shape}, miss:{miss.shape} threshold:{stratify_thresh}\n' + f'Distance: {np.linalg.norm(hit_mean_vector - miss_mean_vector):.3f}, ' + f'Hit std {np.std(hit, axis=1).mean():.3f}, miss std:{np.std(miss, axis=1).mean():.3f}\n') + return hit_mean_vector, miss_mean_vector + + +def plot_hit_to_miss_transforms(latent_df, decoders, feature='Sex_Female_0_0', prefix='./figures/', + thresh=1.0, latent_dimension=256, samples=16, scalar=3.0, cmap='plasma'): + latent_cols = [f'latent_{i}' for i in range(latent_dimension)] + female, male = stratify_latent_space(feature, thresh, latent_cols, latent_df) + sex_vector = female - male + embeddings = latent_df.iloc[:samples][latent_cols].to_numpy() + sexes = latent_df.iloc[:samples][feature].to_numpy() + print(f'Embedding shape: {embeddings.shape} sexes shape: {sexes.shape}') + + sex_vectors = np.tile(sex_vector, (samples, 1)) + male_to_female = embeddings + (scalar * sex_vectors) + female_to_male = embeddings - (scalar * sex_vectors) + print(f'embeddings shape: {embeddings.shape} features vectors shape: {sex_vectors.shape}') + for dtm in decoders: + predictions = decoders[dtm].predict(embeddings) + m2f = decoders[dtm].predict(male_to_female) + f2m = decoders[dtm].predict(female_to_male) + print(f'prediction shape: {predictions.shape}') + if dtm.axes() == 3: + fig, axes = plt.subplots(samples, 2, figsize=(18, samples * 4)) + for i in range(samples): + axes[i, 0].set_title(f"{feature}: {sexes[i]} ?>== thresh: + axes[i, 1].imshow(np.argmax(f2m[i, ...], axis=-1), cmap=cmap) + axes[i, 1].set_title(f'{feature} to less than {thresh}') + else: + axes[i, 1].imshow(np.argmax(m2f[i, ...], axis=-1), cmap=cmap) + axes[i, 1].set_title(f'{feature} to more than or equal to {thresh}') + else: + axes[i, 0].imshow(predictions[i, ..., 0], cmap='gray') + if sexes[i] >= thresh: + axes[i, 1].imshow(f2m[i, ..., 0], cmap='gray') + axes[i, 1].set_title(f'{feature} to less than {thresh}') + else: + axes[i, 1].imshow(m2f[i, ..., 0], cmap='gray') + axes[i, 1].set_title(f'{feature} to more than or equal to {thresh}') + figure_path = f'{prefix}/{dtm.name}_{feature}_transform_scalar_{scalar}.png' + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) \ No newline at end of file diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 80c22f93c..148c2c517 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -15,15 +15,16 @@ from ml4h.defines import TENSOR_EXT, MODEL_EXT from ml4h.tensormap.tensor_map_maker import write_tensor_maps from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb -from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, explore +from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, explore, latent_space_dataframe from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, cross_reference from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator -from ml4h.models import make_character_model_plus, embed_model_predict, make_siamese_model, make_multimodal_multitask_model +from ml4h.models import make_character_model_plus, embed_model_predict, make_siamese_model, make_multimodal_multitask_model, make_paired_autoencoder_model from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients from ml4h.models import train_model_from_generators, get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model, saliency_map -from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_prediction_calibrations +from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_prediction_calibrations, \ + plot_reconstruction, plot_hit_to_miss_transforms from ml4h.tensorize.tensor_writer_ukbb import write_tensors, append_fields_from_csv, append_gene_csv, write_tensors_from_dicom_pngs, write_tensors_from_ecg_pngs from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp @@ -74,20 +75,20 @@ def run(args): ecg_dates(args.tensors, args.output_folder, args.id) elif 'plot_histograms' == args.mode: plot_histograms_of_tensors_in_pdf(args.id, args.tensors, args.output_folder, args.max_samples) - elif 'plot_heatmap' == args.mode: - plot_heatmap_of_tensors(args.id, args.tensors, args.output_folder, args.min_samples, args.max_samples) elif 'plot_resting_ecgs' == args.mode: plot_ecg_rest_mp(args.tensors, args.min_sample_id, args.max_sample_id, args.output_folder, args.num_workers) elif 'plot_partners_ecgs' == args.mode: plot_partners_ecgs(args) - elif 'tabulate_correlations' == args.mode: - tabulate_correlations_of_tensors(args.id, args.tensors, args.output_folder, args.min_samples, args.max_samples) elif 'train_shallow' == args.mode: train_shallow_model(args) elif 'train_char' == args.mode: train_char_model(args) elif 'train_siamese' == args.mode: train_siamese_model(args) + elif 'train_paired' == args.mode: + train_paired_model(args) + elif 'inspect_paired' == args.mode: + inspect_paired_model(args) elif 'write_tensor_maps' == args.mode: write_tensor_maps(args) elif 'append_continuous_csv' == args.mode: @@ -135,8 +136,8 @@ def train_multimodal_multitask(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = make_multimodal_multitask_model(**args.__dict__) model = train_model_from_generators( - model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, - args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, + model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, args.epochs, + args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, save_last_model=args.save_last_model ) out_path = os.path.join(args.output_folder, args.id + '/') @@ -285,14 +286,14 @@ def infer_multimodal_multitask(args): logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") -def hidden_inference_file_name(output_folder: str, id_: str) -> str: - return os.path.join(output_folder, id_, 'hidden_inference_' + id_ + '.tsv') +def _hidden_file_name(output_folder: str, prefix_: str, id_: str, extension_: str) -> str: + return os.path.join(output_folder, id_, prefix_ + id_ + extension_) def infer_hidden_layer_multimodal_multitask(args): stats = Counter() args.num_workers = 0 - inference_tsv = hidden_inference_file_name(args.output_folder, args.id) + inference_tsv = _hidden_file_name(args.output_folder, 'hidden_inference_', args.id, '.tsv') tsv_style_is_genetics = 'genetics' in args.tsv_style tensor_paths = [os.path.join(args.tensors, tp) for tp in sorted(os.listdir(args.tensors)) if os.path.splitext(tp)[-1].lower() == TENSOR_EXT] # hard code batch size to 1 so we can iterate over file names and generated tensors together in the tensor_paths for loop @@ -303,6 +304,7 @@ def infer_hidden_layer_multimodal_multitask(args): generate_test.set_worker_paths(tensor_paths) full_model = make_multimodal_multitask_model(**args.__dict__) embed_model = make_hidden_layer_model(full_model, args.tensor_maps_in, args.hidden_layer) + embed_model.save(_hidden_file_name(args.output_folder, f'{args.hidden_layer}_encoder_', args.id, '.h5')) dummy_input = {tm.input_name(): np.zeros((1,) + full_model.get_layer(tm.input_name()).input_shape[0][1:]) for tm in args.tensor_maps_in} dummy_out = embed_model.predict(dummy_input) latent_dimensions = int(np.prod(dummy_out.shape[1:])) @@ -392,6 +394,57 @@ def train_siamese_model(args): ) +def train_paired_model(args): + full_model, encoders, decoders = make_paired_autoencoder_model(**args.__dict__) + generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) + train_model_from_generators( + full_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, + args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, + save_last_model=True + ) + for tm in encoders: + encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5') + for tm in decoders: + decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5') + out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') + test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) + samples = args.test_steps * args.batch_size + predictions_list = full_model.predict(test_data) + predictions_dict = {name: pred for name, pred in zip(full_model.output_names, predictions_list)} + logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') + performance_metrics = {} + for i, etm in enumerate(encoders): + embed = encoders[etm].predict(test_data[etm.input_name()]) + plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) + for dtm in decoders: + if dtm.axes() > 1: + reconstruction = decoders[dtm].predict(embed) + logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') + my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') + if not os.path.exists(os.path.dirname(my_out_path)): + os.makedirs(os.path.dirname(my_out_path)) + plot_reconstruction(dtm, test_data[dtm.input_name()], reconstruction, my_out_path, test_paths, samples) + else: + y_truth = np.array(test_labels[dtm.output_name()]) + performance_metrics.update(evaluate_predictions(dtm, decoders[dtm].predict(embed), y_truth, {}, dtm.name, os.path.join(args.output_folder, args.id), test_paths)) + return performance_metrics + + +def inspect_paired_model(args): + full_model, encoders, decoders = make_paired_autoencoder_model(**args.__dict__) + infer_hidden_tsv = _hidden_file_name(args.output_folder, 'hidden_inference_', args.id, '.tsv') + latent_df = latent_space_dataframe(infer_hidden_tsv, args.app_csv) + out_folder = os.path.join(args.output_folder, args.id, 'latent_transformations/') + for tm in args.tensor_maps_protected: + index2channel = {v: k for k, v in tm.channel_map.items()} + thresh = 1 if tm.is_categorical() else tm.normalization.mean + plot_hit_to_miss_transforms(latent_df, decoders, + feature=index2channel[0], + thresh=thresh, + latent_dimension=args.dense_layers[0], + prefix=out_folder) + + def plot_predictions(args): _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = make_multimodal_multitask_model(**args.__dict__) diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index ceb4389e8..64242bcbe 100644 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -87,10 +87,6 @@ def write_tensors( mri_unzip: str, mri_field_ids: List[int], xml_field_ids: List[int], - zoom_x: int, - zoom_y: int, - zoom_width: int, - zoom_height: int, write_pngs: bool, min_sample_id: int, max_sample_id: int, @@ -109,13 +105,6 @@ def write_tensors( :param mri_unzip: Folder where zipped DICOM will be decompressed :param mri_field_ids: List of MRI field IDs from UKBB :param xml_field_ids: List of ECG field IDs from UKBB - :param x: Maximum x dimension of MRIs - :param y: Maximum y dimension of MRIs - :param z: Maximum z dimension of MRIs - :param zoom_x: x coordinate of the zoom - :param zoom_y: y coordinate of the zoom - :param zoom_width: width of the zoom - :param zoom_height: height of the zoom :param write_pngs: write MRIs as PNG images for debugging :param min_sample_id: Minimum sample id to generate, for parallelization :param max_sample_id: Maximum sample id to generate, for parallelization @@ -137,7 +126,7 @@ def write_tensors( continue try: with h5py.File(tp, 'w') as hd5: - _write_tensors_from_zipped_dicoms(zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) + _write_tensors_from_zipped_dicoms(write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) _write_tensors_from_zipped_niftis(zip_folder, mri_field_ids, hd5, sample_id, stats) _write_tensors_from_xml(xml_field_ids, xml_folder, hd5, sample_id, write_pngs, stats, continuous_stats) stats['Tensors written'] += 1 @@ -188,19 +177,26 @@ def write_tensors_from_dicom_pngs( continue stats[sample_header + '_' + sample_id] += 1 dicom_file = row[dicom_index] + try: png = imageio.imread(os.path.join(png_path, dicom_file + png_postfix)) + if len(png.shape) == 3 and png.mean() == png[:, :, 0].mean(): + png = png[:, :, 0] + elif len(png.shape) == 3: + raise ValueError(f'PNG has color information but no method to tensorize it {png.mean()}, 0ch :{png[:, :, 0].mean()}, 1ch :{png[:, :, 1].mean()}, 2ch :{png[:, :, 2].mean()}.') full_tensor = np.zeros((x, y), dtype=np.float32) full_tensor[:png.shape[0], :png.shape[1]] = png tensor_file = os.path.join(tensors, str(sample_id) + TENSOR_EXT) if not os.path.exists(os.path.dirname(tensor_file)): os.makedirs(os.path.dirname(tensor_file)) with h5py.File(tensor_file, 'a') as hd5: - tensor_name = series + '_annotated_' + row[instance_index] + tensor_name = series.lower() + '_annotated_' + row[instance_index] tp = tensor_path(path_prefix, tensor_name) if tp in hd5: tensor = first_dataset_at_path(hd5, tp) - tensor[:] = full_tensor + min_x = min(png.shape[0], tensor.shape[0]) + min_y = min(png.shape[1], tensor.shape[1]) + tensor[:min_x, :min_y] = full_tensor[:min_x, :min_y] stats['updated'] += 1 else: create_tensor_in_hd5(hd5, path_prefix, tensor_name, full_tensor, stats) @@ -328,7 +324,7 @@ def _dicts_and_plots_from_tensorization( continuous = {} value_counter = Counter() for k in sorted(list(stats.keys())): - logging.info("{} has {}".format(k, stats[k])) + #logging.info("{} has {}".format(k, stats[k])) if 'categorical' not in k and 'continuous' not in k: continue @@ -346,10 +342,10 @@ def _dicts_and_plots_from_tensorization( plot_value_counter(list(categories.keys()), value_counter, a_id + '_v_count', os.path.join(output_folder, a_id)) plot_histograms(continuous_stats, a_id, os.path.join(output_folder, a_id)) - logging.info("Continuous tensor map: {}".format(continuous)) - logging.info("Continuous Columns: {}".format(len(continuous))) - logging.info("Category tensor map: {}".format(categories)) - logging.info("Categories Columns: {}".format(len(categories))) + # logging.info("Continuous tensor map: {}".format(continuous)) + # logging.info("Continuous Columns: {}".format(len(continuous))) + # logging.info("Category tensor map: {}".format(categories)) + # logging.info("Categories Columns: {}".format(len(categories))) def _to_float_or_false(s): @@ -367,10 +363,6 @@ def _to_float_or_nan(s): def _write_tensors_from_zipped_dicoms( - zoom_x: int, - zoom_y: int, - zoom_width: int, - zoom_height: int, write_pngs: bool, tensors: str, dicoms: str, @@ -390,10 +382,8 @@ def _write_tensors_from_zipped_dicoms( os.makedirs(dicom_folder) with zipfile.ZipFile(zipped, "r") as zip_ref: zip_ref.extractall(dicom_folder) - _write_tensors_from_dicoms( - zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, dicom_folder, - hd5, sample_str, stats, - ) + ukb_instance = zipped.split('_')[2] + _write_tensors_from_dicoms(write_pngs, tensors, dicom_folder, hd5, sample_str, ukb_instance, stats) stats['MRI fields written'] += 1 shutil.rmtree(dicom_folder) @@ -410,36 +400,31 @@ def _write_tensors_from_zipped_niftis(zip_folder: str, mri_field_ids: List[str], def _write_tensors_from_dicoms( - zoom_x: int, zoom_y: int, zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str, - dicom_folder: str, hd5: h5py.File, sample_str: str, stats: Dict[str, int], + write_pngs: bool, tensors: str, dicom_folder: str, hd5: h5py.File, sample_str: str, ukb_instance: str, stats: Dict[str, int], ) -> None: """Convert a folder of DICOMs from a sample into tensors for each series Segmented dicoms require special processing and are written to tensor per-slice Arguments - :param x: Width of the tensors (actual MRI width will be padded with 0s or cropped to this number) - :param y: Height of the tensors (actual MRI width will be padded with 0s or cropped to this number) - :param z: Minimum number of slices to include in the each tensor if more slices are found they will be kept - :param zoom_x: x coordinate of the zoom - :param zoom_y: y coordinate of the zoom - :param zoom_width: width of the zoom - :param zoom_height: height of the zoom :param write_pngs: write MRIs as PNG images for debugging :param tensors: Folder where hd5 tensor files are being written :param dicom_folder: Folder with all dicoms associated with one sample. :param hd5: Tensor file in which to create datasets for each series and each segmented slice :param sample_str: The current sample ID as a string + :param ukb_instance: The UK Biobank assessment visit instance number :param stats: Counter to keep track of summary statistics """ views = defaultdict(list) + series_to_numbers = defaultdict(set) min_ideal_series = 9e9 for dicom in os.listdir(dicom_folder): if os.path.splitext(dicom)[-1] != DICOM_EXT: continue d = pydicom.read_file(os.path.join(dicom_folder, dicom)) series = d.SeriesDescription.lower().replace(' ', '_') + series_to_numbers[series].add(int(d.SeriesNumber)) if series + '_12bit' in MRI_LIVER_SERIES_12BIT and d.LargestImagePixelValue > 2048: views[series + '_12bit'].append(d) stats[series + '_12bit'] += 1 @@ -462,99 +447,61 @@ def _write_tensors_from_dicoms( else: mri_group = 'ukb_mri' + if len(series_to_numbers[v]) > 1 and v not in MRI_BRAIN_SERIES: + max_series = max(series_to_numbers[v]) + single_series = [dicom for dicom in views[v] if int(dicom.SeriesNumber) == max_series] + # for d in views[v]: + # logging.warning(f'{d.SeriesNumber} with Date: {_datetime_from_dicom(d)} Time {d.AcquisitionTime}') + logging.warning(f'{v} has {len(views[v])} series:{series_to_numbers[v]} Using only max series: {max_series} with {len(single_series)}') + views[v] = single_series if v == MRI_TO_SEGMENT: - _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, hd5, mri_date, mri_group, stats) + _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, ukb_instance, hd5, mri_date, mri_group, stats) elif v in MRI_BRAIN_SERIES: _tensorize_brain_mri(views[v], v, mri_date, mri_group, hd5) else: - mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32) - for slicer in views[v]: - _save_pixel_dimensions_if_missing(slicer, v, hd5) - _save_slice_thickness_if_missing(slicer, v, hd5) - _save_series_orientation_and_position_if_missing(slicer, v, hd5) - slice_index = slicer.InstanceNumber - 1 - if v in MRI_LIVER_IDEAL_PROTOCOL: - slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series) - mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32) - create_tensor_in_hd5(hd5, mri_group, v, mri_data, stats, mri_date) + pass + # mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32) + # for slicer in views[v]: + # _save_pixel_dimensions_if_missing(slicer, v, hd5) + # _save_slice_thickness_if_missing(slicer, v, hd5) + # _save_series_orientation_and_position_if_missing(slicer, v, hd5) + # slice_index = slicer.InstanceNumber - 1 + # if v in MRI_LIVER_IDEAL_PROTOCOL: + # slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series) + # mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32) + # create_tensor_in_hd5(hd5, mri_group, f'{v}/{ukb_instance}', mri_data, stats, mri_date) def _tensorize_short_and_long_axis_segmented_cardiac_mri( - slices: List[pydicom.Dataset], series: str, zoom_x: int, zoom_y: int, - zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str, - hd5: h5py.File, mri_date: datetime.datetime, mri_group: str, - stats: Dict[str, int], + slices: List[pydicom.Dataset], series: str, instance: str, + hd5: h5py.File, mri_date: datetime.datetime, mri_group: str, stats: Dict[str, int], ) -> None: - systoles = {} - diastoles = {} - systoles_pix = {} - systoles_masks = {} - diastoles_masks = {} - for slicer in slices: - full_mask = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) - full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) - + #full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) if _has_overlay(slicer): if _is_mitral_valve_segmentation(slicer): series = series.replace('sax', 'lax') else: series = series.replace('lax', 'sax') - series_segmented = f'{series}_segmented' - series_zoom = f'{series}_zoom' - series_zoom_segmented = f'{series}_zoom_segmented' + series_segmented = f'{series}_segmented' try: overlay, mask, ventricle_pixels, _ = _get_overlay_from_dicom(slicer) except KeyError: logging.exception(f'Got key error trying to make anatomical mask, skipping.') continue - _save_pixel_dimensions_if_missing(slicer, series, hd5) - _save_slice_thickness_if_missing(slicer, series, hd5) - _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber)) + # _save_pixel_dimensions_if_missing(slicer, series, hd5) + # _save_slice_thickness_if_missing(slicer, series, hd5) + # _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber)) _save_pixel_dimensions_if_missing(slicer, series_segmented, hd5) _save_slice_thickness_if_missing(slicer, series_segmented, hd5) _save_series_orientation_and_position_if_missing(slicer, series_segmented, hd5, str(slicer.InstanceNumber)) - - cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based - full_slice[:] = slicer.pixel_array.astype(np.float32) - create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', full_slice, stats, mri_date) - create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', mask, stats, mri_date) - - zoom_slice = full_slice[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height] - zoom_mask = mask[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height] - create_tensor_in_hd5(hd5, mri_group, f'{series_zoom}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_slice, stats, mri_date) - create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_mask, stats, mri_date) - - if (slicer.InstanceNumber - 1) % MRI_FRAMES == 0: # Diastole frame is always the first - diastoles[cur_angle] = slicer - diastoles_masks[cur_angle] = mask - if cur_angle not in systoles: - systoles[cur_angle] = slicer - systoles_pix[cur_angle] = ventricle_pixels - systoles_masks[cur_angle] = mask - else: - if ventricle_pixels < systoles_pix[cur_angle]: - systoles[cur_angle] = slicer - systoles_pix[cur_angle] = ventricle_pixels - systoles_masks[cur_angle] = mask - - for angle in diastoles: - logging.info(f'Found systole, instance:{systoles[angle].InstanceNumber} ventricle pixels:{systoles_pix[angle]}') - full_slice = diastoles[angle].pixel_array.astype(np.float32) - create_tensor_in_hd5(hd5, mri_group, f'diastole_frame_b{angle}', full_slice, stats, mri_date) - create_tensor_in_hd5(hd5, mri_group, f'diastole_mask_b{angle}', diastoles_masks[angle], stats, mri_date) - if write_pngs: - plt.imsave(tensors + 'diastole_frame_b' + str(angle) + IMAGE_EXT, full_slice) - plt.imsave(tensors + 'diastole_mask_b' + str(angle) + IMAGE_EXT, full_mask) - - full_slice = systoles[angle].pixel_array.astype(np.float32) - create_tensor_in_hd5(hd5, mri_group, f'systole_frame_b{angle}', full_slice, stats, mri_date) - create_tensor_in_hd5(hd5, mri_group, f'systole_mask_b{angle}', systoles_masks[angle], stats, mri_date) - if write_pngs: - plt.imsave(tensors + 'systole_frame_b' + str(angle) + IMAGE_EXT, full_slice) - plt.imsave(tensors + 'systole_mask_b' + str(angle) + IMAGE_EXT, full_mask) + # + # cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based + #full_slice[:] = slicer.pixel_array.astype(np.float32) + #create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{instance}', full_slice, stats, mri_date, slicer.InstanceNumber) + create_tensor_in_hd5(hd5, mri_group, f'{series_segmented}{HD5_GROUP_CHAR}{instance}', mask, stats, mri_date, slicer.InstanceNumber) def _tensorize_brain_mri(slices: List[pydicom.Dataset], series: str, mri_date: datetime.datetime, mri_group: str, hd5: h5py.File) -> None: @@ -588,13 +535,16 @@ def _save_slice_thickness_if_missing(slicer, series, hd5): def _save_series_orientation_and_position_if_missing(slicer, series, hd5, instance=None): orientation_ds_name = MRI_PATIENT_ORIENTATION + '_' + series position_ds_name = MRI_PATIENT_POSITION + '_' + series - if instance: - orientation_ds_name += HD5_GROUP_CHAR + instance - position_ds_name += HD5_GROUP_CHAR + instance - if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: - hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) - if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: - hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) + if instance is not None: + orientation_ds_name = f'{orientation_ds_name}_{instance}' + position_ds_name = f'{position_ds_name}_{instance}' + try: + if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) + if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) + except RuntimeError as e: + logging.warning(f' got error {e} \n orientation : {orientation_ds_name} {slicer.ImageOrientationPatient} and pos: {position_ds_name} {slicer.ImagePositionPatient}') def _has_overlay(d) -> bool: @@ -745,13 +695,16 @@ def _write_ecg_rest_tensors(ecgs, xml_field, hd5, sample_id, write_pngs, stats, def create_tensor_in_hd5( hd5: h5py.File, path_prefix: str, name: str, value, stats: Counter = None, date: datetime.datetime = None, - storage_type: StorageType = None, attributes: Dict[str, Any] = None, + instance: str = None, storage_type: StorageType = None, attributes: Dict[str, Any] = None, ): hd5_path = tensor_path(path_prefix, name) + if instance is not None: + hd5_path = f'{hd5_path}instance_{instance}/' if hd5_path in hd5: hd5_path = f'{hd5_path}instance_{len(hd5[hd5_path])}' - else: + elif instance is None: hd5_path = f'{hd5_path}instance_0' + if stats is not None: stats[hd5_path] += 1 if storage_type == StorageType.STRING: diff --git a/ml4h/tensormap/mgb/dynamic.py b/ml4h/tensormap/mgb/dynamic.py index 2708a8f85..8e858a654 100644 --- a/ml4h/tensormap/mgb/dynamic.py +++ b/ml4h/tensormap/mgb/dynamic.py @@ -23,7 +23,8 @@ INCIDENCE_CSV = '/media/erisone_snf13/lc_outcomes.csv' CARDIAC_SURGERY_OUTCOMES_CSV = '/data/sts-data/mgh-preop-ecg-outcome-labels.csv' PARTNERS_PREFIX = 'partners_ecg_rest' -WIDE_FILE = '/home/sam/ml/hf-wide-2020-08-18-with-lvh-and-lbbb.tsv' +WIDE_FILE = '/home/sam/ml/hf-wide-2020-09-15-with-lvh-and-lbbb.tsv' +#WIDE_FILE = '/home/sam/ml/mgh-wide-2020-06-25-with-mrn.tsv' def make_mgb_dynamic_tensor_maps(desired_map_name: str) -> TensorMap: @@ -511,6 +512,7 @@ def _days_to_years_float(s: str): except ValueError: return None + def _time_to_event_tensor_from_days(tm: TensorMap, has_disease: int, follow_up_days: int): tensor = np.zeros(tm.shape, dtype=np.float32) if follow_up_days > tm.days_window: @@ -535,7 +537,7 @@ def _survival_curve_tensor_from_dates(tm: TensorMap, has_disease: int, assessmen def tensor_from_wide( - file_name: str, patient_column: str = 'fpath', age_column: str = 'age', bmi_column: str = 'bmi', + file_name: str, patient_column: str = 'Mrn', age_column: str = 'age', bmi_column: str = 'bmi', sex_column: str = 'sex', hf_column: str = 'any_hf_age', start_column: str = 'start_fu', end_column: str = 'last_encounter', delimiter: str = '\t', population_normalize: int = 2000, target: str = 'ecg', skip_prevalent: bool = True, @@ -562,8 +564,11 @@ def tensor_from_wide( try: patient_key = int(float(row[patient_index])) patient_data[patient_key] = { - 'age': _days_to_years_float(row[age_index]), 'bmi': _to_float_or_none(row[bmi_index]), 'sex': row[sex_index], - 'hf_age': _days_to_years_float(row[hf_index]), 'end_age': _days_to_years_float(row[end_index]), + 'age': _days_to_years_float(row[age_index]), + 'bmi': _to_float_or_none(row[bmi_index]), + 'sex': row[sex_index], + 'hf_age': _days_to_years_float(row[hf_index]), + 'end_age': _days_to_years_float(row[end_index]), 'start_date': datetime.datetime.strptime(row[start_index], CARDIAC_SURGERY_DATE_FORMAT), } @@ -574,7 +579,7 @@ def tensor_from_wide( def tensor_from_file(tm: TensorMap, hd5: h5py.File, dependents=None): mrn_int = _hd5_filename_to_mrn_int(hd5.filename) if mrn_int not in patient_data: - raise KeyError(f'{tm.name} mrn not in legacy csv.') + raise KeyError(f'{tm.name} mrn not in csv.') if patient_data[mrn_int]['end_age'] is None or patient_data[mrn_int]['age'] is None: raise ValueError(f'{tm.name} could not find ages.') if patient_data[mrn_int]['end_age'] - patient_data[mrn_int]['age'] < 0: diff --git a/ml4h/tensormap/ukb/by_script.py b/ml4h/tensormap/ukb/by_script.py index 8f4591894..260f883e9 100644 --- a/ml4h/tensormap/ukb/by_script.py +++ b/ml4h/tensormap/ukb/by_script.py @@ -4275,8 +4275,10 @@ ukb_23098_0 = TensorMap('23098_Weight_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 78.03287707120329, 'std': 15.89979106473436}, annotation_units=1, channel_map={'23098_Weight_0_0': 0, }) ukb_23099_0 = TensorMap('23099_Body-fat-percentage_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.451786835511985, 'std': 8.547574421386416}, annotation_units=1, channel_map={'23099_Body-fat-percentage_0_0': 0, }) ukb_23099_1 = TensorMap('23099_Body-fat-percentage_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.264657151412308, 'std': 8.292322694561557}, annotation_units=1, channel_map={'23099_Body-fat-percentage_1_0': 0, }) +ukb_23099_2 = TensorMap('23099_Body-fat-percentage_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.264657151412308, 'std': 8.292322694561557}, annotation_units=1, channel_map={'23099_Body-fat-percentage_2_0': 0, }) ukb_23100_1 = TensorMap('23100_Whole-body-fat-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.313408909308837, 'std': 9.148610869548651}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_1_0': 0, }) -ukb_23100_0 = TensorMap('23100_Whole-body-fat-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.85684055488536, 'std': 9.565426323411884}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_0_0': 0, }) +ukb_23100_2 = TensorMap('23100_Whole-body-fat-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 24.85684055488536, 'std': 9.565426323411884}, annotation_units=1, channel_map={'23100_Whole-body-fat-mass_2_0': 0, }) +ukb_23101_2 = TensorMap('23101_Whole-body-fatfree-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 52.606156455797255, 'std': 11.126052649633138}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_2_0': 0, }) ukb_23101_1 = TensorMap('23101_Whole-body-fatfree-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 52.606156455797255, 'std': 11.126052649633138}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_1_0': 0, }) ukb_23101_0 = TensorMap('23101_Whole-body-fatfree-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 53.21773908875434, 'std': 11.496421819167042}, annotation_units=1, channel_map={'23101_Whole-body-fatfree-mass_0_0': 0, }) ukb_23102_0 = TensorMap('23102_Whole-body-water-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 38.947460814537706, 'std': 8.414129150401765}, annotation_units=1, channel_map={'23102_Whole-body-water-mass_0_0': 0, }) @@ -4300,44 +4302,59 @@ ukb_23110_1 = TensorMap('23110_Impedance-of-arm-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 337.08795210775764, 'std': 57.39185814366964}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_1_0': 0, }) ukb_23110_2 = TensorMap('23110_Impedance-of-arm-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 326.5779425175772, 'std': 54.696785486803726}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_2_0': 0, }) ukb_23110_0 = TensorMap('23110_Impedance-of-arm-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 332.26580591961624, 'std': 56.83465434478377}, annotation_units=1, channel_map={'23110_Impedance-of-arm-left_0_0': 0, }) +ukb_23111_2 = TensorMap('23111_Leg-fat-percentage-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.525313045647284, 'std': 10.627681452664147}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_2_0': 0, }) ukb_23111_1 = TensorMap('23111_Leg-fat-percentage-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.525313045647284, 'std': 10.627681452664147}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_1_0': 0, }) ukb_23111_0 = TensorMap('23111_Leg-fat-percentage-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 32.04635516524227, 'std': 10.694715982950983}, annotation_units=1, channel_map={'23111_Leg-fat-percentage-right_0_0': 0, }) +ukb_23112_2 = TensorMap('23112_Leg-fat-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.126630082314794, 'std': 1.8301929159686312}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_2_0': 0, }) ukb_23112_1 = TensorMap('23112_Leg-fat-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.126630082314794, 'std': 1.8301929159686312}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_1_0': 0, }) ukb_23112_0 = TensorMap('23112_Leg-fat-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.308648770510948, 'std': 1.898437368686469}, annotation_units=1, channel_map={'23112_Leg-fat-mass-right_0_0': 0, }) +ukb_23113_2 = TensorMap('23113_Leg-fatfree-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.792466949363929, 'std': 1.9185583701790712}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_2_0': 0, }) ukb_23113_1 = TensorMap('23113_Leg-fatfree-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.792466949363929, 'std': 1.9185583701790712}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_1_0': 0, }) ukb_23113_0 = TensorMap('23113_Leg-fatfree-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.96966001620808, 'std': 2.0246336838262913}, annotation_units=1, channel_map={'23113_Leg-fatfree-mass-right_0_0': 0, }) ukb_23114_0 = TensorMap('23114_Leg-predicted-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.48647434933969, 'std': 1.9273038036404626}, annotation_units=1, channel_map={'23114_Leg-predicted-mass-right_0_0': 0, }) ukb_23114_1 = TensorMap('23114_Leg-predicted-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.322803691693691, 'std': 1.828108029602454}, annotation_units=1, channel_map={'23114_Leg-predicted-mass-right_1_0': 0, }) +ukb_23115_2 = TensorMap('23115_Leg-fat-percentage-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.41810426540285, 'std': 10.597272488375483}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_2_0': 0, }) ukb_23115_1 = TensorMap('23115_Leg-fat-percentage-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.41810426540285, 'std': 10.597272488375483}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_1_0': 0, }) ukb_23115_0 = TensorMap('23115_Leg-fat-percentage-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.951824770075714, 'std': 10.648351373024937}, annotation_units=1, channel_map={'23115_Leg-fat-percentage-left_0_0': 0, }) ukb_23116_0 = TensorMap('23116_Leg-fat-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.242730601973062, 'std': 1.871902668963224}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_0_0': 0, }) ukb_23116_1 = TensorMap('23116_Leg-fat-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.064025941631329, 'std': 1.8044984851208248}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_1_0': 0, }) +ukb_23116_2 = TensorMap('23116_Leg-fat-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 4.064025941631329, 'std': 1.8044984851208248}, annotation_units=1, channel_map={'23116_Leg-fat-mass-left_2_0': 0, }) +ukb_23117_2 = TensorMap('23117_Leg-fatfree-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.697755051134946, 'std': 1.9100971996154277}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_2_0': 0, }) ukb_23117_1 = TensorMap('23117_Leg-fatfree-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.697755051134946, 'std': 1.9100971996154277}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_1_0': 0, }) ukb_23117_0 = TensorMap('23117_Leg-fatfree-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.866400981491417, 'std': 2.01067768245527}, annotation_units=1, channel_map={'23117_Leg-fatfree-mass-left_0_0': 0, }) ukb_23118_1 = TensorMap('23118_Leg-predicted-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.23430281865802, 'std': 1.819746251807764}, annotation_units=1, channel_map={'23118_Leg-predicted-mass-left_1_0': 0, }) ukb_23118_0 = TensorMap('23118_Leg-predicted-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 8.389605651769429, 'std': 1.914070754426116}, annotation_units=1, channel_map={'23118_Leg-predicted-mass-left_0_0': 0, }) ukb_23119_0 = TensorMap('23119_Arm-fat-percentage-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.531533189516836, 'std': 10.173278005428738}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_0_0': 0, }) ukb_23119_1 = TensorMap('23119_Arm-fat-percentage-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.08498378648042, 'std': 9.660940777028525}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_1_0': 0, }) +ukb_23119_2 = TensorMap('23119_Arm-fat-percentage-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.08498378648042, 'std': 9.660940777028525}, annotation_units=1, channel_map={'23119_Arm-fat-percentage-right_2_0': 0, }) +ukb_23120_2 = TensorMap('23120_Arm-fat-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.190551259665752, 'std': 0.5962184768943481}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_2_0': 0, }) ukb_23120_1 = TensorMap('23120_Arm-fat-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.190551259665752, 'std': 0.5962184768943481}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_1_0': 0, }) ukb_23120_0 = TensorMap('23120_Arm-fat-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.239808434567098, 'std': 0.6384610930020305}, annotation_units=1, channel_map={'23120_Arm-fat-mass-right_0_0': 0, }) +ukb_23121_2 = TensorMap('23121_Arm-fatfree-mass-right_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.838897480668496, 'std': 0.7716166268521579}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_2_0': 0, }) ukb_23121_1 = TensorMap('23121_Arm-fatfree-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.838897480668496, 'std': 0.7716166268521579}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_1_0': 0, }) ukb_23121_0 = TensorMap('23121_Arm-fatfree-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8931570158616644, 'std': 0.8215411754827264}, annotation_units=1, channel_map={'23121_Arm-fatfree-mass-right_0_0': 0, }) ukb_23122_0 = TensorMap('23122_Arm-predicted-mass-right_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.7094604752076186, 'std': 0.7833267524137338}, annotation_units=1, channel_map={'23122_Arm-predicted-mass-right_0_0': 0, }) ukb_23122_1 = TensorMap('23122_Arm-predicted-mass-right_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.6591219755550015, 'std': 0.7367273170117283}, annotation_units=1, channel_map={'23122_Arm-predicted-mass-right_1_0': 0, }) +ukb_23123_2 = TensorMap('23123_Arm-fat-percentage-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.945781990521326, 'std': 9.82011062734792}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_2_0': 0, }) ukb_23123_1 = TensorMap('23123_Arm-fat-percentage-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.945781990521326, 'std': 9.82011062734792}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_1_0': 0, }) ukb_23123_0 = TensorMap('23123_Arm-fat-percentage-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 30.42513586266444, 'std': 10.267361725229156}, annotation_units=1, channel_map={'23123_Arm-fat-percentage-left_0_0': 0, }) ukb_23124_0 = TensorMap('23124_Arm-fat-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.3197648029289917, 'std': 0.7131871500892241}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_0_0': 0, }) ukb_23124_1 = TensorMap('23124_Arm-fat-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.2664488126122535, 'std': 0.6681665072338635}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_1_0': 0, }) +ukb_23124_2 = TensorMap('23124_Arm-fat-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 1.2664488126122535, 'std': 0.6681665072338635}, annotation_units=1, channel_map={'23124_Arm-fat-mass-left_2_0': 0, }) +ukb_23125_2 = TensorMap('23125_Arm-fatfree-mass-left_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8760987777500615, 'std': 0.7893732349099633}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_2_0': 0, }) ukb_23125_1 = TensorMap('23125_Arm-fatfree-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.8760987777500615, 'std': 0.7893732349099633}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_1_0': 0, }) ukb_23125_0 = TensorMap('23125_Arm-fatfree-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.9256461729453624, 'std': 0.8372290130379668}, annotation_units=1, channel_map={'23125_Arm-fatfree-mass-left_0_0': 0, }) ukb_23126_0 = TensorMap('23126_Arm-predicted-mass-left_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.7400753431894462, 'std': 0.7977945107270243}, annotation_units=1, channel_map={'23126_Arm-predicted-mass-left_0_0': 0, }) ukb_23126_1 = TensorMap('23126_Arm-predicted-mass-left_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 2.6941082564230485, 'std': 0.7532883471284786}, annotation_units=1, channel_map={'23126_Arm-predicted-mass-left_1_0': 0, }) +ukb_23127_2 = TensorMap('23127_Trunk-fat-percentage_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.227587927163885, 'std': 7.733697629204932}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_2_0': 0, }) ukb_23127_1 = TensorMap('23127_Trunk-fat-percentage_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.227587927163885, 'std': 7.733697629204932}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_1_0': 0, }) ukb_23127_0 = TensorMap('23127_Trunk-fat-percentage_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 31.17342071305921, 'std': 8.008510881672985}, annotation_units=1, channel_map={'23127_Trunk-fat-percentage_0_0': 0, }) +ukb_23128_2 = TensorMap('23128_Trunk-fat-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.669508605637319, 'std': 5.019106497864643}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_2_0': 0, }) ukb_23128_1 = TensorMap('23128_Trunk-fat-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.669508605637319, 'std': 5.019106497864643}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_1_0': 0, }) ukb_23128_0 = TensorMap('23128_Trunk-fat-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 13.735808957116548, 'std': 5.1711708664366345}, annotation_units=1, channel_map={'23128_Trunk-fat-mass_0_0': 0, }) ukb_23129_0 = TensorMap('23129_Trunk-fatfree-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.585953690492165, 'std': 5.981176227799396}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_0_0': 0, }) ukb_23129_1 = TensorMap('23129_Trunk-fatfree-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.42263906211025, 'std': 5.885068008414232}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_1_0': 0, }) +ukb_23129_2 = TensorMap('23129_Trunk-fatfree-mass_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 29.42263906211025, 'std': 5.885068008414232}, annotation_units=1, channel_map={'23129_Trunk-fatfree-mass_2_0': 0, }) ukb_23130_0 = TensorMap('23130_Trunk-predicted-mass_0_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 28.36869182530567, 'std': 5.804851285767162}, annotation_units=1, channel_map={'23130_Trunk-predicted-mass_0_0': 0, }) ukb_23130_1 = TensorMap('23130_Trunk-predicted-mass_1_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 28.2279920179596, 'std': 5.708378220103142}, annotation_units=1, channel_map={'23130_Trunk-predicted-mass_1_0': 0, }) ukb_23200_2 = TensorMap('23200_L1L4-area_2_0', loss='logcosh', path_prefix='continuous', normalization={'mean': 59.829805862939246, 'std': 8.280381408666097}, annotation_units=1, channel_map={'23200_L1L4-area_2_0': 0, }) diff --git a/ml4h/tensormap/ukb/genetics.py b/ml4h/tensormap/ukb/genetics.py index be785c8f5..4f2b6d86d 100644 --- a/ml4h/tensormap/ukb/genetics.py +++ b/ml4h/tensormap/ukb/genetics.py @@ -1,4 +1,5 @@ from ml4h.TensorMap import TensorMap, Interpretation +from ml4h.defines import StorageType from ml4h.metrics import weighted_crossentropy @@ -77,8 +78,11 @@ def _ttn_tensor_from_file(tm, hd5, dependents={}): }, ) -genetic_caucasian = TensorMap('Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', channel_map={'no_caucasian': 0, 'caucasian': 1}) +genetic_caucasian = TensorMap( + 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG, + channel_map={'no_caucasian': 0, 'Genetic-ethnic-grouping_Caucasian_0_0': 1}) + genetic_caucasian_weighted = TensorMap( - 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', - channel_map={'no_caucasian': 0, 'caucasian': 1}, loss=weighted_crossentropy([10.0, 1.0], 'caucasian_loss'), + 'Genetic-ethnic-grouping_Caucasian_0_0', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG, + channel_map={'no_caucasian': 0, 'Genetic-ethnic-grouping_Caucasian_0_0': 1}, loss=weighted_crossentropy([10.0, 1.0], 'caucasian_loss'), ) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index d81ff9c90..c8a6bb323 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -215,7 +215,7 @@ def _slice_tensor_from_file(tm, hd5, dependents={}): def _segmented_dicom_slices(dicom_key_prefix, path_prefix='ukb_cardiac_mri', step=1, total_slices=50): def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}): tensor = np.zeros(tm.shape, dtype=np.float32) - if path_prefix == 'ukb_liver_mri': + if tm.axes() == 3 or path_prefix == 'ukb_liver_mri': categorical_index_slice = get_tensor_at_first_date(hd5, path_prefix, f'{dicom_key_prefix}1') categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) @@ -896,8 +896,9 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): ), ) lax_4ch_diastole_slice0_224_3d = TensorMap( - 'lax_4ch_diastole_slice0_224_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1), - normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/instance_0', 0), + 'lax_4ch_diastole_slice0_224_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1), loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/instance_0', 0), ) lax_4ch_diastole_slice0_256_3d = TensorMap( 'lax_4ch_diastole_slice0_256_3d', Interpretation.CONTINUOUS, shape=(192, 256, 1), @@ -953,8 +954,8 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): 'ukb_cardiac_mri/cine_segmented_ao_dist/instance_0', 0, ), ) -cine_segmented_ao_dist_slice0_3d = TensorMap( - 'cine_segmented_ao_dist_slice0_3d', Interpretation.CONTINUOUS, shape=(256, 256, 1), loss='logcosh', +aorta_diastole_slice0_3d = TensorMap( + 'aorta_diastole_slice0_3d', Interpretation.CONTINUOUS, shape=(192, 256, 1), loss='logcosh', normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_ao_dist/instance_0', 0), ) cine_segmented_lvot_slice0_3d = TensorMap( @@ -1179,6 +1180,10 @@ def _pad_crop_tensor(tm, hd5, dependents={}): channel_map=MRI_SAX_SEGMENTED_CHANNEL_MAP, ) +segmented_aorta_diastole = TensorMap( + 'segmented_aorta_diastole', Interpretation.CATEGORICAL, shape=(192, 256, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slices('cine_segmented_ao_dist_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, +) cine_segmented_ao_dist = TensorMap( 'cine_segmented_ao_dist', Interpretation.CATEGORICAL, shape=(160, 192, 100, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slices('cine_segmented_ao_dist_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, @@ -1251,14 +1256,30 @@ def sax_tensor_from_file(tm, hd5, dependents={}): [1.0, 40.0, 40.0], 'sax_all_diastole_segmented', ), ) +sax_all_diastole_192_segmented_weighted = TensorMap( + 'sax_all_diastole_segmented', Interpretation.CATEGORICAL, shape=(192, 192, 13, 3), + channel_map=MRI_SEGMENTED_CHANNEL_MAP, + loss=weighted_crossentropy( + [1.0, 40.0, 40.0], 'sax_all_diastole_segmented', + ), +) sax_all_diastole = TensorMap( 'sax_all_diastole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('diastole'), - dependent_map=sax_all_diastole_segmented, + path_prefix='ukb_cardiac_mri', ) sax_all_diastole_weighted = TensorMap( 'sax_all_diastole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('diastole'), - dependent_map=sax_all_diastole_segmented_weighted, + dependent_map=sax_all_diastole_segmented_weighted, path_prefix='ukb_cardiac_mri', +) + +sax_all_diastole_192 = TensorMap( + 'sax_all_diastole', shape=(192, 192, 13, 1), tensor_from_file=sax_tensor('diastole'), + dependent_map=sax_all_diastole_segmented, path_prefix='ukb_cardiac_mri', +) +sax_all_diastole_192_weighted = TensorMap( + 'sax_all_diastole', shape=(192, 192, 13, 1), tensor_from_file=sax_tensor('diastole'), + dependent_map=sax_all_diastole_segmented_weighted, path_prefix='ukb_cardiac_mri', ) sax_all_systole_segmented = TensorMap( @@ -1271,6 +1292,7 @@ def sax_tensor_from_file(tm, hd5, dependents={}): loss=weighted_crossentropy([1.0, 40.0, 40.0], 'sax_all_systole_segmented'), ) + sax_all_systole = TensorMap( 'sax_all_systole', shape=(256, 256, 13, 1), tensor_from_file=sax_tensor('systole'), dependent_map=sax_all_systole_segmented, @@ -1351,6 +1373,14 @@ def _slice_tensor_from_file(tm, hd5, dependents={}): 'aorta_slice_nekoui', shape=(200, 240, 1), normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_ao_dist/instance_0', 'cine_segmented_ao_dist_nekoui_annotated_'), ) +lvot_slice_jamesp = TensorMap( + 'lvot_slice_jamesp', shape=(200, 240, 1), normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lvot/instance_0', 'cine_segmented_lvot_jamesp_annotated_'), +) +lvot_slice_nekoui = TensorMap( + 'lvot_slice_nekoui', shape=(200, 240, 1), normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lvot/instance_0', 'cine_segmented_lvot_nekoui_annotated_'), +) lax_2ch_slice_jamesp = TensorMap( 'lax_2ch_slice_jamesp', shape=(192, 160, 1), normalization=ZeroMeanStd1(), tensor_from_file=_slice_tensor_with_segmentation('cine_segmented_lax_2ch/instance_0', 'cine_segmented_lax_2ch_jamesp_annotated_'), @@ -1387,6 +1417,14 @@ def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}): 'cine_segmented_ao_dist', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_AO_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slice('cine_segmented_ao_dist_nekoui_annotated_'), channel_map=MRI_AO_SEGMENTED_CHANNEL_MAP, ) +cine_segmented_lvot_jamesp = TensorMap( + 'cine_segmented_lvot', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_LVOT_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slice('cine_segmented_lvot_jamesp_annotated_'), channel_map=MRI_LVOT_SEGMENTED_CHANNEL_MAP, +) +cine_segmented_lvot_nekoui = TensorMap( + 'cine_segmented_lvot', Interpretation.CATEGORICAL, shape=(200, 240, len(MRI_LVOT_SEGMENTED_CHANNEL_MAP)), + tensor_from_file=_segmented_dicom_slice('cine_segmented_lvot_nekoui_annotated_'), channel_map=MRI_LVOT_SEGMENTED_CHANNEL_MAP, +) cine_segmented_lax_2ch_jamesp = TensorMap( 'cine_segmented_lax_2ch_slice', Interpretation.CATEGORICAL, shape=(192, 160, len(MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP)), tensor_from_file=_segmented_dicom_slice('cine_segmented_lax_2ch_jamesp_annotated_'), channel_map=MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, diff --git a/ml4h/test_utils.py b/ml4h/test_utils.py index e8d53c7fe..d729f1c23 100644 --- a/ml4h/test_utils.py +++ b/ml4h/test_utils.py @@ -35,6 +35,7 @@ ), ] +TENSOR_MAP_PAIRS = [[(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[3])], [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[3]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[4])]] TMAPS_UP_TO_4D = CONTINUOUS_TMAPS[:-1] + CATEGORICAL_TMAPS[:-1] TMAPS_5D = CONTINUOUS_TMAPS[-1:] + CATEGORICAL_TMAPS[-1:] MULTIMODAL_UP_TO_4D = [list(x) for x in product(CONTINUOUS_TMAPS[:-1], CATEGORICAL_TMAPS[:-1])] diff --git a/ml4h/visualization_tools/annotation_storage.py b/ml4h/visualization_tools/annotation_storage.py index d0020b1ec..ac8e89249 100644 --- a/ml4h/visualization_tools/annotation_storage.py +++ b/ml4h/visualization_tools/annotation_storage.py @@ -2,11 +2,9 @@ import abc import datetime -from typing import Optional, Union - +import pandas as pd from google.cloud import bigquery from google.cloud.bigquery import magics as bqmagics -import pandas as pd class AnnotationStorage(abc.ABC): @@ -16,14 +14,12 @@ class AnnotationStorage(abc.ABC): """ @abc.abstractmethod - def describe(self) -> str: + def describe(self): """Return a string describing how annotations are stored.""" + pass @abc.abstractmethod - def submit_annotation( - self, sample_id: Union[int, str], annotator: str, key: str, - value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, - ) -> bool: + def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): """Add an annotation to the collection of annotations. Args: @@ -36,9 +32,10 @@ def submit_annotation( Returns: Whether the submission was successful. Throws an Exception on failure. """ + pass @abc.abstractmethod - def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: + def view_recent_submissions(self, count=10): """View a dataframe of up to [count] most recent submissions. Args: @@ -47,6 +44,7 @@ def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: Returns: A dataframe of the most recent annotations. """ + pass class TransientAnnotationStorage(AnnotationStorage): @@ -58,14 +56,11 @@ class TransientAnnotationStorage(AnnotationStorage): def __init__(self): self.annotations = [] - def describe(self) -> str: + def describe(self): return '''Annotations will be stored in memory only during the duration of this demo.\n For durable storage of annotations, use BigQueryAnnotationStorage instead.''' - def submit_annotation( - self, sample_id: Union[int, str], annotator: str, key: str, - value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, - ) -> bool: + def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): """Add this annotation to our in-memory collection of annotations. Args: @@ -90,7 +85,7 @@ def submit_annotation( self.annotations.append(annotation) return True - def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: + def view_recent_submissions(self, count=10): """View a dataframe of up to [count] most recent submissions. Args: @@ -115,17 +110,14 @@ class BigQueryAnnotationStorage(AnnotationStorage): annotations_schema.json """ - def __init__(self, table: str): + def __init__(self, table): """This table should already exist.""" self.table = table - def describe(self) -> str: + def describe(self): return f'''Annotations are stored in BigQuery table {self.table}''' - def submit_annotation( - self, sample_id: Union[int, str], annotator: str, key: str, - value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, - ) -> bool: + def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): """Call a BigQuery INSERT statement to add a row containing annotation information. Args: @@ -158,7 +150,7 @@ def submit_annotation( # Return whether the submission completed. return submission.done() - def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: + def view_recent_submissions(self, count=10): """View a dataframe of up to [count] most recent submissions. This is a convenience method for use within the annotation flow. For full access to the underlying annotations, diff --git a/ml4h/visualization_tools/annotations.py b/ml4h/visualization_tools/annotations.py index 9ca9c1b44..2400a07d8 100644 --- a/ml4h/visualization_tools/annotations.py +++ b/ml4h/visualization_tools/annotations.py @@ -2,11 +2,8 @@ import os import socket -from typing import Any, Dict, Union - from IPython.display import display from IPython.display import HTML -import pandas as pd import ipywidgets as widgets from ml4h.visualization_tools.annotation_storage import AnnotationStorage from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage @@ -14,18 +11,14 @@ DEFAULT_ANNOTATION_STORAGE = TransientAnnotationStorage() -def _get_df_sample(sample_info: pd.DataFrame, sample_id: Union[int, str]) -> pd.DataFrame: +def _get_df_sample(sample_info, sample_id): """Return a dataframe containing only the row for the indicated sample_id.""" df_sample = sample_info[sample_info['sample_id'] == str(sample_id)] - if df_sample.shape[0] == 0: df_sample = sample_info.query('sample_id == ' + str(sample_id)) + if 0 == df_sample.shape[0]: df_sample = sample_info.query('sample_id == ' + str(sample_id)) return df_sample -def display_annotation_collector( - sample_info: pd.DataFrame, sample_id: Union[int, str], - annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE, - custom_annotation_key: str = None, -) -> None: +def display_annotation_collector(sample_info, sample_id, annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE, custom_annotation_key=None): """Method to create a gui (set of widgets) through which the user can create an annotation and submit it to storage. Args: @@ -33,16 +26,15 @@ def display_annotation_collector( sample_id: The selected sample for which the values will be displayed. annotation_storage: An instance of AnnotationStorage. custom_annotation_key: The key for an annotation of data other than the tabular fields. + + Returns: + A notebook-friendly messages indicating the status of the submission. """ df_sample = _get_df_sample(sample_info, sample_id) if df_sample.shape[0] == 0: - display( - HTML(f'''
- Warning: Sample {sample_id} not present in sample_info DataFrame. -
'''), - ) - return + return HTML(f'''
+ Warning: Sample {sample_id} not present in sample_info DataFrame.
''') # Show the sample ID for this annotation. sample = widgets.HTML(value=f'For sample {sample_id}') @@ -90,7 +82,7 @@ def handle_key_change(change): submit_button = widgets.Button(description='Submit annotation', button_style='success') output = widgets.Output() - def cb_on_button_clicked(b): + def on_button_clicked(b): params = _format_annotation(sample_id=sample_id, key=key.value, keyvalue=keyvalue.value, comment=comment.value) try: success = annotation_storage.submit_annotation( @@ -101,38 +93,34 @@ def cb_on_button_clicked(b): value_string=params['value_string'], comment=params['comment'], ) - except Exception as e: # pylint: disable=broad-except + except Exception as e: display( HTML(f'''
- Warning: Unable to store annotation. -

{e}

-
'''), + Warning: Unable to store annotation. +

{e}

+ '''), ) - return + return() with output: if success: # Show the information that was submitted. display( HTML(f'''
- Submission successful\n[{annotation_storage.describe()}] -
'''), + Submission successful\n[{annotation_storage.describe()}]'''), ) display(annotation_storage.view_recent_submissions(1)) else: display( HTML('''
- Annotation not submitted. Please try again. -
'''), + Annotation not submitted. Please try again.'''), ) - submit_button.on_click(cb_on_button_clicked) + submit_button.on_click(on_button_clicked) # Display all the widgets. display(sample, box1, comment, submit_button, output) -def _format_annotation( - sample_id: Union[int, str], key: str, keyvalue: Union[int, float, str], comment: str, -) -> Dict[str, Any]: +def _format_annotation(sample_id, key, keyvalue, comment): """Helper method to clean and reshape info from the widgets and the environment into a dictionary representing the annotation.""" # Programmatically get the identity of the person running this Terra notebook. current_user = os.getenv('OWNER_EMAIL') @@ -140,10 +128,11 @@ def _format_annotation( if current_user is None: current_user = socket.gethostname() # By convention, we prefix the hostname with our username. - value_numeric = None - value_string = None # Check whether the value is string or numeric. - if keyvalue is not None: + if keyvalue is None: + value_numeric = None + value_string = None + else: try: value_numeric = float(keyvalue) # this will fail if the value is text value_string = None diff --git a/ml4h/visualization_tools/batch_image_annotations.py b/ml4h/visualization_tools/batch_image_annotations.py deleted file mode 100644 index 34ff731df..000000000 --- a/ml4h/visualization_tools/batch_image_annotations.py +++ /dev/null @@ -1,236 +0,0 @@ -"""Methods for batch annotations of images stored as 3D tensors, such as MRIs, from within notebooks.""" - -import json -import os -import socket -import tempfile -from typing import Any, Dict, List - -from IPython.display import display -import numpy as np -import pandas as pd -import h5py -from ipyannotations import PolygonAnnotator -import ipywidgets as widgets -from ml4h.visualization_tools.hd5_mri_plots import MRI_TMAPS -from ml4h.visualization_tools.annotation_storage import AnnotationStorage -from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage -from PIL import Image -import tensorflow as tf - - -class BatchImageAnnotator(): - """Annotate batches of images with polygons drawn over regions of interest.""" - - SUBMIT_BUTTON_DESCRIPTION = 'Submit polygons, goto next sample' - USE_INSTRUCTIONS = ''' -

    -
  • To draw a polygon, click anywhere you'd like to start. Continue to click - along the edge of the polygon until arrive back where you started. To - finish, simply click the first point (highlighted in red). It may be - helpful to increase the point size if you're struggling (using the slider).
  • - -
  • You can change the class of a polygon using the dropdown menu while the - polygon is still "open", or unfinished. If you make a mistake, use the Undo - button until the point that's wrong has disappeared. - -
  • You can move, but not add / subtract polygon points, by clicking the "Edit" - button. Simply drag a point you want to adjust. Again, if you have - difficulty aiming at the points, you can increase the point size.
  • - -
  • You can increase or decrease the contrast and brightness of the image - using the sliders to make it easier to annotate. Sometimes you need to see - what's behind already-created annotations, and for this purpose you can - make them more see-through using the "Opacity" slider.
  • -

- ''' - EXPECTED_COLUMN_NAMES = ['sample_id', 'tmap_name', 'instance_number', 'folder'] - DEFAULT_ANNOTATION_CLASSNAME = 'region_of_interest' - CSS = ''' - - ''' - - def __init__( - self, samples: pd.DataFrame, annotation_categories: List[str] = None, - zoom: float = 1.5, annotation_storage: AnnotationStorage = TransientAnnotationStorage(), - ): - """Initializes an instance of BatchImageAnnotator. - - Args: - samples: A dataframe of samples to annotate. Columns must include those - in BatchImageAnnotator.EXPECTED_COLUMN_NAMES. - annotation_categories: A list of one or more strings to serve as tags for the polygons. - zoom: Desired zoom level for the image. - annotation_storage: An instance of AnnotationStorage. This faciltates the use of a user-provided - strategy for the storage and processing of annotations. - - Raises: - ValueError: The provided dataframe does not contain the expected columns. - """ - if not set(self.EXPECTED_COLUMN_NAMES).issubset(samples.columns): - raise ValueError(f'samples Dataframe must contain columns {self.EXPECTED_COLUMN_NAMES}') - self.samples = samples - self.current_sample = 0 - # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11 - self.zoom = zoom - self.annotation_storage = annotation_storage - if annotation_categories is None: - annotation_categories = [self.DEFAULT_ANNOTATION_CLASSNAME] - - self.annotation_widget = PolygonAnnotator( - options=annotation_categories, - canvas_size=(900, 280 * self.zoom), - ) - self.annotation_widget.on_submit(self._store_annotations) - self.annotation_widget.submit_button.description = self.SUBMIT_BUTTON_DESCRIPTION - self.annotation_widget.submit_button.layout = widgets.Layout(width='300px') - - self.title_widget = widgets.HTML('') - self.results_widget = widgets.HTML('') - - def _store_annotations(self, data: Dict[Any, Any]) -> None: - """Transfer widget state to the annotation storage and advance to the next sample.""" - if self.current_sample >= self.samples.shape[0]: - self.results_widget.value = '

Annotation batch complete!

Thank you for making the model better.' - return - - # Convert polygon points in canvas coordinates to tensor coordinates. - image_canvas_position = self.annotation_widget.canvas.image_extent - x_offset, y_offset, _, _ = image_canvas_position - tensor_coords = [ - ( - a['label'], - [( - int((p[0] - x_offset) / self.zoom), - int((p[1] - y_offset) / self.zoom), - ) for p in a['points']], - ) for a in data - ] - # Store the annotation using the provided annotation storage strategy. - self.annotation_storage.submit_annotation( - sample_id=self.samples.loc[self.current_sample, 'sample_id'], - annotator=os.getenv('OWNER_EMAIL') if os.getenv('OWNER_EMAIL') else socket.gethostname(), - key=self.samples.loc[self.current_sample, 'tmap_name'], - value_numeric=self.samples.loc[self.current_sample, 'instance_number'], - value_string=self.samples.loc[self.current_sample, 'folder'], - comment=json.dumps(tensor_coords), - ) - - # Display this annotation at the bottom of the widget. - results = f''' -
-

Prior sample's submitted annotations

- The {self.SUBMIT_BUTTON_DESCRIPTION} button is both printing out the polygons below and storing the polygons - via strategy {self.annotation_storage.__class__.__name__}.
- Details: {self.annotation_storage.describe()} -

sample info

- {self._format_info_for_current_sample()} -

canvas coordinates

- image extent {image_canvas_position} - {[f'
{json.dumps(x)}
' for x in data]} -

source tensor coordinates

- {[f'
{json.dumps(x)}
' for x in tensor_coords]} -
- ''' - self.results_widget.value = results - - # Advance to the next sample. - self.current_sample += 1 - self._annotate_image_for_current_sample() - - def _format_info_for_current_sample(self) -> str: - """Convert information about the current sample to an HTML table for display within the widget.""" - headings = ' '.join([f'{c}' for c in self.EXPECTED_COLUMN_NAMES] + ['TMAP shape']) - values = ' '.join([f'{self.samples.loc[self.current_sample, c]}' for c in self.EXPECTED_COLUMN_NAMES] - + [f'{MRI_TMAPS[self.samples.loc[self.current_sample, "tmap_name"]].shape}']) - return f''' - - {headings} - {values} -
- ''' - - def _annotate_image_for_current_sample(self) -> None: - """Retrieve the data for the current sample and display its image in the annotation widget. - - If all samples have been processed, display the completion message. - """ - if self.current_sample >= self.samples.shape[0]: - self.annotation_widget.canvas.clear() - # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this - # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15 - self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access - self.title_widget.value = '

Annotation batch complete!

Thank you for making the model better.' - return - - sample_id = self.samples.loc[self.current_sample, 'sample_id'] - tmap_name = self.samples.loc[self.current_sample, 'tmap_name'] - instance_number = self.samples.loc[self.current_sample, 'instance_number'] - folder = self.samples.loc[self.current_sample, 'folder'] - - with tempfile.TemporaryDirectory() as tmpdirname: - sample_hd5 = str(sample_id) + '.hd5' - local_path = os.path.join(tmpdirname, sample_hd5) - try: - tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path) - hd5 = h5py.File(local_path, mode='r') - except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - self.annotation_widget.canvas.clear() - # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this - # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15 - self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access - self.title_widget.value = f''' -
-

Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}

- Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket. -

{e.message}

-
- ''' - return - - tensor = MRI_TMAPS[tmap_name].tensor_from_file(MRI_TMAPS[tmap_name], hd5) - tensor_instance = tensor[:, :, instance_number] - if self.zoom > 1.0: - # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11 - img = Image.fromarray(tensor_instance) - zoomed_img = img.resize([int(self.zoom * s) for s in img.size], Image.LANCZOS) - tensor_instance = np.asarray(zoomed_img) - - self.annotation_widget.display(tensor_instance) - self.title_widget.value = f''' - {self.CSS} -
-

Batch annotation of {self.samples.shape[0]} samples

- {self.USE_INSTRUCTIONS} -
-

Current sample

- {self._format_info_for_current_sample()} -
- ''' - - def annotate_images(self) -> None: - """Begin the batch annotation task by displaying the annotation widget populated with the first sample. - - The submit button is used to proceed to the next sample until all samples have been processed. - """ - self._annotate_image_for_current_sample() - display(widgets.VBox([self.title_widget, self.annotation_widget, self.results_widget])) - - def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: - """View a dataframe of up to [count] most recent submissions. - - Args: - count: The number of the most recent submissions to return. - - Returns: - A dataframe of the most recent annotations. - """ - return self.annotation_storage.view_recent_submissions(count=count) diff --git a/ml4h/visualization_tools/dicom_interactive_plots.py b/ml4h/visualization_tools/dicom_interactive_plots.py index ec9d63834..d9850e841 100644 --- a/ml4h/visualization_tools/dicom_interactive_plots.py +++ b/ml4h/visualization_tools/dicom_interactive_plots.py @@ -1,4 +1,4 @@ -"""Methods for integration of interactive DICOM plots within notebooks. +"""Methods for integration of interactive dicom plots within notebooks. TODO: * Continue to *pragmatically* improve this to make the visualization controls @@ -8,15 +8,14 @@ import collections import os import tempfile -from typing import Any, DefaultDict, Dict, Optional, Tuple import zipfile from IPython.display import display from IPython.display import HTML -import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_mri_folders +import numpy as np import pydicom import tensorflow as tf @@ -28,12 +27,15 @@ MAX_COLOR_RANGE = 6000 -def choose_mri(sample_id, folder: Optional[str] = None) -> None: +def choose_mri(sample_id, folder=None): """Render widget to choose the MRI to plot. Args: sample_id: The id of the sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. + + Returns: + ipywidget or HTML upon error. """ if folder is None: folders = get_mri_folders(sample_id) @@ -43,26 +45,22 @@ def choose_mri(sample_id, folder: Optional[str] = None) -> None: sample_mris = [] sample_mri_glob = str(sample_id) + '_*.zip' try: - for f in folders: - sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(f, sample_mri_glob))) + for folder in folders: + sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob))) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - display( - HTML(f'''
+ return HTML(f''' +
Warning: MRI not available for sample {sample_id} in {folders}:

{e.message}

Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
'''), - ) - return +
''') if not sample_mris: - display( - HTML(f'''
+ return HTML(f''' +
Warning: MRI DICOMs not available for sample {sample_id} in {folders}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
'''), - ) - return +
''') mri_chooser = widgets.Dropdown( options=sample_mris, @@ -79,11 +77,14 @@ def choose_mri(sample_id, folder: Optional[str] = None) -> None: display(file_controls_ui, file_controls_output) -def choose_mri_series(sample_mri: str) -> None: +def choose_mri_series(sample_mri): """Render widgets and interactive plots for MRIs. Args: sample_mri: The local or Cloud Storage path to the MRI file. + + Returns: + ipywidget or HTML upon error. """ with tempfile.TemporaryDirectory() as tmpdirname: local_path = os.path.join(tmpdirname, os.path.basename(sample_mri)) @@ -92,15 +93,13 @@ def choose_mri_series(sample_mri: str) -> None: with zipfile.ZipFile(local_path, 'r') as zip_ref: zip_ref.extractall(tmpdirname) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - display( - HTML(f'''
+ return HTML(f''' +
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:

{e.message}

-
'''), - ) - return +
''') - unordered_dicoms: DefaultDict[Any, Any] = collections.defaultdict(dict) + unordered_dicoms = collections.defaultdict(dict) for dcm_file in os.listdir(tmpdirname): if not dcm_file.endswith('.dcm'): continue @@ -113,13 +112,8 @@ def choose_mri_series(sample_mri: str) -> None: unordered_dicoms[key1][key2] = dcm if not unordered_dicoms: - display( - HTML(f'''
- No series available in MRI for sample {os.path.basename(sample_mri)}. - Try a different MRI. -
'''), - ) - return + print(f'\n\nNo series available in MRI for sample {os.path.basename(sample_mri)}\n\nTry a different MRI.') + return None # Convert from dict of dicts to dict of ordered lists. dicoms = {} @@ -140,7 +134,7 @@ def choose_mri_series(sample_mri: str) -> None: style={'description_width': 'initial'}, layout=widgets.Layout(width='800px'), ) - # Slide through DICOM image instances using a slide bar. + # Slide through dicom image instances using a slide bar. instance_chooser = widgets.IntSlider( continuous_update=True, value=default_instance_value, @@ -218,25 +212,25 @@ def on_value_change(change): display(viz_controls_ui, viz_controls_output) -def compute_color_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]: +def compute_color_range(dicoms, series_name): """Compute the mean values for the color ranges of instances in the series.""" vmin = np.mean([np.min(d.pixel_array) for d in dicoms[series_name]]) vmax = np.mean([np.max(d.pixel_array) for d in dicoms[series_name]]) - return (vmin, vmax) + return(vmin, vmax) -def compute_instance_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]: +def compute_instance_range(dicoms, series_name): """Compute middle and max instances.""" middle_instance = int(len(dicoms[series_name]) / 2) max_instance = len(dicoms[series_name]) - return (middle_instance, max_instance) + return(middle_instance, max_instance) def dicom_animation( - dicoms: Dict[str, Any], series_name: str, instance: int, vmin: int, vmax: int, transpose: bool, - fig_width: int, title_prefix: str = '', -) -> None: - """Render one frame of a DICOM animation. + dicoms, series_name, instance, vmin, vmax, transpose, + fig_width, title_prefix='', +): + """Render one frame of a dicom animation. Args: dicoms: the dictionary DICOM series and instances lists @@ -256,7 +250,7 @@ def dicom_animation( dcm = dicoms[series_name][instance - 1] if instance != dcm.InstanceNumber: # Notice invalid input, but don't throw an error. - print(f'WARNING: Instance parameter {str(instance)} and instance number {str(dcm.InstanceNumber)} do not match.') + print(f'WARNING: Instance parameter {str(instance)} and dicom instance number {str(dcm.InstanceNumber)} do not match.') if transpose: height = dcm.pixel_array.T.shape[0] diff --git a/ml4h/visualization_tools/dicom_plots.py b/ml4h/visualization_tools/dicom_plots.py index 093691382..ce2b3e083 100644 --- a/ml4h/visualization_tools/dicom_plots.py +++ b/ml4h/visualization_tools/dicom_plots.py @@ -1,17 +1,16 @@ -"""Methods for integration of DICOM plots within notebooks.""" +"""Methods for integration of dicom plots within notebooks.""" import collections import os import tempfile -from typing import Dict, List, Optional, Tuple, Union import zipfile from IPython.display import display from IPython.display import HTML -import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_cardiac_mri_folder +import numpy as np import pydicom from scipy.ndimage.morphology import binary_closing from scipy.ndimage.morphology import binary_erosion @@ -28,21 +27,21 @@ MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2} -def _is_mitral_valve_segmentation(d: pydicom.FileDataset) -> bool: - """Determine whether a DICOM has mitral valve segmentation. +def _is_mitral_valve_segmentation(d): # -> bool: + """Determine whether a dicom has mitral valve segmentation. This is used for visualization of CINE_segmented_SAX_InlineVF. Args: - d: the DICOM file + d: the dicom file Returns: - Whether or not the DICOM has mitral valve segmentation + Whether or not the dicom has mitral valve segmentation """ return d.SliceThickness == 6 -def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]: +def _get_overlay_from_dicom(d): """Get an overlay from a DICOM file. Morphological operators are used to transform the pixel outline of the @@ -50,7 +49,7 @@ def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]: is used for visualization of CINE_segmented_SAX_InlineVF. Args: - d: the DICOM file + d: the dicom file Returns: Raw overlay array with myocardium outline, anatomical mask (a pixel @@ -78,30 +77,29 @@ def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]: byte >>= 1 bit += 1 overlay = overlay[:expected_bit_length] - if overlay_frames != 1: - raise ValueError(f'DICOM has {overlay_frames} overlay frames, but only one expected.') - overlay = overlay.reshape(rows, cols) - idx = np.where(overlay == 1) - min_pos = (np.min(idx[0]), np.min(idx[1])) - max_pos = (np.max(idx[0]), np.max(idx[1])) - short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) - small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) - big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) - small_structure = _unit_disk(small_radius) - m1 = binary_closing(overlay, small_structure).astype(np.int) - big_structure = _unit_disk(big_radius) - m2 = binary_closing(overlay, big_structure).astype(np.int) - anatomical_mask = m1 + m2 - ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) - myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) - if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: - erode_structure = _unit_disk(small_radius*1.5) - anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) + if overlay_frames == 1: + overlay = overlay.reshape(rows, cols) + idx = np.where(overlay == 1) + min_pos = (np.min(idx[0]), np.min(idx[1])) + max_pos = (np.max(idx[0]), np.max(idx[1])) + short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) + small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) + big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) + small_structure = _unit_disk(small_radius) + m1 = binary_closing(overlay, small_structure).astype(np.int) + big_structure = _unit_disk(big_radius) + m2 = binary_closing(overlay, big_structure).astype(np.int) + anatomical_mask = m1 + m2 ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) - return overlay, anatomical_mask, ventricle_pixels + myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) + if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: + erode_structure = _unit_disk(small_radius*1.5) + anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) + ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) + return overlay, anatomical_mask, ventricle_pixels -def _unit_disk(r: int) -> np.ndarray: +def _unit_disk(r): # -> np.ndarray: """Get the unit disk for a radius. This is used for visualization of CINE_segmented_SAX_InlineVF. @@ -116,9 +114,7 @@ def _unit_disk(r: int) -> np.ndarray: return (x ** 2 + y ** 2 <= r ** 2).astype(np.int) -def plot_cardiac_long_axis( - b_series: List[pydicom.FileDataset], sides: int = 7, fig_width: int = 18, title_prefix: str = '', -) -> None: +def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''): """Visualize CINE_segmented_SAX_InlineVF series. Args: @@ -172,9 +168,9 @@ def plot_cardiac_long_axis( def plot_cardiac_short_axis( - series: List[pydicom.FileDataset], transpose: bool = False, fig_width: int = 18, - title_prefix: str = '', -) -> None: + series, transpose=False, fig_width=18, + title_prefix='', +): """Visualize CINE_segmented_LAX series. Args: @@ -229,14 +225,14 @@ def plot_cardiac_short_axis( def plot_mri_series( - sample_mri: str, dicoms: Dict[str, pydicom.FileDataset], series_name: str, sax_sides: int, - lax_transpose: bool, fig_width: int, -) -> None: + sample_mri, dicoms, series_name, sax_sides, + lax_transpose, fig_width, +): """Visualize the applicable series within this DICOM. Args: sample_mri: The local or Cloud Storage path to the MRI file. - dicoms: A dictionary of DICOMs. + dicoms: A dictionary of dicoms. series_name: The name of the chosen series. sax_sides: How many sides to display for CINE_segmented_SAX_InlineVF. lax_transpose: Whether to transpose when plotting CINE_segmented_LAX. @@ -262,9 +258,10 @@ def plot_mri_series( ) else: print(f'Visualization not currently implemented for {series_name}.') + return None -def choose_mri_series(sample_mri: str) -> None: +def choose_mri_series(sample_mri): """Render widgets and plots for cardiac MRIs. Visualization is supported for CINE_segmented_SAX_InlineVF series and @@ -272,6 +269,9 @@ def choose_mri_series(sample_mri: str) -> None: Args: sample_mri: The local or Cloud Storage path to the MRI file. + + Returns: + ipywidget or HTML upon error. """ with tempfile.TemporaryDirectory() as tmpdirname: local_path = os.path.join(tmpdirname, os.path.basename(sample_mri)) @@ -280,13 +280,11 @@ def choose_mri_series(sample_mri: str) -> None: with zipfile.ZipFile(local_path, 'r') as zip_ref: zip_ref.extractall(tmpdirname) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - display( - HTML(f'''
+ return HTML(f''' +
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:

{e.message}

-
'''), - ) - return +
''') filtered_dicoms = collections.defaultdict(list) series_descriptions = [] @@ -297,7 +295,7 @@ def choose_mri_series(sample_mri: str) -> None: series_descriptions.append(dcm.SeriesDescription) if 'cine_segmented_lax' in dcm.SeriesDescription.lower(): filtered_dicoms[dcm.SeriesDescription.lower()].append(dcm) - if dcm.SeriesDescription.lower() == 'cine_segmented_sax_inlinevf': + if 'cine_segmented_sax_inlinevf' == dcm.SeriesDescription.lower(): cur_angle = (dcm.InstanceNumber - 1) // MRI_FRAMES filtered_dicoms[f'{dcm.SeriesDescription.lower()}_angle_{str(cur_angle)}'].append(dcm) @@ -352,20 +350,22 @@ def choose_mri_series(sample_mri: str) -> None: ) display(viz_controls_ui, viz_controls_output) else: - display( - HTML(f'''
- Neither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}. - Try a different MRI. -
'''), + print( + f'\n\nNeither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.', + '\n\nTry a different MRI.', ) + return None -def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None) -> None: +def choose_cardiac_mri(sample_id, folder=None): """Render widget to choose the cardiac MRI to plot. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. + + Returns: + ipywidget or HTML upon error. """ if folder is None: folder = get_cardiac_mri_folder(sample_id) @@ -374,23 +374,19 @@ def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None) try: sample_mris = tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob)) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - display( - HTML(f'''
+ return HTML(f''' +
Warning: Cardiac MRI not available for sample {sample_id} in {folder}:

{e.message}

Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
'''), - ) - return +
''') if not sample_mris: - display( - HTML(f'''
+ return HTML(f''' +
Warning: Cardiac MRI DICOM not available for sample {sample_id} in {folder}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
'''), - ) - return +
''') mri_chooser = widgets.Dropdown( options=[(os.path.basename(mri), mri) for mri in sample_mris], diff --git a/ml4h/visualization_tools/ecg_interactive_plots.py b/ml4h/visualization_tools/ecg_interactive_plots.py index 18ed39a9b..97a4e1547 100644 --- a/ml4h/visualization_tools/ecg_interactive_plots.py +++ b/ml4h/visualization_tools/ecg_interactive_plots.py @@ -2,12 +2,10 @@ import os import tempfile -from typing import Optional, Union -from IPython.display import HTML import altair as alt # Interactive data visualization for plots. -from ml4h.TensorMap import TensorMap -from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP +from IPython.display import HTML +from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME from ml4h.visualization_tools.ecg_reshape import reshape_exercise_ecg_to_tidy from ml4h.visualization_tools.ecg_reshape import reshape_resting_ecg_to_tidy @@ -33,21 +31,18 @@ ) -def resting_ecg_interactive_plot( - sample_id: Union[int, str], folder: Optional[str] = None, - tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP, -) -> Union[HTML, alt.Chart]: +def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME): """Wrangle resting ECG data to tidy and present it as an interactive plot. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap: The TensorMap to use for ECG input. + tmap_name: The name of the TMAP to use for ecg input. Returns: An Altair plot or a notebook-friendly error. """ - tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap) + tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap_name) if tidy_resting_ecg_signal.shape[0] == 0: return HTML(f'''
@@ -90,9 +85,7 @@ def resting_ecg_interactive_plot( return upper & lower -def exercise_ecg_interactive_plot( - sample_id: Union[int, str], folder: Optional[str] = None, time_interval_seconds: int = 10, -) -> Union[HTML, alt.Chart]: +def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds=10): """Wrangle exercise ECG data to tidy and present it as an interactive plot. Args: @@ -147,8 +140,7 @@ def exercise_ecg_interactive_plot( lead_select, ).transform_filter( # https://github.com/altair-viz/altair/issues/1960 - f'''((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) - && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})''', + f'((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})', ) return trend.encode(y='heartrate:Q') & trend.encode(y='load:Q') & signal diff --git a/ml4h/visualization_tools/ecg_reshape.py b/ml4h/visualization_tools/ecg_reshape.py index 167eb5012..b3213d359 100644 --- a/ml4h/visualization_tools/ecg_reshape.py +++ b/ml4h/visualization_tools/ecg_reshape.py @@ -1,57 +1,53 @@ """Methods for reshaping raw ECG signal data for use in the pandas ecosystem.""" import os import tempfile -from typing import Any, Dict, Optional, Tuple, Union -import numpy as np -import pandas as pd from biosppy.signals.tools import filter_signal import h5py from ml4h.defines import ECG_BIKE_LEADS from ml4h.defines import ECG_REST_LEADS from ml4h.runtime_data_defines import get_exercise_ecg_hd5_folder from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder -from ml4h.TensorMap import TensorMap -import ml4h.tensormap.ukb.ecg as ecg_tmaps +from ml4h.tensor_maps_by_hand import TMAPS +import numpy as np +import pandas as pd import tensorflow as tf RAW_SCALE = 0.005 # Convert to mV. SAMPLING_RATE = 500.0 -DEFAULT_RESTING_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_rest +DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME = 'ecg_rest' # TODO(deflaux): parameterize exercise ECG by TMAP name if there is similar ECG data from other studies. -EXERCISE_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_bike_raw_full +EXERCISE_ECG_SIGNAL_TMAP = TMAPS['ecg-bike-raw-full'] EXERCISE_ECG_TREND_TMAPS = [ - ecg_tmaps.ecg_bike_raw_trend_hr, - ecg_tmaps.ecg_bike_raw_trend_load, - ecg_tmaps.ecg_bike_raw_trend_grade, - ecg_tmaps.ecg_bike_raw_trend_artifact, - ecg_tmaps.ecg_bike_raw_trend_mets, - ecg_tmaps.ecg_bike_raw_trend_pacecount, - ecg_tmaps.ecg_bike_raw_trend_phasename, - ecg_tmaps.ecg_bike_raw_trend_phasetime, - ecg_tmaps.ecg_bike_raw_trend_time, - ecg_tmaps.ecg_bike_raw_trend_vecount, + TMAPS['ecg-bike-raw-trend-hr'], + TMAPS['ecg-bike-raw-trend-load'], + TMAPS['ecg-bike-raw-trend-grade'], + TMAPS['ecg-bike-raw-trend-artifact'], + TMAPS['ecg-bike-raw-trend-mets'], + TMAPS['ecg-bike-raw-trend-pacecount'], + TMAPS['ecg-bike-raw-trend-phasename'], + TMAPS['ecg-bike-raw-trend-phasetime'], + TMAPS['ecg-bike-raw-trend-time'], + TMAPS['ecg-bike-raw-trend-vecount'], ] EXERCISE_PHASES = {0.0: 'Pretest', 1.0: 'Exercise', 2.0: 'Recovery'} -def _examine_available_keys(hd5: Dict[str, Any]) -> None: +def _examine_available_keys(hd5): print(f'hd5 ECG keys {[k for k in hd5.keys() if "ecg" in k]}') for key in [k for k in hd5.keys() if 'ecg' in k]: - print(f'hd5 {key} keys {k for k in hd5[key]}') + print(f'hd5 {key} keys {[k for k in hd5[key].keys()]}') -def reshape_resting_ecg_to_tidy( - sample_id: Union[int, str], folder: Optional[str] = None, tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP, -) -> pd.DataFrame: +def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME): """Wrangle resting ECG data to tidy. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap: The TensorMap to use for ECG input. + tmap_name: The name of the TMAP to use for ecg input. Returns: A pandas dataframe in tidy format or print a notebook-friendly error and return an empty dataframe. @@ -59,7 +55,7 @@ def reshape_resting_ecg_to_tidy( if folder is None: folder = get_resting_ecg_hd5_folder(sample_id) - data: Dict[str, Any] = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []} + data = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []} with tempfile.TemporaryDirectory() as tmpdirname: sample_hd5 = str(sample_id) + '.hd5' @@ -73,10 +69,10 @@ def reshape_resting_ecg_to_tidy( with h5py.File(local_path, mode='r') as hd5: try: - signals = tmap.tensor_from_file(tmap, hd5) + signals = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5) except (KeyError, ValueError) as e: - print(f'''Warning: Resting ECG TMAP {tmap.name} not available for sample {sample_id}. - Use the tmap parameter to choose a different TMAP.\n\n{e}''') + print(f'''Warning: Resting ECG TMAP {tmap_name} not available for sample {sample_id}. + Use the tmap_name parameter to choose a different TMAP.\n\n{e}''') _examine_available_keys(hd5) return pd.DataFrame(data) for (lead, channel) in ECG_REST_LEADS.items(): @@ -140,9 +136,7 @@ def reshape_resting_ecg_to_tidy( return tidy_signal_df -def reshape_exercise_ecg_to_tidy( - sample_id: Union[int, str], folder: Optional[str] = None, -) -> Tuple[pd.DataFrame, pd.DataFrame]: +def reshape_exercise_ecg_to_tidy(sample_id, folder=None): """Wrangle exercise ECG signal data to tidy format. Args: @@ -214,9 +208,7 @@ def reshape_exercise_ecg_to_tidy( return (trend_df, tidy_signal_df) -def reshape_exercise_ecg_and_trend_to_tidy( - sample_id: Union[int, str], folder: Optional[str] = None, -) -> Tuple[pd.DataFrame, pd.DataFrame]: +def reshape_exercise_ecg_and_trend_to_tidy(sample_id, folder=None): """Wrangle exercise ECG signal and trend data to tidy format. Args: diff --git a/ml4h/visualization_tools/ecg_static_plots.py b/ml4h/visualization_tools/ecg_static_plots.py index ac7283237..2ebcfc3e1 100644 --- a/ml4h/visualization_tools/ecg_static_plots.py +++ b/ml4h/visualization_tools/ecg_static_plots.py @@ -1,18 +1,17 @@ """Methods for integration of static plots within notebooks.""" import os import tempfile -from typing import List, Optional, Union from IPython.display import HTML from IPython.display import SVG -import numpy as np from ml4h.plots import plot_ecg_rest from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder from ml4h.runtime_data_defines import get_resting_ecg_svg_folder +import numpy as np import tensorflow as tf -def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None) -> Union[HTML, SVG]: +def display_resting_ecg(sample_id, folder=None): """Retrieve (or render) and display the SVG of the resting ECG. Args: @@ -54,8 +53,8 @@ def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None try: # We don't need the resulting SVG, so send it to a temporary directory. with tempfile.TemporaryDirectory() as tmpdirname: - return plot_ecg_rest(tensor_paths=[local_path], rows=[0], out_folder=tmpdirname, is_blind=False) - except Exception as e: # pylint: disable=broad-except + plot_ecg_rest(tensor_paths = [local_path], rows=[0], out_folder=tmpdirname, is_blind=False) + except Exception as e: return HTML(f'''
Warning: Unable to render static plot of resting ECG for sample {sample_id} from {hd5_folder}: @@ -63,7 +62,7 @@ def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None
''') -def major_breaks_x_resting_ecg(limits: List[float]) -> np.array: +def major_breaks_x_resting_ecg(limits): """Method to compute breaks for plotnine plots of ECG resting data. Args: diff --git a/ml4h/visualization_tools/facets.py b/ml4h/visualization_tools/facets.py index 18f96327d..a45ea88da 100644 --- a/ml4h/visualization_tools/facets.py +++ b/ml4h/visualization_tools/facets.py @@ -2,7 +2,6 @@ import base64 import os -import pandas as pd from facets_overview.generic_feature_statistics_generator import GenericFeatureStatisticsGenerator FACETS_DEPENDENCIES = { @@ -26,10 +25,10 @@ FACETS_DEPENDENCIES[dep] = os.path.basename(url) -class FacetsOverview(): +class FacetsOverview(object): """Methods for Facets Overview notebook integration.""" - def __init__(self, data: pd.DataFrame): + def __init__(self, data): # This takes the dataframe and computes all the inputs to the Facets # Overview plots such as: # - numeric variables: histogram bins, mean, min, median, max, etc.. @@ -40,7 +39,7 @@ def __init__(self, data: pd.DataFrame): [{'name': 'data', 'table': data}], ) - def _repr_html_(self) -> str: + def _repr_html_(self): """Html representation of Facets Overview for use in a Jupyter notebook.""" protostr = base64.b64encode(self._proto.SerializeToString()).decode('utf-8') html_template = ''' @@ -58,14 +57,14 @@ def _repr_html_(self) -> str: return html -class FacetsDive(): +class FacetsDive(object): """Methods for Facets Dive notebook integration.""" - def __init__(self, data: pd.DataFrame, height: int = 1000): + def __init__(self, data, height=1000): self._data = data self.height = height - def _repr_html_(self) -> str: + def _repr_html_(self): """Html representation of Facets Dive for use in a Jupyter notebook.""" html_template = """ diff --git a/ml4h/visualization_tools/hd5_mri_plots.py b/ml4h/visualization_tools/hd5_mri_plots.py index d3894b39d..20b3305b1 100644 --- a/ml4h/visualization_tools/hd5_mri_plots.py +++ b/ml4h/visualization_tools/hd5_mri_plots.py @@ -1,34 +1,29 @@ """Methods for integration of plots of mri data processed to 3D tensors from within notebooks.""" -from collections import OrderedDict from enum import Enum, auto import os import tempfile -from typing import Any, Dict, List, Optional, Tuple, Union +import h5py from IPython.display import display from IPython.display import HTML -import numpy as np -import h5py import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_mri_hd5_folder -import ml4h.tensormap.ukb.mri as ukb_mri -import ml4h.tensormap.ukb.mri_vtk as ukb_mri_vtk -from ml4h.TensorMap import Interpretation, TensorMap +from ml4h.tensor_maps_by_hand import TMAPS +from ml4h.TensorMap import Interpretation +import numpy as np import tensorflow as tf -# Discover applicable TensorMaps. -MRI_TMAPS = { - key: value for key, value in ukb_mri.__dict__.items() if isinstance(value, TensorMap) - and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3 -} -MRI_TMAPS.update( - { - key: value for key, value in ukb_mri_vtk.__dict__.items() - if isinstance(value, TensorMap) and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3 - }, +# Discover applicable TMAPS. +CARDIAC_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if ('_lax_' in k or '_sax_' in k) and TMAPS[k].axes() == 3] +CARDIAC_MRI_TMAP_NAMES.extend( + [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_cardiac_mri' and TMAPS[k].axes() == 3], ) +LIVER_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_liver_mri' and TMAPS[k].axes() == 3] +BRAIN_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_brain_mri' and TMAPS[k].axes() == 3] +# This includes more than just MRI TMAPS, it is a best effort. +BEST_EFFORT_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].interpretation == Interpretation.CONTINUOUS and TMAPS[k].axes() == 3] MIN_IMAGE_WIDTH = 8 DEFAULT_IMAGE_WIDTH = 12 @@ -46,30 +41,42 @@ class PlotType(Enum): class TensorMapCache: """Cache the tensor to display for reuse when re-plotting the same TMAP with different plot parameters.""" - def __init__(self, hd5: Dict[str, Any], tmap: TensorMap): + def __init__(self, hd5, tmap_name): self.hd5 = hd5 - self.tmap: Optional[TensorMap] = None + self.tmap_name = None self.tensor = None - _ = self.get(tmap) + _ = self.get(tmap_name) - def get(self, tmap: TensorMap) -> np.array: - if self.tmap != tmap: - self.tensor = tmap.tensor_from_file(tmap, self.hd5) - self.tmap = tmap + def get(self, tmap_name): + if self.tmap_name != tmap_name: + self.tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], self.hd5) + self.tmap_name = tmap_name return self.tensor -def choose_mri_tmap( - sample_id: Union[int, str], folder: Optional[str] = None, tmap: Optional[TensorMap] = None, - default_tmaps: Dict[str, TensorMap] = MRI_TMAPS, -) -> None: +def choose_cardiac_mri_tmap(sample_id, folder=None, tmap_name='cine_lax_4ch_192', default_tmap_names=CARDIAC_MRI_TMAP_NAMES): + choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) + + +def choose_brain_mri_tmap(sample_id, folder=None, tmap_name='t2_flair_sag_p2_1mm_fs_ellip_pf78_1', default_tmap_names=BRAIN_MRI_TMAP_NAMES): + choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) + + +def choose_liver_mri_tmap(sample_id, folder=None, tmap_name='liver_shmolli_segmented', default_tmap_names=LIVER_MRI_TMAP_NAMES): + choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) + + +def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=BEST_EFFORT_MRI_TMAP_NAMES): """Render widgets and plots for MRI tensors. Args: sample_id: The id of the sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap: The TensorMap for the 3D MRI tensor to visualize. - default_tmaps: Other TensorMaps to offer for visualization, if present in the hd5. + tmap_name: The TMAP name for the 3D MRI tensor to visualize. + default_tmap_names: Other TMAP names to offer for visualization, if present in the hd5. + + Returns: + ipywidget or HTML upon error. """ if folder is None: folder = get_mri_hd5_folder(sample_id) @@ -81,45 +88,42 @@ def choose_mri_tmap( tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path) hd5 = h5py.File(local_path, mode='r') except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - display( - HTML(f'''
+ return HTML(f''' +
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}:

{e.message}

Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket. -
'''), - ) - return - - sample_tmaps = OrderedDict() - # Add the passed tmap parameter, if it is present in this hd5. - if tmap: - if tmap.hd5_key_guess() in hd5: - if len(tmap.shape) == 3: - sample_tmaps[tmap.name] = tmap +
''') + + sample_tmap_names = [] + # Add the passed tmap_name parameter, if it is present in this hd5. + if tmap_name: + if TMAPS[tmap_name].hd5_key_guess() in hd5: + if len(TMAPS[tmap_name].shape) == 3: + sample_tmap_names.append(tmap_name) else: - print(f'{tmap} is not a 3D tensor, skipping it') + print(f'{tmap_name} is not a 3D tensor, skipping it') else: - print(f'{tmap} is not available in {sample_id}') - # Also discover applicable TensorMaps for this particular sample's HD5 file. - sample_tmaps.update({n: t for n, t in sorted(default_tmaps.items(), key=lambda t: t[0]) if t.hd5_key_guess() in hd5}) - - if not sample_tmaps: - display( - HTML(f'''
- Neither {tmap.name} nor any of {default_tmaps.keys()} are present in this HD5 for sample {sample_id} in {folder}. - Use the tmap parameter to try a different TensorMap or the folder parameter to try a different hd5 for the sample. -
'''), - ) - return - - default_tmap_value = next(iter(sample_tmaps.values())) + print(f'{tmap_name} is not available in {sample_id}') + # Also discover applicable TMAPS for this particular sample's HD5 file. + sample_tmap_names.extend( + sorted(set([k for k in default_tmap_names if TMAPS[k].hd5_key_guess() in hd5])), + ) + + if not sample_tmap_names: + return HTML(f'''
+ Neither {tmap_name} nor any of {default_tmap_names} are present in this HD5 for sample {sample_id} in {folder}. + Use the tmap_name parameter to try a different TMAP or the folder parameter to try a different hd5 for the sample. +
''') + + default_tmap_name_value = sample_tmap_names[0] # Display the middle instance by default in the interactive view. - default_instance_value, max_instance_value = compute_instance_range(default_tmap_value) - default_vmin_value, default_vmax_value = compute_color_range(hd5, default_tmap_value) + default_instance_value, max_instance_value = compute_instance_range(default_tmap_name_value) + default_vmin_value, default_vmax_value = compute_color_range(hd5, default_tmap_name_value) - tmap_chooser = widgets.Dropdown( - options=sample_tmaps, - value=default_tmap_value, + tmap_name_chooser = widgets.Dropdown( + options=sample_tmap_names, + value=default_tmap_name_value, description='Choose the MRI tensor TMAP name to visualize:', style={'description_width': 'initial'}, layout=widgets.Layout(width='900px'), @@ -170,20 +174,20 @@ def choose_mri_tmap( viz_controls_ui = widgets.VBox( [ widgets.HTML('

Visualization controls

'), - tmap_chooser, + tmap_name_chooser, widgets.HBox([transpose_chooser, fig_width_chooser]), widgets.HBox([flip_chooser, color_range_chooser]), widgets.HBox([plot_type_chooser, instance_chooser]), ], layout=widgets.Layout(width='auto', border='solid 1px grey'), ) - tmap_cache = TensorMapCache(hd5=hd5, tmap=tmap_chooser.value) + tmap_cache = TensorMapCache(hd5=hd5, tmap_name=tmap_name_chooser.value) viz_controls_output = widgets.interactive_output( plot_mri_tmap, { 'sample_id': widgets.fixed(sample_id), 'tmap_cache': widgets.fixed(tmap_cache), - 'tmap': tmap_chooser, + 'tmap_name': tmap_name_chooser, 'plot_type': plot_type_chooser, 'instance': instance_chooser, 'color_range': color_range_chooser, @@ -205,36 +209,33 @@ def on_plot_type_change(change): else: instance_chooser.layout.visibility = 'hidden' - tmap_chooser.observe(on_tmap_value_change, names='value') + tmap_name_chooser.observe(on_tmap_value_change, names='value') plot_type_chooser.observe(on_plot_type_change, names='value') display(viz_controls_ui, viz_controls_output) -def compute_color_range(hd5: Dict[str, Any], tmap: TensorMap) -> List[int]: +def compute_color_range(hd5, tmap_name): """Compute the mean values for the color ranges of instances in the MRI series.""" - mri_tensor = tmap.tensor_from_file(tmap, hd5) + mri_tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5) vmin = np.mean([np.min(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])]) vmax = np.mean([np.max(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])]) - return [vmin, vmax] + return[vmin, vmax] -def compute_instance_range(tmap: TensorMap) -> Tuple[int, int]: +def compute_instance_range(tmap_name): """Compute middle and max instances.""" - middle_instance = int(tmap.shape[2] / 2) - max_instance = tmap.shape[2] - return (middle_instance, max_instance) + middle_instance = int(TMAPS[tmap_name].shape[2] / 2) + max_instance = TMAPS[tmap_name].shape[2] + return(middle_instance, max_instance) -def plot_mri_tmap( - sample_id: Union[int, str], tmap_cache: TensorMapCache, tmap: TensorMap, plot_type: PlotType, - instance: int, color_range: Tuple[int, int], transpose: bool, flip: bool, fig_width: int, -) -> None: +def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_range, transpose, flip, fig_width): """Visualize the applicable MRI series within this HD5 file. Args: sample_id: The local or Cloud Storage path to the MRI file. tmap_cache: The cache from which to retrieve the tensor to be plotted. - tmap: The chosen TensorMap for the MRI series. + tmap_name: The name of the chosen TMAP for the MRI series. plot_type: Whether to display instances interactively or in a panel view. instance: The particular instance to display, if interactive. color_range: Array of minimum and maximum value for the color range. @@ -242,9 +243,12 @@ def plot_mri_tmap( flip: Whether to flip the image on its vertical axis fig_width: The desired width of the figure. Note that height computed as the proportion of the width based on the data to be plotted. + + Returns: + The plot or a notebook-friendly error message. """ - title_prefix = f'{tmap.name} from MRI {sample_id}' - mri_tensor = tmap_cache.get(tmap) + title_prefix = f'{tmap_name} from MRI {sample_id}' + mri_tensor = tmap_cache.get(tmap_name) if plot_type == PlotType.INTERACTIVE: plot_mri_tensor_as_animation( mri_tensor=mri_tensor, @@ -271,13 +275,10 @@ def plot_mri_tmap( title_prefix=title_prefix, ) else: - HTML(f'''
Invalid plot type: {plot_type}
''') + return HTML(f'''
Invalid plot type: {plot_type}
''') -def plot_mri_tensor_as_panels( - mri_tensor: np.array, vmin: int, vmax: int, transpose: bool = False, flip: bool = False, - fig_width: int = DEFAULT_IMAGE_WIDTH, title_prefix: str = '', -) -> None: +def plot_mri_tensor_as_panels(mri_tensor, vmin, vmax, transpose=False, flip=False, fig_width=DEFAULT_IMAGE_WIDTH, title_prefix=''): """Visualize an MRI series from a 3D tensor as a panel of static plots. Args: @@ -313,7 +314,7 @@ def plot_mri_tensor_as_panels( axes[row, col].set_yticklabels([]) axes[row, col].set_xticklabels([]) fig.suptitle( - f'{title_prefix}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', # pylint: disable=line-too-long + f'{title_prefix}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', fontsize=fig_width, ) fig.subplots_adjust( @@ -325,11 +326,7 @@ def plot_mri_tensor_as_panels( ) -def plot_mri_tensor_as_animation( - mri_tensor: np.array, instance: int, vmin: int, vmax: int, - transpose: bool = False, flip: bool = False, - fig_width: int = DEFAULT_IMAGE_WIDTH, title_prefix: str = '', -) -> None: +def plot_mri_tensor_as_animation(mri_tensor, instance, vmin, vmax, transpose=False, flip=False, fig_width=DEFAULT_IMAGE_WIDTH, title_prefix=''): """Visualize an MRI series from a 3D tensor as an animation rendered one panel at a time. Args: @@ -361,7 +358,7 @@ def plot_mri_tensor_as_animation( _, ax = plt.subplots(figsize=(fig_width, fig_height), facecolor='beige') ax.imshow(pixels, cmap='gray', vmin=vmin, vmax=vmax) ax.set_title( - f'{title_prefix}, Instance: {instance}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', # pylint: disable=line-too-long + f'{title_prefix}, Instance: {instance}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', fontsize=fig_width, ) ax.set_yticklabels([]) diff --git a/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb b/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb index 67870c053..bd58f6d25 100644 --- a/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb +++ b/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb @@ -328,9 +328,11 @@ ") -> Model:\n", " inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in args.tensor_maps_in}\n", " original_outputs = {tm:1 for tm in args.tensor_maps_out}\n", + " real_serial_layers = kwargs['model_layers']\n", + " args.model_layers = None\n", " multimodal_activations = []\n", - " desired_distance_tm = []\n", - " my_metrics = {}\n", + " encoders = {}\n", + " decoders = {}\n", " outputs = []\n", " losses = []\n", " for left, right in pairs:\n", @@ -345,15 +347,22 @@ " h_right = encode_right(inputs[right]) \n", " \n", " if pair_loss == 'cosine':\n", - " loss_layer = CosineLossLayer(100.0)\n", + " loss_layer = CosineLossLayer(1.0)\n", " elif pair_loss == 'euclid':\n", - " loss_layer = L2LossLayer(100.0)\n", + " loss_layer = L2LossLayer(1.0)\n", " \n", " paired_embeddings = loss_layer([h_left, h_right])\n", " multimodal_activations.extend(paired_embeddings)\n", + " if left not in encoders:\n", + " encoders[left] = encode_left\n", + " if right not in encoders:\n", + " encoders[right] = encode_right \n", " \n", " multimodal_activation = Concatenate()(multimodal_activations)\n", + " encoder = Model(inputs=list(inputs.values()), outputs=[multimodal_activation], name='encoder')\n", " \n", + " # build decoder models\n", + " latent_inputs = Input(shape=(args.dense_layers[0]*len(inputs)), name='input_concept_space')\n", " pre_decoder_shapes: Dict[TensorMap, Optional[Tuple[int, ...]]] = {}\n", " for tm in args.tensor_maps_out:\n", " shape = _calc_start_shape(num_upsamples=len(args.dense_blocks), output_shape=tm.shape, \n", @@ -380,24 +389,26 @@ " upsample_z=args.pool_z,\n", " )\n", " \n", - " outputs.append(decode(restructure(multimodal_activation)))\n", + " reconstruction = decode(restructure(latent_inputs))\n", + " decoder = Model(latent_inputs, reconstruction, name=tm.output_name())\n", + " decoders[tm] = decoder\n", + " outputs.append(decoder(multimodal_activation))\n", " losses.append(tm.loss)\n", "\n", - " args.tensor_maps_out = list(original_outputs.keys()) + desired_distance_tm\n", + " args.tensor_maps_out = list(original_outputs.keys())\n", " args.tensor_maps_in = list(inputs.keys())\n", " \n", - " opt = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n", - " #outputs.reverse() # Make paired loss last\n", - " #losses.reverse()\n", " m = Model(inputs=list(inputs.values()), outputs=outputs)\n", - " m.compile(optimizer=opt, loss=losses)\n", + " my_metrics = {tm.output_name(): tm.metrics for tm in args.tensor_maps_out}\n", + " opt = Adam(lr=kwargs['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n", + " m.compile(optimizer=opt, loss=losses, metrics=my_metrics)\n", " m.summary()\n", " \n", - " if kwargs['model_layers'] is not None:\n", + " if real_serial_layers is not None:\n", " m.load_weights(kwargs['model_layers'], by_name=True)\n", " print(f\"Loaded model weights from:{kwargs['model_layers']}\")\n", " \n", - " return m" + " return m, encoders, decoders" ] }, { @@ -407,7 +418,7 @@ "outputs": [], "source": [ "sys.argv = ['train', \n", - " '--tensors', '/mnt/disks/segmented-sax-lax/2020-07-07/', \n", + " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", " '--input_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d', \n", " '--output_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d',\n", " '--activation', 'swish',\n", @@ -417,13 +428,13 @@ " '--conv_z', '3', '3', '3', \n", " '--dense_blocks', '32', '32', '32',\n", " '--block_size', '3',\n", - " '--dense_layers', '512',\n", + " '--dense_layers', '64',\n", " '--pool_x', '2',\n", " '--pool_y', '2',\n", " '--batch_size', '1',\n", " '--patience', '32',\n", - " '--epochs', '248',\n", - " '--learning_rate', '0.001',\n", + " '--epochs', '292',\n", + " '--learning_rate', '0.0001',\n", " '--training_steps', '256',\n", " '--validation_steps', '30',\n", " '--test_steps', '2',\n", @@ -433,12 +444,17 @@ " '--id', 'lax_2ch_3ch_diastole_pair_cosine_loss']\n", "args = parse_args()\n", "pairs = [(args.tensor_maps_in[0], args.tensor_maps_in[1])]\n", - "overparameterized_model = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", + "overparameterized_model, encoders, decoders = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", "train_model_from_generators(\n", " overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", " args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", - ")" + " plot=False, save_last_model=True\n", + ")\n", + "for tm in encoders:\n", + " encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5')\n", + "for tm in decoders:\n", + " decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5')" ] }, { @@ -502,6 +518,55 @@ "latent_df.info()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out_path = os.path.join(args.output_folder, args.id + '/')\n", + "test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps*5)\n", + "print(list(test_data.keys()))\n", + "\n", + "preds = overparameterized_model.predict(test_data)\n", + "print([p.shape for p in preds])\n", + "print([tm.name for tm in args.tensor_maps_out])\n", + "print(test_paths)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4h.plots import _plot_reconstruction\n", + "_plot_reconstruction(args.tensor_maps_out[0], test_data['input_lax_2ch_diastole_slice0_3d_continuous'], preds[0], out_path, test_paths, num_samples=2)\n", + "from ml4h.explorations import predictions_to_pngs\n", + "predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_out, \n", + " test_data, test_labels, test_paths, out_path)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i, etm in enumerate(encoders):\n", + " embed = encoders[etm].predict(test_data[etm.input_name()])\n", + " double = np.tile(embed, 2)\n", + " print(f'embed shape: {embed.shape} double shape: {double.shape}')\n", + " for dtm in decoders:\n", + " predictions = decoders[dtm].predict(double)\n", + " print(f'prediction shape: {predictions.shape}')\n", + " out_path = os.path.join(args.output_folder, args.id, f'decoding_{dtm.name}_from_{etm.name}/')\n", + " if not os.path.exists(os.path.dirname(out_path)):\n", + " os.makedirs(os.path.dirname(out_path))\n", + " _plot_reconstruction(dtm, test_data[dtm.input_name()], predictions.copy(), out_path, test_paths, 8)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -711,6 +776,18 @@ "display_name": "Python 3", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" } }, "nbformat": 4, diff --git a/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb b/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb index 442e27a18..3ea3e32a2 100644 --- a/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb +++ b/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb @@ -38,6 +38,7 @@ "# ml4h Imports\n", "from ml4h.TensorMap import TensorMap\n", "from ml4h.arguments import parse_args\n", + "from ml4h.plots import _plot_reconstruction\n", "from ml4h.models import make_multimodal_multitask_model, train_model_from_generators, make_hidden_layer_model, _conv_layer_from_kind_and_dimension\n", "from ml4h.tensor_generators import TensorGenerator, big_batch_from_minibatch_generator, test_train_valid_tensor_generators\n", "from ml4h.recipes import plot_predictions, infer_hidden_layer_multimodal_multitask\n", @@ -271,13 +272,6 @@ " return norm\n", "\n", "def pairwise_cosine_difference(t1, t2):\n", - " \"\"\"\n", - " A [batch x n x d] tensor of n rows with d dimensions\n", - " B [batch x m x d] tensor of n rows with d dimensions\n", - "\n", - " returns:\n", - " D [batch x n x m] tensor of cosine similarity scores between each point i Model:\n", " inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in args.tensor_maps_in}\n", " original_outputs = {tm:1 for tm in args.tensor_maps_out}\n", + " real_serial_layers = kwargs['model_layers']\n", + " args.model_layers = None\n", " multimodal_activations = []\n", + " encoders = {}\n", + " decoders = {}\n", " outputs = []\n", " losses = []\n", " for left, right in pairs:\n", @@ -348,9 +346,16 @@ " \n", " paired_embeddings = loss_layer([h_left, h_right])\n", " multimodal_activations.extend(paired_embeddings)\n", + " if left not in encoders:\n", + " encoders[left] = encode_left\n", + " if right not in encoders:\n", + " encoders[right] = encode_right \n", " \n", " multimodal_activation = Concatenate()(multimodal_activations)\n", + " encoder = Model(inputs=list(inputs.values()), outputs=[multimodal_activation], name='encoder')\n", " \n", + " # build decoder models\n", + " latent_inputs = Input(shape=(args.dense_layers[0]*len(inputs)), name='input_concept_space')\n", " pre_decoder_shapes: Dict[TensorMap, Optional[Tuple[int, ...]]] = {}\n", " for tm in args.tensor_maps_out:\n", " shape = _calc_start_shape(num_upsamples=len(args.dense_blocks), output_shape=tm.shape, \n", @@ -377,7 +382,10 @@ " upsample_z=args.pool_z,\n", " )\n", " \n", - " outputs.append(decode(restructure(multimodal_activation)))\n", + " reconstruction = decode(restructure(latent_inputs))\n", + " decoder = Model(latent_inputs, reconstruction, name=tm.output_name())\n", + " decoders[tm] = decoder\n", + " outputs.append(decoder(multimodal_activation))\n", " losses.append(tm.loss)\n", "\n", " args.tensor_maps_out = list(original_outputs.keys())\n", @@ -389,11 +397,11 @@ " m.compile(optimizer=opt, loss=losses, metrics=my_metrics)\n", " m.summary()\n", " \n", - " if kwargs['model_layers'] is not None:\n", + " if real_serial_layers is not None:\n", " m.load_weights(kwargs['model_layers'], by_name=True)\n", " print(f\"Loaded model weights from:{kwargs['model_layers']}\")\n", " \n", - " return m" + " return m, encoders, decoders" ] }, { @@ -406,35 +414,90 @@ " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", " '--input_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole', \n", " '--output_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole',\n", - " '--activation', 'swish',\n", + " '--activation', 'selu',\n", " '--conv_layers', '32',\n", - " '--conv_x', '9', '9', '9',\n", + " '--conv_x', '15', '15', '15',\n", " '--conv_y', '3', '3', '3', \n", " '--conv_z', '3', '3', '3',\n", " '--dense_blocks', '32', '32', '32',\n", " '--block_size', '3',\n", - " '--dense_layers', '512',\n", + " '--dense_layers', '64',\n", " '--pool_x', '2',\n", " '--pool_y', '2',\n", " '--batch_size', '1',\n", - " '--patience', '44',\n", - " '--epochs', '496',\n", - " '--learning_rate', '0.0001',\n", + " '--patience', '94',\n", + " '--epochs', '396',\n", + " '--learning_rate', '0.00005',\n", " '--training_steps', '128',\n", " '--validation_steps', '30',\n", " '--test_steps', '8',\n", " '--num_workers', '4',\n", " '--inspect_model',\n", " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", - " '--id', 'ecg_mri_lax_4ch_diastole_euclid_paired_segmenter_512d']\n", + " '--model_layers', './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.h5',\n", + " '--id', 'paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu']\n", "args = parse_args()\n", "pairs = [(args.tensor_maps_in[0], args.tensor_maps_in[1])]\n", - "overparameterized_model = make_paired_autoencoder_model(pairs, pair_loss='euclid', **args.__dict__)\n", + "overparameterized_model, encoders, decoders = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", "train_model_from_generators(\n", " overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", " args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", - ")" + " plot=False, save_last_model=True\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for tm in encoders:\n", + " encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5')\n", + "for tm in decoders:\n", + " decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sys.argv = ['train', \n", + " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", + " '--input_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole', \n", + " '--output_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole',\n", + " '--activation', 'swish',\n", + " '--conv_layers', '32',\n", + " '--conv_x', '9', '9', '9',\n", + " '--conv_y', '3', '3', '3', \n", + " '--conv_z', '3', '3', '3', \n", + " '--dense_blocks', '32', '32', '32',\n", + " '--block_size', '3',\n", + " '--dense_layers', '256',\n", + " '--pool_x', '2',\n", + " '--pool_y', '2',\n", + " '--batch_size', '1',\n", + " '--patience', '44',\n", + " '--epochs', '532',\n", + " '--learning_rate', '0.0002',\n", + " '--training_steps', '72',\n", + " '--validation_steps', '30',\n", + " '--test_steps', '8',\n", + " '--num_workers', '4',\n", + " '--inspect_model',\n", + " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", + " '--hidden_layer', 'concatenate_36',\n", + " '--model_file', './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.h5',\n", + " #'--sample_csv', '/home/sam/lvh/lvh_hold_out.txt',\n", + " '--id', 'paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu']\n", + "args = parse_args()\n", + "#overparameterized_model = make_multimodal_multitask_model(**args.__dict__)\n", + "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", + "#plot_predictions(args)\n", + "#infer_hidden_layer_multimodal_multitask(args)" ] }, { @@ -480,7 +543,17 @@ "metadata": {}, "outputs": [], "source": [ - "print(list(test_data['input_strip_continuous'].shape))" + "for i, etm in enumerate(encoders):\n", + " embed = encoders[etm].predict(test_data[etm.input_name()])\n", + " double = np.tile(embed, 2)\n", + " print(f'embed shape: {embed.shape} double shape: {double.shape}')\n", + " for dtm in decoders:\n", + " predictions = decoders[dtm].predict(double)\n", + " print(f'prediction shape: {predictions.shape}')\n", + " out_path = os.path.join(args.output_folder, args.id, f'decoding_{dtm.name}_from_{etm.name}/')\n", + " if not os.path.exists(os.path.dirname(out_path)):\n", + " os.makedirs(os.path.dirname(out_path))\n", + " _plot_reconstruction(dtm, test_data[dtm.input_name()], predictions.copy(), out_path, test_paths, 8)" ] }, { @@ -515,7 +588,7 @@ " if i % sample_every == 0 and col < samples:\n", " for j in range(rows):\n", " if len(test_data[test_key].shape) == 4:\n", - " axes[j, col].imshow(test_data[test_key][j, :, :, 0], cmap = 'gray')\n", + " axes[j, col].imshow(test_data[test_key][j, :, :, 0], cmap = 'plasma')\n", " axes[j, col].set_yticks(())\n", " elif len(test_data[test_key].shape) == 3:\n", " for l in range(12):\n", @@ -541,14 +614,14 @@ "\n", "\n", "test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)\n", - "test_key = 'input_lax_4ch_diastole_slice0_224_3d_continuous'\n", + "test_key = 'input_cine_segmented_lax_4ch_diastole_categorical'\n", "test_shape = test_data[test_key].shape\n", "test_data[test_key] = np.random.random(test_shape)\n", "out_path = os.path.join(args.output_folder, args.id, test_key + '_noise/')\n", "if not os.path.exists(os.path.dirname(out_path)):\n", " os.makedirs(os.path.dirname(out_path))\n", "noise_preds = plot_ae_towards_attractor(overparameterized_model, test_data, test_labels, test_key, \n", - " test_index=1, rows=8, samples=4, steps = 18)\n", + " test_index=1, rows=8, samples=4, steps = 28)\n", "print(list(test_data.keys()))\n", "_plot_reconstruction(args.tensor_maps_out[0], test_data['input_strip_continuous'], \n", " noise_preds[0], out_path, test_paths)\n", @@ -592,7 +665,10 @@ "df = pd.read_csv('/home/sam/ml/trained_models/lax_4ch_diastole_autoencode_leaky_converge/tensors_all_union.csv')\n", "df['21003_Age-when-attended-assessment-centre_2_0'].plot.hist(bins=30)\n", "hidden_inference = './recipes_output/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples/hidden_inference_ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples.tsv'\n", + "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_euclid_256d_swish/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_euclid_256d_swish.tsv'\n", + "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_256d_selu/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_cosine_256d_selu.tsv'\n", "\n", + "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.tsv'\n", "\n", "df2 = pd.read_csv(hidden_inference, sep='\\t')\n", "df['fpath'] = pd.to_numeric(df['fpath'], errors='coerce')\n", @@ -714,7 +790,7 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", - "latent_dimension = 256\n", + "latent_dimension = 128\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", "for strat in ['Sex_Female_0_0', 'has_ttntv', 'atrial_fibrillation_or_flutter', \n", @@ -742,7 +818,7 @@ " stratify_latent_space(strat, 1.0, latent_cols, latent_df)\n", "strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - "theshes = [45, 100, 150, 3500000, 27.5, 70]\n", + "theshes = [45, 100, 150, 3500000, 27.5, 65]\n", "for strat, thresh in zip(strats, theshes):\n", " stratify_latent_space(strat, thresh, latent_cols, latent_df)" ] @@ -756,12 +832,12 @@ "latent_dimension = 512\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", - "for strat in ['Sex_Female_0_0', 'atrial_fibrillation_or_flutter', \n", + "for strat in ['Sex_Female_0_0', 'has_ttntv', 'atrial_fibrillation_or_flutter', \n", " 'coronary_artery_disease', 'hypertension']:\n", " stratify_latent_space(strat, 1.0, latent_cols, latent_df)\n", "strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - "theshes = [45, 100, 150, 3500000, 27.5, 70]\n", + "theshes = [45, 100, 150, 3500000, 27.5, 65]\n", "for strat, thresh in zip(strats, theshes):\n", " stratify_latent_space(strat, thresh, latent_cols, latent_df)" ] @@ -775,11 +851,11 @@ "latent_dimension = 256\n", "latent_cols = [f'latent_{256+i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", - "c_strats = [ 'Sex_Female_0_0']\n", + "c_strats = [ 'Sex_Female_0_0', 'has_ttntv']\n", "for c_strat in c_strats:\n", " strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - " theshes = [50, 100, 150, 3750000, 27.5, 70]\n", + " theshes = [50, 100, 150, 3750000, 27.5, 65]\n", " for strat, thresh in zip(strats, theshes):\n", " directions_in_latent_space(c_strat, 1.0, strat, thresh, latent_cols, latent_df)" ] @@ -824,7 +900,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(f'{ecg_encode[:5,:5]} \\n{mri_encode[:5,:5]}')" + "print(f\"{np.mean(np.sqrt(np.einsum('ij, ij->ij', ecg_encode, ecg_encode)))}\")\n", + "print(f\"{np.mean(np.sqrt(np.einsum('ij, ij->ij', mri_encode, mri_encode)))}\")" ] }, { @@ -836,7 +913,7 @@ "latent_dimension = 256\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "ecg_encode = latent_df[latent_cols].to_numpy()\n", - "latent_cols = [f'latent_{18+i}' for i in range(latent_dimension)]\n", + "latent_cols = [f'latent_{250+i}' for i in range(latent_dimension)]\n", "mri_encode = latent_df[latent_cols].to_numpy()\n", "diff = np.sqrt(np.einsum('ij, ij->ij', ecg_encode - mri_encode, ecg_encode - mri_encode))\n", "print(diff.shape) \n", @@ -849,8 +926,8 @@ "metadata": {}, "outputs": [], "source": [ - "ch2_random = np.random.random((4452, 256))\n", - "ch3_random = np.random.random((4452, 256))\n", + "ch2_random = np.random.random((8520, 256))\n", + "ch3_random = np.random.random((8520, 256))\n", "diff = np.sqrt(np.einsum('ij, ij->ij', ch2_random - ch3_random, ch2_random - ch3_random))\n", "print(diff.shape) \n", "print(np.mean(diff))" @@ -862,44 +939,35 @@ "metadata": {}, "outputs": [], "source": [ - "sys.argv = ['train', \n", - " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", - " '--input_tensors', 'ecg.ecg_rest', 'mri.lax_4ch_diastole_slice0_224_3d', \n", - " '--output_tensors', 'ecg.ecg_rest', 'mri.lax_4ch_diastole_slice0_224_3d',\n", - " '--activation', 'swish',\n", - " '--conv_layers', '32',\n", - " '--conv_x', '9', '9', '9',\n", - " '--conv_y', '3', '3', '3', \n", - " '--conv_z', '3', '3', '3', \n", - " '--dense_blocks', '32', '32', '32',\n", - " '--block_size', '3',\n", - " '--dense_layers', '256',\n", - " '--pool_x', '2',\n", - " '--pool_y', '2',\n", - " '--batch_size', '1',\n", - " '--patience', '44',\n", - " '--epochs', '532',\n", - " '--learning_rate', '0.0002',\n", - " '--training_steps', '72',\n", - " '--validation_steps', '30',\n", - " '--test_steps', '8',\n", - " '--num_workers', '4',\n", - " '--inspect_model',\n", - " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", - " '--hidden_layer', 'concatenate_36',\n", - " '--model_file', './recipes_output/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples.h5',\n", - " '--train_csv', '/home/sam/lvh/small_set.csv',\n", - " #'--sample_csv', '/home/sam/lvh/lvh_hold_out.txt',\n", - " '--id', 'ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples']\n", - "args = parse_args()\n", - "\n", - "#overparameterized_model = make_multimodal_multitask_model(**args.__dict__)\n", - "#infer_hidden_layer_multimodal_multitask(args)\n", - "#generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", - "# train_model_from_generators(\n", - "# overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", - "# args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", - "# )" + "latent_dimension = 128\n", + "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", + "all_encode = latent_df[latent_cols].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4h.plots import _plot_reconstruction\n", + "for tm in decoders:\n", + " predictions = decoders[tm].predict(all_encode[:4])\n", + " print(predictions.shape)\n", + " out_path = os.path.join(args.output_folder, args.id, 'decodings/')\n", + " if not os.path.exists(os.path.dirname(out_path)):\n", + " os.makedirs(os.path.dirname(out_path))\n", + " samples = list(map(str, list(latent_df['sample_id'])[:10]))\n", + " _plot_reconstruction(tm, predictions, predictions, out_path, samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print()" ] }, { diff --git a/notebooks/autoencoders/vae_mri_slice.ipynb b/notebooks/autoencoders/vae_mri_slice.ipynb index 21785359e..0aa84e195 100644 --- a/notebooks/autoencoders/vae_mri_slice.ipynb +++ b/notebooks/autoencoders/vae_mri_slice.ipynb @@ -286,7 +286,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/notebooks/mnist_demo.ipynb b/notebooks/mnist_demo.ipynb index 88393f90f..b4fc5bcbf 100644 --- a/notebooks/mnist_demo.ipynb +++ b/notebooks/mnist_demo.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ "from tensorflow.keras.layers import Dense, Conv2D, Flatten\n", "\n", "from ml4h.defines import StorageType\n", - "from ml4h.arguments import parse_args, TMAPS, _get_tmap\n", + "from ml4h.arguments import parse_args\n", "from ml4h.TensorMap import TensorMap, Interpretation\n", "from ml4h.tensor_generators import test_train_valid_tensor_generators\n", "from ml4h.models import train_model_from_generators, make_multimodal_multitask_model, _inspect_model\n", @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -113,9 +113,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading data...\n", + "(50000, 784)\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "plot_mnist(4)" ] @@ -139,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -158,9 +177,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading data...\n", + "Wrote 5000 MNIST images and labels as HD5 files\n", + "Wrote 10000 MNIST images and labels as HD5 files\n", + "Wrote 15000 MNIST images and labels as HD5 files\n", + "Wrote 20000 MNIST images and labels as HD5 files\n", + "Wrote 25000 MNIST images and labels as HD5 files\n", + "Wrote 30000 MNIST images and labels as HD5 files\n", + "Wrote 35000 MNIST images and labels as HD5 files\n", + "Wrote 40000 MNIST images and labels as HD5 files\n", + "Wrote 45000 MNIST images and labels as HD5 files\n", + "Wrote 50000 MNIST images and labels as HD5 files\n" + ] + } + ], "source": [ "mnist_as_hd5(HD5_FOLDER)" ] @@ -210,12 +247,41 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2020-09-18 16:17:13,404 - logger:25 - INFO - Logging configuration was loaded. Log messages can be found at ./runs/learn_mnist/log_2020-09-18_16-17_0.log.\n", + "2020-09-18 16:17:13,410 - arguments:444 - INFO - Command Line was: \n", + "./scripts/tf.sh train --tensors ./mnist_hd5s/ --tensormap_prefix ml4h.tensormap.mnist --input_tensors mnist_image --output_tensors mnist_label --batch_size 64 --test_steps 64 --epochs 24 --output_folder ./runs/ --id learn_mnist\n", + "\n", + "2020-09-18 16:17:13,410 - arguments:445 - INFO - Arguments are Namespace(activation='relu', aligned_dimension=16, alpha=0.5, anneal_max=2.0, anneal_rate=0.0, anneal_shift=0.0, app_csv=None, b_slice_force=None, balance_csvs=[], batch_size=64, bigquery_credentials_file='/mnt/ml4cvd/projects/jamesp/bigquery/bigquery-viewer-credentials.json', bigquery_dataset='broad-ml4cvd.ukbb7089_r10data', block_size=3, bottleneck_type=, cache_size=437500000.0, categorical_field_ids=[], continuous_field_ids=[], continuous_file=None, continuous_file_column=None, continuous_file_discretization_bounds=[], continuous_file_normalize=False, conv_dilate=False, conv_layers=[32], conv_normalize=None, conv_regularize=None, conv_regularize_rate=0.0, conv_type='conv', conv_x=[3], conv_y=[3], conv_z=[2], debug=False, dense_blocks=[32, 24, 16], dense_layers=[16, 64], dense_normalize=None, dense_regularize=None, dense_regularize_rate=0.0, dicom_series='cine_segmented_sax_b6', dicoms='./dicoms/', eager=False, embed_visualization=None, epochs=24, explore_export_errors=False, freeze_model_layers=False, hidden_layer='embed', id='learn_mnist', imputation_method_for_continuous_fields='random', include_array=False, include_instance=False, include_missing_continuous_channel=False, input_tensors=['mnist_image'], inspect_model=False, inspect_show_labels=True, join_tensors=['partners_ecg_patientid_clean'], label_weights=None, language_layer='ecg_rest_text', language_prefix='ukb_ecg_rest', learning_rate=0.0002, learning_rate_schedule=None, logging_level='INFO', match_any_window=False, max_models=16, max_parameters=9000000, max_patients=999999, max_pools=[], max_sample_id=7000000, max_samples=None, max_slices=999999, min_sample_id=0, min_samples=3, min_values=10, mixup_alpha=0, mlp_concat=False, mode='mlp', model_file=None, model_files=[], model_layers=None, mri_field_ids=['20208', '20209'], num_workers=8, number_per_window=1, optimizer='radam', order_in_window=None, output_folder='./runs/', output_tensors=['mnist_label'], padding='same', patience=8, phecode_definitions='/mnt/ml4cvd/projects/jamesp/data/phecode_definitions1.2.csv', phenos_folder='gs://ml4cvd/phenotypes/', plot_hist=True, plot_mode='clinical', pool_type='max', pool_x=2, pool_y=2, pool_z=1, protected_tensors=[], random_seed=12878, reference_end_time_tensor=None, reference_join_tensors=None, reference_labels=None, reference_name='Reference', reference_start_time_tensor=None, reference_tensors=None, sample_csv=None, sample_weight=None, save_last_model=False, t=48, tensor_maps_in=[TensorMap(mnist_image, (28, 28, 1), continuous)], tensor_maps_out=[TensorMap(mnist_label, (10,), categorical)], tensor_maps_protected=[], tensormap_prefix='ml4h.tensormap.mnist', tensors='./mnist_hd5s/', tensors_name='Tensors', tensors_source=None, test_csv=None, test_ratio=0.1, test_steps=64, text_file=None, text_one_hot=False, text_window=32, time_frequency='3M', time_tensor='partners_ecg_datetime', train_csv=None, training_steps=72, tsv_style='standard', u_connect=defaultdict(, {}), valid_csv=None, valid_ratio=0.2, validation_steps=18, window_name=None, write_pngs=False, x=256, xml_field_ids=['20205', '6025'], xml_folder='/mnt/disks/ecg-rest-xml/', y=256, z=48, zip_folder='/mnt/disks/sax-mri-zip/', zoom_height=96, zoom_width=96, zoom_x=50, zoom_y=35)\n", + "\n", + "2020-09-18 16:17:13,411 - tensor_generators:661 - INFO - Found 0 train, 0 validation, and 0 testing tensors at: ./mnist_hd5s/\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Not enough tensors at ./mnist_hd5s/\nFound 0 training, 0 validation, and 0 testing tensors\nDiscarded 0 tensors", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m ]\n\u001b[1;32m 13\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparse_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mtrain_multimodal_multitask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/home/sam/ml/ml4h/recipes.py\u001b[0m in \u001b[0;36mtrain_multimodal_multitask\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_multimodal_multitask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0mgenerate_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerate_valid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerate_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_train_valid_tensor_generators\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_multimodal_multitask_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m model = train_model_from_generators(\n", + "\u001b[0;32m/home/sam/ml/ml4h/tensor_generators.py\u001b[0m in \u001b[0;36mtest_train_valid_tensor_generators\u001b[0;34m(tensor_maps_in, tensor_maps_out, tensor_maps_protected, tensors, batch_size, num_workers, training_steps, validation_steps, cache_size, balance_csvs, keep_paths, keep_paths_test, mixup_alpha, sample_csv, valid_ratio, test_ratio, train_csv, valid_csv, test_csv, siamese, sample_weight, **kwargs)\u001b[0m\n\u001b[1;32m 796\u001b[0m \u001b[0mtrain_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 797\u001b[0m \u001b[0mvalid_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalid_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 798\u001b[0;31m \u001b[0mtest_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 799\u001b[0m )\n\u001b[1;32m 800\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/home/sam/ml/ml4h/tensor_generators.py\u001b[0m in \u001b[0;36mget_train_valid_test_paths\u001b[0;34m(tensors, sample_csv, valid_ratio, test_ratio, train_csv, valid_csv, test_csv)\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalid_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m raise ValueError(\n\u001b[0;32m--> 665\u001b[0;31m \u001b[0;34mf'Not enough tensors at {tensors}\\n'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 666\u001b[0m \u001b[0;34mf'Found {len(train_paths)} training, {len(valid_paths)} validation, and {len(test_paths)} testing tensors\\n'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;34mf'Discarded {len(discard_paths)} tensors'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Not enough tensors at ./mnist_hd5s/\nFound 0 training, 0 validation, and 0 testing tensors\nDiscarded 0 tensors" + ] + } + ], "source": [ "sys.argv = ['train', \n", " '--tensors', HD5_FOLDER, \n", + " '--tensormap_prefix', 'ml4h.tensormap.mnist',\n", " '--input_tensors', 'mnist_image',\n", " '--output_tensors', 'mnist_label',\n", " '--batch_size', '64',\n", diff --git a/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb b/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb index f9f99b4ad..26c08c73a 100644 --- a/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb +++ b/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb @@ -67,8 +67,8 @@ "outputs": [], "source": [ "def plot_lax(series, transpose=False, size=18):\n", - " cols = 5\n", - " rows = 10\n", + " cols = 2\n", + " rows = 25\n", " _, axes = plt.subplots(rows, cols, figsize=(size, size))\n", " for dcm in series:\n", " col = (dcm.InstanceNumber-1)%cols\n", @@ -76,7 +76,7 @@ " if transpose:\n", " axes[row, col].imshow(dcm.pixel_array.T)\n", " else:\n", - " axes[row, col].imshow(dcm.pixel_array)\n", + " axes[row, col].imshow(dcm.pixel_array, cmap='gray')\n", " axes[row, col].set_yticklabels([])\n", " axes[row, col].set_xticklabels([])" ] @@ -132,7 +132,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb b/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb index 20b93d12f..64493df15 100644 --- a/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb +++ b/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb @@ -32,8 +32,9 @@ "source": [ "!mkdir ./dcm_scratch\n", "!rm ./dcm_scratch/*\n", - "!cp /mnt/ml4cvd/projects/bulk/cardiac_mri/1000387_20209_2_0.zip ./dcm_scratch/\n", - "!unzip ./dcm_scratch/1000387_20209_2_0.zip -d ./dcm_scratch/" + "\n", + "!cp /mnt/ml4cvd/projects/bulk/cardiac_mri/2467677_20209_2_0.zip ./dcm_scratch/\n", + "!unzip ./dcm_scratch/2467677_20209_2_0.zip -d ./dcm_scratch/" ] }, { @@ -154,11 +155,12 @@ " if idx >= sides*sides:\n", " continue\n", " if _is_mitral_valve_segmentation(dcm):\n", - " axes[idx%sides, idx//sides].imshow(dcm.pixel_array)\n", + " axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')\n", " else:\n", " try:\n", " overlay, anatomical_mask, ventricle_pixels = _get_overlay_from_dicom(dcm)\n", - " axes[idx%sides, idx//sides].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array))\n", + " #axes[idx%sides, idx//sides].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array), cmap='gray')\n", + " axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')\n", " except KeyError:\n", " print(f'Could not get overlay at {dcm.InstanceNumber}, angle {s}')\n", " axes[idx, idx//sides].imshow(dcm.pixel_array)\n", @@ -172,7 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_b_series(series[2], sides=7)" + "plot_b_series(series[4], sides=2)" ] }, { @@ -261,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/notebooks/review_results/identify_a_sample_to_review.ipynb b/notebooks/review_results/identify_a_sample_to_review.ipynb index f26d76309..a5c659d24 100644 --- a/notebooks/review_results/identify_a_sample_to_review.ipynb +++ b/notebooks/review_results/identify_a_sample_to_review.ipynb @@ -16,7 +16,7 @@ "
\n", " This notebook assumes:\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -32,7 +32,7 @@ "from ml4h.runtime_data_defines import determine_runtime\n", "from ml4h.runtime_data_defines import Runtime\n", "\n", - "if Runtime.ML4H_VM == determine_runtime():\n", + "if Runtime.ml4h_VM == determine_runtime():\n", " !pip3 install --user --upgrade pandas_gbq pyarrow\n", " # Be sure to restart the kernel if pip installs anything." ] @@ -259,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.7.7" }, "toc": { "base_numbering": 1, diff --git a/notebooks/review_results/image_annotations.ipynb b/notebooks/review_results/image_annotations.ipynb deleted file mode 100644 index 6e644f15a..000000000 --- a/notebooks/review_results/image_annotations.ipynb +++ /dev/null @@ -1,268 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Image annotations for a batch of samples\n", - "\n", - "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setup\n", - "\n", - "
\n", - " This notebook assumes\n", - "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", - "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", - "
\n", - "
" - ] - }, - { - "attachments": { - "Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png": { - "image/png": "" - } - }, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png](attachment:Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# TODO(deflaux): remove this cell after gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu has this preinstalled.\n", - "from ml4h.runtime_data_defines import determine_runtime\n", - "from ml4h.runtime_data_defines import Runtime\n", - "\n", - "if Runtime.ML4H_VM == determine_runtime():\n", - " !pip3 install --user ipycanvas==0.4.1 ipyannotations==0.2.0\n", - " !jupyter nbextension install --user --py ipycanvas\n", - " !jupyter nbextension enable --user --py ipycanvas\n", - " # Be sure to restart the kernel if pip installs anything.\n", - " # Also, shift-reload the browser page after the notebook extension installation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ml4h.visualization_tools.annotation_storage import BigQueryAnnotationStorage\n", - "from ml4h.visualization_tools.batch_image_annotations import BatchImageAnnotator\n", - "import pandas as pd\n", - "import tensorflow as tf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "code_folding": [] - }, - "outputs": [], - "source": [ - "%%javascript\n", - "// Display cell outputs to full height (no vertical scroll bar)\n", - "IPython.OutputArea.auto_scroll_threshold = 9999;" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pd.set_option('display.max_colwidth', -1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BIG_QUERY_ANNOTATIONS_STORAGE = BigQueryAnnotationStorage('uk-biobank-sek-data.ml_results.annotations')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Define the batch of samples to annotate\n", - "\n", - "
\n", - " Edit the CSV file path below, if needed, to either a local file or one in Cloud Storage.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#---[ EDIT AND RUN THIS CELL TO READ FROM A LOCAL FILE OR A FILE IN CLOUD STORAGE ]---\n", - "SAMPLE_BATCH_FILE = None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if SAMPLE_BATCH_FILE:\n", - " samples_df = pd.read_csv(tf.io.gfile.GFile(SAMPLE_BATCH_FILE))\n", - "\n", - "else:\n", - " # Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n", - " # more obvious in this demo that we are processing different samples.\n", - " samples_df = pd.DataFrame(\n", - " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n", - " data=[\n", - " [1655349, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", - " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", - " [1655349, 'cine_lax_4ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", - " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", - " [2403657, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", - " ])\n", - "\n", - "samples_df.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "samples_df.head(n = 10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Annotate the batch! " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", - "# use the PIL library to scale the image.\n", - "\n", - "annotator = BatchImageAnnotator(samples=samples_df,\n", - " zoom=2.0,\n", - " annotation_categories=['region_of_interest'],\n", - " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n", - "annotator.annotate_images()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# View the stored annotations " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "annotator.view_recent_submissions(count=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Provenance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "print(datetime.datetime.now())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "pip3 freeze" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.8" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": { - "height": "calc(100% - 180px)", - "left": "10px", - "top": "150px", - "width": "199px" - }, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/review_results/review_one_sample.ipynb b/notebooks/review_results/review_one_sample.ipynb index 9569380aa..9fb1fc2b2 100644 --- a/notebooks/review_results/review_one_sample.ipynb +++ b/notebooks/review_results/review_one_sample.ipynb @@ -16,7 +16,7 @@ "
\n", " This notebook assumes\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -41,7 +41,7 @@ "from ml4h.runtime_data_defines import determine_runtime\n", "from ml4h.runtime_data_defines import Runtime\n", "\n", - "if Runtime.ML4H_VM == determine_runtime():\n", + "if Runtime.ml4h_VM == determine_runtime():\n", " !pip3 install --user --upgrade pandas_gbq pyarrow\n", " # Be sure to restart the kernel if pip installs anything." ] @@ -523,7 +523,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## MRI DICOM visualization" + "## MRI dicom visualization" ] }, { @@ -641,7 +641,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.7.7" }, "toc": { "base_numbering": 1, diff --git a/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb b/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb index 97eee7d12..0c4a1c97a 100644 --- a/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb +++ b/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb @@ -21,7 +21,7 @@ "metadata": {}, "source": [ "
\n", - " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608)\n", + " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732)\n", "
" ] }, @@ -41,9 +41,7 @@ "from ml4h.visualization_tools.annotation_storage import BigQueryAnnotationStorage\n", "\n", "import pandas as pd\n", - "import tensorflow as tf\n", - "\n", - "%matplotlib inline" + "import tensorflow as tf" ] }, { @@ -64,7 +62,7 @@ "outputs": [], "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { @@ -257,7 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "#dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" + "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" ] }, { @@ -266,7 +264,7 @@ "metadata": {}, "outputs": [], "source": [ - "#dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" + "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" ] }, { @@ -283,7 +281,7 @@ "outputs": [], "source": [ "SAMPLE_TO_REVIEW = 5993648\n", - "folder = 'gs://deflaux-test-001/'" + "folder = 'gs://broad-ml4cvd-vcm/'" ] }, { @@ -328,7 +326,7 @@ "metadata": {}, "outputs": [], "source": [ - "#dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" + "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" ] }, { @@ -337,7 +335,7 @@ "metadata": {}, "outputs": [], "source": [ - "#dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" + "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" ] }, { @@ -460,6 +458,24 @@ "hd5_mri_plots.choose_mri_tmap(sample_id=SAMPLE_TO_REVIEW)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -615,7 +631,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.7.7" }, "toc": { "base_numbering": 1, @@ -630,7 +646,7 @@ "height": "calc(100% - 180px)", "left": "10px", "top": "150px", - "width": "197.756px" + "width": "342.756px" }, "toc_section_display": true, "toc_window_display": true diff --git a/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb b/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb index e17eb08e2..e166223b1 100644 --- a/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb +++ b/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb @@ -20,7 +20,7 @@ "
\n", " This notebook assumes:\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -79,7 +79,7 @@ "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", "# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { diff --git a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb deleted file mode 100644 index ce15e4d73..000000000 --- a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb +++ /dev/null @@ -1,259 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Image annotations for a batch of samples\n", - "\n", - "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setup\n", - "\n", - "
\n", - " This notebook assumes\n", - "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", - "
  • ml4cvd is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", - "
\n", - "
" - ] - }, - { - "attachments": { - "Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png": { - "image/png": "" - } - }, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png](attachment:Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ml4cvd.visualization_tools.batch_image_annotations import BatchImageAnnotator\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "code_folding": [] - }, - "outputs": [], - "source": [ - "%%javascript\n", - "// Display cell outputs to full height (no vertical scroll bar)\n", - "IPython.OutputArea.auto_scroll_threshold = 9999;" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pd.set_option('display.max_colwidth', -1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Define the batch of samples to annotate\n", - "\n", - "In general, we would read in a CSV file but for this demo we define the batch right here." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n", - "# more obvious in this demo that we are processing different samples.\n", - "samples_df = pd.DataFrame(\n", - " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n", - " data=[\n", - " ['fake_1', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", - " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n", - " ['fake_1', 'cine_lax_4ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", - " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n", - " ['fake_2', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", - " ])\n", - "\n", - "samples_df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Annotate the batch! " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", - "# use the PIL library to scale the image.\n", - "\n", - "annotator = BatchImageAnnotator(samples=samples_df, zoom=2.0, annotation_categories=['region_of_interest'])\n", - "annotator.annotate_images()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## via BigQuery annotation storage " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ml4cvd.visualization_tools.annotation_storage import BigQueryAnnotationStorage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BIG_QUERY_ANNOTATIONS_STORAGE = BigQueryAnnotationStorage('uk-biobank-sek-data.ml_results.annotations')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", - "# use the PIL library to scale the image.\n", - "\n", - "annotator = BatchImageAnnotator(samples=samples_df,\n", - " zoom=2.0,\n", - " annotation_categories=['region_of_interest'],\n", - " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n", - "annotator.annotate_images()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# View the stored annotations " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "annotator.view_recent_submissions(count=10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Provenance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "print(datetime.datetime.now())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "pip3 freeze" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.8" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": { - "height": "calc(100% - 180px)", - "left": "10px", - "top": "150px", - "width": "199px" - }, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb b/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb index 8e863a7e4..44379f740 100644 --- a/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb +++ b/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb @@ -18,7 +18,7 @@ "
\n", " This notebook assumes\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -84,7 +84,7 @@ "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", "# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { diff --git a/pylintrc b/pylintrc deleted file mode 100644 index 8a5e40122..000000000 --- a/pylintrc +++ /dev/null @@ -1,337 +0,0 @@ -# This configuration was copied from https://github.com/tensorflow/tensorflow/blob/18ebe824d2f6f20b09839cb0a0073032a2d6c5fe/tensorflow/tools/ci_build/pylintrc and then further modified. - -[MASTER] - -# Specify a configuration file. -#rcfile= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Profiled execution. -profile=no - -# Add files or directories to the denylist. They should be base names, not -# paths. -ignore=CVS - -# Pickle collected data for later comparisons. -persistent=yes - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - - -[MESSAGES CONTROL] - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time. See also the "--disable" option for examples. -enable=indexing-exception,old-raise-syntax - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager - - -# Set the cache size for astng objects. -cache-size=500 - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". -files-output=no - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Add a comment according to your evaluation note. This is used by the global -# evaluation report (RP0004). -comment=no - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[TYPECHECK] - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of classes names for which member attributes should not be checked -# (useful for classes with attributes dynamically set). -ignored-classes=SQLObject - -# When zope mode is activated, add a predefined set of Zope acquired attributes -# to generated-members. -zope=no - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E0201 when accessed. Python regular -# expressions are accepted. -generated-members=REQUEST,acl_users,aq_parent - -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the beginning of the name of dummy variables -# (i.e. not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - - -[BASIC] - -# Required attributes for module, separated by a comma -required-attributes= - -# List of builtins function names that should not be used, separated by a comma -bad-functions=apply,input,reduce - - -# Disable the report(s) with the given id(s). -# All non-Google reports are disabled by default. -disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 - -# Regular expression which should only match correct module names -module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - -# Regular expression which should only match correct module level names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression which should only match correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression which should only match correct function names -function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression which should only match correct method names -method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match correct instance attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression which should only match correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression which should only match correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression which should only match correct attribute names in class -# bodies -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main) - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=120 - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x) - (^\s*(import|from)\s - |\$Id:\s\/\/depot\/.+#\d+\s\$ - |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') - |^\s*\#\ LINT\.ThenChange - |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ - |pylint - |""" - |\# - |lambda - |(https?|ftp):) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=y - -# List of optional constructs for which whitespace checking is disabled -no-space-check= - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes= - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - - -[CLASSES] - -# List of interface methods to ignore, separated by a comma. This is used for -# instance to not check methods defines in Zope's Interface base class. -ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__,__new__,setUp - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls,class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[DESIGN] - -# Maximum number of arguments for function / method -max-args=5 - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore -ignored-argument-names=_.* - -# Maximum number of locals for function / method body -max-locals=15 - -# Maximum number of return / yield for function / method body -max-returns=6 - -# Maximum number of branch for function / method body -max-branches=12 - -# Maximum number of statements in function / method body -max-statements=50 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception,StandardError,BaseException - - -[AST] - -# Maximum line length for lambdas -short-func-length=1 - -# List of module members that should be marked as deprecated. -# All of the string functions are listed in 4.1.4 Deprecated string functions -# in the Python 2.4 docs. -deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc - - -[DOCSTRING] - -# List of exceptions that do not need to be mentioned in the Raises section of -# a docstring. -ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError - - - -[TOKENS] - -# Number of spaces of indent required when the last token on the preceding line -# is an open (, [, or {. -indent-after-paren=4 - - -[GOOGLE LINES] - -# Regexp for a proper copyright notice. -copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh index 32edfb1f3..a83fd582d 100755 --- a/scripts/jupyter.sh +++ b/scripts/jupyter.sh @@ -54,7 +54,7 @@ while getopts ":ip:ch" opt ; do ;; c) DOCKER_IMAGE=${DOCKER_IMAGE_NO_GPU} - GPU_DEVICE="" + GPU_DEVICE="" ;; :) echo "ERROR: Option -${OPTARG} requires an argument." 1>&2 @@ -99,7 +99,6 @@ ${DOCKER_COMMAND} run -it \ ${GPU_DEVICE} \ --rm \ --ipc=host \ ---hostname=$(hostname) \ -v /home/${USER}/:/home/${USER}/ \ -v /mnt/:/mnt/ \ -p 0.0.0.0:${PORT}:${PORT} \ diff --git a/tests/test_models.py b/tests/test_models.py index 61df950cb..976a6b846 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,7 +7,8 @@ from typing import List, Optional, Dict, Tuple, Iterator from ml4h.TensorMap import TensorMap -from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, check_no_bottleneck +from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, \ + check_no_bottleneck, make_paired_autoencoder_model from ml4h.test_utils import TMAPS_UP_TO_4D, MULTIMODAL_UP_TO_4D, CATEGORICAL_TMAPS, CONTINUOUS_TMAPS, SEGMENT_IN, SEGMENT_OUT, PARENT_TMAPS, CYCLE_PARENTS from ml4h.test_utils import LANGUAGE_TMAP_1HOT_WINDOW, LANGUAGE_TMAP_1HOT_SOFTMAX @@ -18,14 +19,13 @@ 'dense_layers': [4, 2], 'dense_blocks': [5, 3], 'block_size': 3, - 'conv_width': 3, 'learning_rate': 1e-3, 'optimizer': 'adam', 'conv_type': 'conv', 'conv_layers': [6, 5, 3], - 'conv_x': [3], - 'conv_y': [3], - 'conv_z': [2], + 'conv_x': [3]*5, + 'conv_y': [3]*5, + 'conv_z': [2]*5, 'padding': 'same', 'max_pools': [], 'pool_type': 'max', @@ -39,6 +39,16 @@ 'dense_regularize_rate': .1, 'dense_normalize': 'batch_norm', 'bottleneck_type': BottleneckType.FlattenRestructure, + 'pair_loss': 'cosine', + 'training_steps': 12, + 'learning_rate': 0.00001, + 'epochs': 6, + 'optimizer': 'adam', + 'learning_rate_schedule': None, + 'model_layers': None, + 'model_file': None, + 'hidden_layer': 'embed', + 'u_connect': {}, } @@ -54,19 +64,20 @@ def make_training_data(input_tmaps: List[TensorMap], output_tmaps: List[TensorMa ), ]) -def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None): +def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None, skip_shape_check: bool = False): if m is None: m = make_multimodal_multitask_model( input_tmaps, output_tmaps, **DEFAULT_PARAMS, ) - for tmap, tensor in zip(input_tmaps, m.inputs): - assert tensor.shape[1:] == tmap.shape - assert tensor.shape[1:] == tmap.shape - for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs): - assert tensor.shape[1:] == tmap.shape - assert tensor.shape[1:] == tmap.shape + if not skip_shape_check: + for tmap, tensor in zip(input_tmaps, m.inputs): + assert tensor.shape[1:] == tmap.shape + assert tensor.shape[1:] == tmap.shape + for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs): + assert tensor.shape[1:] == tmap.shape + assert tensor.shape[1:] == tmap.shape data = make_training_data(input_tmaps, output_tmaps) history = m.fit(data, steps_per_epoch=2, epochs=2, validation_data=data, validation_steps=2) for tmap in output_tmaps: @@ -294,8 +305,8 @@ def test_parents(self, output_tmaps): def test_language_models(self, input_output_tmaps, tmpdir): params = DEFAULT_PARAMS.copy() m = make_multimodal_multitask_model( - input_output_tmaps[0], - input_output_tmaps[1], + tensor_maps_in=input_output_tmaps[0], + tensor_maps_out=input_output_tmaps[1], **params ) assert_model_trains(input_output_tmaps[0], input_output_tmaps[1], m) @@ -309,6 +320,36 @@ def test_language_models(self, input_output_tmaps, tmpdir): **DEFAULT_PARAMS, ) + @pytest.mark.parametrize( + 'pairs', + [ + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1])], + [(CATEGORICAL_TMAPS[2], CATEGORICAL_TMAPS[1])], + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[3])] + ], + ) + def test_paired_models(self, pairs, tmpdir): + params = DEFAULT_PARAMS.copy() + pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) + params['u_connect'] = {tm: [] for tm in pair_list} + m, encoders, decoders = make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list, + **params + ) + assert_model_trains(pair_list, pair_list, m, skip_shape_check=True) + m.save(os.path.join(tmpdir, 'paired_ae.h5')) + path = os.path.join(tmpdir, f'm{MODEL_EXT}') + m.save(path) + make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list, + **params + ) + + @pytest.mark.parametrize( 'tmaps', [_rotate(PARENT_TMAPS, i) for i in range(len(PARENT_TMAPS))], diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 468dc44c6..df8f4a1e3 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -3,7 +3,7 @@ import pandas as pd import numpy as np -from ml4h.recipes import inference_file_name, hidden_inference_file_name +from ml4h.recipes import inference_file_name, _hidden_file_name from ml4h.recipes import train_multimodal_multitask, compare_multimodal_multitask_models from ml4h.recipes import infer_multimodal_multitask, infer_hidden_layer_multimodal_multitask from ml4h.recipes import compare_multimodal_scalar_task_models, _find_learning_rate @@ -42,7 +42,7 @@ def test_infer_genetics(self, default_arguments): def test_infer_hidden(self, default_arguments): infer_hidden_layer_multimodal_multitask(default_arguments) - tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id) + tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id) inferred = pd.read_csv(tsv, sep='\t') assert len(set(inferred['sample_id'])) == pytest.N_TENSORS @@ -50,7 +50,7 @@ def test_infer_hidden_genetics(self, default_arguments): default_arguments.tsv_style = 'genetics' infer_hidden_layer_multimodal_multitask(default_arguments) default_arguments.tsv_style = 'standard' - tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id) + tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id) inferred = pd.read_csv(tsv, sep='\t') assert len(set(inferred['FID'])) == pytest.N_TENSORS From 46fe794081c3d6b42bc20cd7e97ac1ff1d508dad Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Tue, 29 Sep 2020 17:15:39 -0400 Subject: [PATCH 02/21] paired --- CONTRIBUTING.md | 192 ++++++++++ .../batch_image_annotations.py | 236 ++++++++++++ .../review_results/image_annotations.ipynb | 268 ++++++++++++++ .../image_annotations_demo.ipynb | 259 ++++++++++++++ pylintrc | 337 ++++++++++++++++++ 5 files changed, 1292 insertions(+) create mode 100644 CONTRIBUTING.md create mode 100644 ml4h/visualization_tools/batch_image_annotations.py create mode 100644 notebooks/review_results/image_annotations.ipynb create mode 100644 notebooks/terra_featured_workspace/image_annotations_demo.ipynb create mode 100644 pylintrc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..2e2006c07 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,192 @@ +# Contributing + +1. Before making a substantial pull request, consider first [filing an issue](https://github.com/broadinstitute/ml/issues) describing the feature addition or change you wish to make. +1. [Get setup](#setup-for-code-contributions) +1. [Follow the coding style](#python-coding-style) +1. [Test your code](#testing) +1. Send a [pull request](https://github.com/broadinstitute/ml/pulls) + +## Setup for code contributions + +### Get setup for GitHub + +Small typos in code or documentation may be edited directly using the GitHub web interface. Otherwise: + +1. If you are new to GitHub, don't start here. Instead, work through a GitHub tutorial such as https://guides.github.com/activities/hello-world/. +1. Create a fork of https://github.com/broadinstitute/ml +1. Clone your fork. +1. Work from a feature branch. See the [Appendix](#appendix) for detailed `git` commands. + +### Install precommit + +[`pre-commit`](https://pre-commit.com/) is a framework for managing and maintaining multi-language pre-commit hooks. + +``` +# Install pre-commit +pip3 install pre-commit +# Install the git hook scripts by running this within the git clone directory +cd ${HOME}/ml +pre-commit install +``` + +See [.pre-commit-config.yaml](https://github.com/broadinstitute/ml/blob/master/.pre-commit-config.yaml) for the currently configured pre-commit hooks for ml4cvd. + +### Install git-secrets + +```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our +repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be +installed with Homebrew by running ```brew install git-secrets```. + +To add hooks to all repositories that you initialize or clone in the future: + +```git secrets --install --global``` + +To add hooks to all local repositories: + +``` +git secrets --install ~/.git-templates/git-secrets +git config --global init.templateDir ~/.git-templates/git-secrets +``` + +We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid +committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to +```git_secrets_provider_ml4cvd.txt``` to match types of critical data that are not currently covered by the patterns in that +file. To register the patterns in this file with ```git-secrets```: + +``` +git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4cvd.txt +``` + +### Install pylint + +[`pylint`](https://www.pylint.org/) is a Python static code analysis tool which looks for programming errors, helps enforcing a coding standard, sniffs for code smells and offers simple refactoring suggestions. + +``` +# Install pylint +pip3 install pylint +``` + +See [pylintrc](https://github.com/broadinstitute/ml/blob/master/pylintrc) for the current lint configuration for ml4cvd. + +# Python coding style + +Changes to ml4cvd should conform to [PEP 8 -- Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/). See also [Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) as another decription of this coding style. + +Use `pylint` to check your Python changes: + +```bash +pylint --rcfile=${HOME}/ml/pylintrc myfile.py +``` + +Any messages returned by `pylint` are intended to be self-explanatory, but that isn't always the case. + +* Search for `pylint ` or `pylint ` for more details on the recommended code change to resolve the lint issue. +* Or add comment `# pylint: disable=` to the end of the line of code. + +# Testing + +## Testing of `recipes` + +Unit tests can be run in Docker with +``` +${HOME}/ml/scripts/tf.sh -T ${HOME}/ml/tests +``` +Unit tests can be run locally in a conda environment with +``` +python -m pytest ${HOME}/ml/tests +``` +Some of the unit tests are slow due to creating, saving and loading `tensorflow` models. +To skip those tests to move quickly, run +``` +python -m pytest ${HOME}/ml/tests -m "not slow" +``` +pytest can also run specific tests using `::`. For example + +``` +python -m pytest ${HOME}/ml/tests/test_models.py::TestMakeMultimodalMultitaskModel::test_u_connect_segment +``` + +For more pytest usage information, checkout the [usage guide](https://docs.pytest.org/en/latest/usage.html). + +## Testing of `visualization_tools` + +The code in [ml4cvd/visualization_tools](https://github.com/broadinstitute/ml/tree/master/ml4cvd/visualization_tools) is primarily interactive so we add test cases to notebook [test_error_handling_for_notebook_visualizations.ipynb](https://github.com/broadinstitute/ml/blob/master/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb) and visually inspect the output of `Cells -> Run all`. + +# Appendix + +For the ml4cvd GitHub repository, we are doing ‘merge and squash’ of pull requests. So that means your fork does not match upstream after your pull request has been merged. The easiest way to manage this is to always work in a feature branch, instead of checking changes into your fork’s master branch. + + +## How to work on a new feature + +(1) Get the latest version of the upstream repo + +``` +git fetch upstream +``` + +Note: If you get an error saying that upstream is unknown, run the following remote add command and then re-run the fetch command. You only need to do this once per git clone. + +``` +git remote add upstream https://github.com/broadinstitute/ml.git +``` + +(2) Make sure your master branch is “even” with upstream. + +``` +git checkout master +git merge --ff-only upstream/master +git push +``` + +Now the master branch of your fork on GitHub should say *"This branch is even with broadinstitute:master."*. + + +(3) Create a feature branch for your change. + +``` +git checkout -b my-feature-branch-name +``` + +Because you created this feature branch from your master branch that was up to date with upstream (step 2), your feature branch is also up to date with upstream. Commit your changes to this branch until you are happy with them. + +(4) Push your changes to GitHub and send a pull request. + +``` +git push --set-upstream origin my-feature-branch-name +``` + +After your pull request is merged, its safe to delete your branch! + +## I accidentally checked a new change to my master branch instead of a feature branch. How to fix this? + +(1) Soft undo your change(s). This leaves the changes in the files on disk but undoes the commit. + +``` +git checkout master +# Moves pointer back to previous HEAD +git reset --soft HEAD@{1} +``` + +Or if you need to move back several commits to the most recent one in common with upstream, you can change ‘1’ to be however many commits back you need to go. + +(2) “stash” your now-unchecked-in changes so that you can get them back later. + +``` +git stash +``` + +(3) Now do the [How to work on a new feature](#how-to-work-on-a-new-feature) step to bring master up to date and create your new feature branch that is “even” with upstream. Here are those commands again: + +``` +git fetch upstream +git merge --ff-only upstream/master +git checkout -b my-feature-branch-name +``` + +(4) “unstash” your changes. + +``` +git stash pop +``` +Now you can proceed with your work! diff --git a/ml4h/visualization_tools/batch_image_annotations.py b/ml4h/visualization_tools/batch_image_annotations.py new file mode 100644 index 000000000..34ff731df --- /dev/null +++ b/ml4h/visualization_tools/batch_image_annotations.py @@ -0,0 +1,236 @@ +"""Methods for batch annotations of images stored as 3D tensors, such as MRIs, from within notebooks.""" + +import json +import os +import socket +import tempfile +from typing import Any, Dict, List + +from IPython.display import display +import numpy as np +import pandas as pd +import h5py +from ipyannotations import PolygonAnnotator +import ipywidgets as widgets +from ml4h.visualization_tools.hd5_mri_plots import MRI_TMAPS +from ml4h.visualization_tools.annotation_storage import AnnotationStorage +from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage +from PIL import Image +import tensorflow as tf + + +class BatchImageAnnotator(): + """Annotate batches of images with polygons drawn over regions of interest.""" + + SUBMIT_BUTTON_DESCRIPTION = 'Submit polygons, goto next sample' + USE_INSTRUCTIONS = ''' +

    +
  • To draw a polygon, click anywhere you'd like to start. Continue to click + along the edge of the polygon until arrive back where you started. To + finish, simply click the first point (highlighted in red). It may be + helpful to increase the point size if you're struggling (using the slider).
  • + +
  • You can change the class of a polygon using the dropdown menu while the + polygon is still "open", or unfinished. If you make a mistake, use the Undo + button until the point that's wrong has disappeared. + +
  • You can move, but not add / subtract polygon points, by clicking the "Edit" + button. Simply drag a point you want to adjust. Again, if you have + difficulty aiming at the points, you can increase the point size.
  • + +
  • You can increase or decrease the contrast and brightness of the image + using the sliders to make it easier to annotate. Sometimes you need to see + what's behind already-created annotations, and for this purpose you can + make them more see-through using the "Opacity" slider.
  • +

+ ''' + EXPECTED_COLUMN_NAMES = ['sample_id', 'tmap_name', 'instance_number', 'folder'] + DEFAULT_ANNOTATION_CLASSNAME = 'region_of_interest' + CSS = ''' + + ''' + + def __init__( + self, samples: pd.DataFrame, annotation_categories: List[str] = None, + zoom: float = 1.5, annotation_storage: AnnotationStorage = TransientAnnotationStorage(), + ): + """Initializes an instance of BatchImageAnnotator. + + Args: + samples: A dataframe of samples to annotate. Columns must include those + in BatchImageAnnotator.EXPECTED_COLUMN_NAMES. + annotation_categories: A list of one or more strings to serve as tags for the polygons. + zoom: Desired zoom level for the image. + annotation_storage: An instance of AnnotationStorage. This faciltates the use of a user-provided + strategy for the storage and processing of annotations. + + Raises: + ValueError: The provided dataframe does not contain the expected columns. + """ + if not set(self.EXPECTED_COLUMN_NAMES).issubset(samples.columns): + raise ValueError(f'samples Dataframe must contain columns {self.EXPECTED_COLUMN_NAMES}') + self.samples = samples + self.current_sample = 0 + # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11 + self.zoom = zoom + self.annotation_storage = annotation_storage + if annotation_categories is None: + annotation_categories = [self.DEFAULT_ANNOTATION_CLASSNAME] + + self.annotation_widget = PolygonAnnotator( + options=annotation_categories, + canvas_size=(900, 280 * self.zoom), + ) + self.annotation_widget.on_submit(self._store_annotations) + self.annotation_widget.submit_button.description = self.SUBMIT_BUTTON_DESCRIPTION + self.annotation_widget.submit_button.layout = widgets.Layout(width='300px') + + self.title_widget = widgets.HTML('') + self.results_widget = widgets.HTML('') + + def _store_annotations(self, data: Dict[Any, Any]) -> None: + """Transfer widget state to the annotation storage and advance to the next sample.""" + if self.current_sample >= self.samples.shape[0]: + self.results_widget.value = '

Annotation batch complete!

Thank you for making the model better.' + return + + # Convert polygon points in canvas coordinates to tensor coordinates. + image_canvas_position = self.annotation_widget.canvas.image_extent + x_offset, y_offset, _, _ = image_canvas_position + tensor_coords = [ + ( + a['label'], + [( + int((p[0] - x_offset) / self.zoom), + int((p[1] - y_offset) / self.zoom), + ) for p in a['points']], + ) for a in data + ] + # Store the annotation using the provided annotation storage strategy. + self.annotation_storage.submit_annotation( + sample_id=self.samples.loc[self.current_sample, 'sample_id'], + annotator=os.getenv('OWNER_EMAIL') if os.getenv('OWNER_EMAIL') else socket.gethostname(), + key=self.samples.loc[self.current_sample, 'tmap_name'], + value_numeric=self.samples.loc[self.current_sample, 'instance_number'], + value_string=self.samples.loc[self.current_sample, 'folder'], + comment=json.dumps(tensor_coords), + ) + + # Display this annotation at the bottom of the widget. + results = f''' +
+

Prior sample's submitted annotations

+ The {self.SUBMIT_BUTTON_DESCRIPTION} button is both printing out the polygons below and storing the polygons + via strategy {self.annotation_storage.__class__.__name__}.
+ Details: {self.annotation_storage.describe()} +

sample info

+ {self._format_info_for_current_sample()} +

canvas coordinates

+ image extent {image_canvas_position} + {[f'
{json.dumps(x)}
' for x in data]} +

source tensor coordinates

+ {[f'
{json.dumps(x)}
' for x in tensor_coords]} +
+ ''' + self.results_widget.value = results + + # Advance to the next sample. + self.current_sample += 1 + self._annotate_image_for_current_sample() + + def _format_info_for_current_sample(self) -> str: + """Convert information about the current sample to an HTML table for display within the widget.""" + headings = ' '.join([f'{c}' for c in self.EXPECTED_COLUMN_NAMES] + ['TMAP shape']) + values = ' '.join([f'{self.samples.loc[self.current_sample, c]}' for c in self.EXPECTED_COLUMN_NAMES] + + [f'{MRI_TMAPS[self.samples.loc[self.current_sample, "tmap_name"]].shape}']) + return f''' + + {headings} + {values} +
+ ''' + + def _annotate_image_for_current_sample(self) -> None: + """Retrieve the data for the current sample and display its image in the annotation widget. + + If all samples have been processed, display the completion message. + """ + if self.current_sample >= self.samples.shape[0]: + self.annotation_widget.canvas.clear() + # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this + # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15 + self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access + self.title_widget.value = '

Annotation batch complete!

Thank you for making the model better.' + return + + sample_id = self.samples.loc[self.current_sample, 'sample_id'] + tmap_name = self.samples.loc[self.current_sample, 'tmap_name'] + instance_number = self.samples.loc[self.current_sample, 'instance_number'] + folder = self.samples.loc[self.current_sample, 'folder'] + + with tempfile.TemporaryDirectory() as tmpdirname: + sample_hd5 = str(sample_id) + '.hd5' + local_path = os.path.join(tmpdirname, sample_hd5) + try: + tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path) + hd5 = h5py.File(local_path, mode='r') + except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: + self.annotation_widget.canvas.clear() + # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this + # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15 + self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access + self.title_widget.value = f''' +
+

Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}

+ Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket. +

{e.message}

+
+ ''' + return + + tensor = MRI_TMAPS[tmap_name].tensor_from_file(MRI_TMAPS[tmap_name], hd5) + tensor_instance = tensor[:, :, instance_number] + if self.zoom > 1.0: + # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11 + img = Image.fromarray(tensor_instance) + zoomed_img = img.resize([int(self.zoom * s) for s in img.size], Image.LANCZOS) + tensor_instance = np.asarray(zoomed_img) + + self.annotation_widget.display(tensor_instance) + self.title_widget.value = f''' + {self.CSS} +
+

Batch annotation of {self.samples.shape[0]} samples

+ {self.USE_INSTRUCTIONS} +
+

Current sample

+ {self._format_info_for_current_sample()} +
+ ''' + + def annotate_images(self) -> None: + """Begin the batch annotation task by displaying the annotation widget populated with the first sample. + + The submit button is used to proceed to the next sample until all samples have been processed. + """ + self._annotate_image_for_current_sample() + display(widgets.VBox([self.title_widget, self.annotation_widget, self.results_widget])) + + def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: + """View a dataframe of up to [count] most recent submissions. + + Args: + count: The number of the most recent submissions to return. + + Returns: + A dataframe of the most recent annotations. + """ + return self.annotation_storage.view_recent_submissions(count=count) diff --git a/notebooks/review_results/image_annotations.ipynb b/notebooks/review_results/image_annotations.ipynb new file mode 100644 index 000000000..6e644f15a --- /dev/null +++ b/notebooks/review_results/image_annotations.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image annotations for a batch of samples\n", + "\n", + "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "\n", + "
\n", + " This notebook assumes\n", + "
    \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", + "
\n", + "
" + ] + }, + { + "attachments": { + "Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png](attachment:Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO(deflaux): remove this cell after gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu has this preinstalled.\n", + "from ml4h.runtime_data_defines import determine_runtime\n", + "from ml4h.runtime_data_defines import Runtime\n", + "\n", + "if Runtime.ML4H_VM == determine_runtime():\n", + " !pip3 install --user ipycanvas==0.4.1 ipyannotations==0.2.0\n", + " !jupyter nbextension install --user --py ipycanvas\n", + " !jupyter nbextension enable --user --py ipycanvas\n", + " # Be sure to restart the kernel if pip installs anything.\n", + " # Also, shift-reload the browser page after the notebook extension installation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4h.visualization_tools.annotation_storage import BigQueryAnnotationStorage\n", + "from ml4h.visualization_tools.batch_image_annotations import BatchImageAnnotator\n", + "import pandas as pd\n", + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [] + }, + "outputs": [], + "source": [ + "%%javascript\n", + "// Display cell outputs to full height (no vertical scroll bar)\n", + "IPython.OutputArea.auto_scroll_threshold = 9999;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.set_option('display.max_colwidth', -1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BIG_QUERY_ANNOTATIONS_STORAGE = BigQueryAnnotationStorage('uk-biobank-sek-data.ml_results.annotations')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define the batch of samples to annotate\n", + "\n", + "
\n", + " Edit the CSV file path below, if needed, to either a local file or one in Cloud Storage.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#---[ EDIT AND RUN THIS CELL TO READ FROM A LOCAL FILE OR A FILE IN CLOUD STORAGE ]---\n", + "SAMPLE_BATCH_FILE = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if SAMPLE_BATCH_FILE:\n", + " samples_df = pd.read_csv(tf.io.gfile.GFile(SAMPLE_BATCH_FILE))\n", + "\n", + "else:\n", + " # Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n", + " # more obvious in this demo that we are processing different samples.\n", + " samples_df = pd.DataFrame(\n", + " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n", + " data=[\n", + " [1655349, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", + " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", + " [1655349, 'cine_lax_4ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", + " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", + " [2403657, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n", + " ])\n", + "\n", + "samples_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "samples_df.head(n = 10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Annotate the batch! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", + "# use the PIL library to scale the image.\n", + "\n", + "annotator = BatchImageAnnotator(samples=samples_df,\n", + " zoom=2.0,\n", + " annotation_categories=['region_of_interest'],\n", + " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n", + "annotator.annotate_images()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# View the stored annotations " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "annotator.view_recent_submissions(count=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Provenance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "print(datetime.datetime.now())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "pip3 freeze" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "199px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb new file mode 100644 index 000000000..ce15e4d73 --- /dev/null +++ b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb @@ -0,0 +1,259 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image annotations for a batch of samples\n", + "\n", + "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "\n", + "
\n", + " This notebook assumes\n", + "
    \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", + "
  • ml4cvd is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", + "
\n", + "
" + ] + }, + { + "attachments": { + "Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png](attachment:Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4cvd.visualization_tools.batch_image_annotations import BatchImageAnnotator\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [] + }, + "outputs": [], + "source": [ + "%%javascript\n", + "// Display cell outputs to full height (no vertical scroll bar)\n", + "IPython.OutputArea.auto_scroll_threshold = 9999;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.set_option('display.max_colwidth', -1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define the batch of samples to annotate\n", + "\n", + "In general, we would read in a CSV file but for this demo we define the batch right here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n", + "# more obvious in this demo that we are processing different samples.\n", + "samples_df = pd.DataFrame(\n", + " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n", + " data=[\n", + " ['fake_1', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", + " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n", + " ['fake_1', 'cine_lax_4ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", + " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n", + " ['fake_2', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n", + " ])\n", + "\n", + "samples_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Annotate the batch! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", + "# use the PIL library to scale the image.\n", + "\n", + "annotator = BatchImageAnnotator(samples=samples_df, zoom=2.0, annotation_categories=['region_of_interest'])\n", + "annotator.annotate_images()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## via BigQuery annotation storage " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4cvd.visualization_tools.annotation_storage import BigQueryAnnotationStorage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BIG_QUERY_ANNOTATIONS_STORAGE = BigQueryAnnotationStorage('uk-biobank-sek-data.ml_results.annotations')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n", + "# use the PIL library to scale the image.\n", + "\n", + "annotator = BatchImageAnnotator(samples=samples_df,\n", + " zoom=2.0,\n", + " annotation_categories=['region_of_interest'],\n", + " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n", + "annotator.annotate_images()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# View the stored annotations " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "annotator.view_recent_submissions(count=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Provenance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "print(datetime.datetime.now())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "pip3 freeze" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "199px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pylintrc b/pylintrc new file mode 100644 index 000000000..8a5e40122 --- /dev/null +++ b/pylintrc @@ -0,0 +1,337 @@ +# This configuration was copied from https://github.com/tensorflow/tensorflow/blob/18ebe824d2f6f20b09839cb0a0073032a2d6c5fe/tensorflow/tools/ci_build/pylintrc and then further modified. + +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the denylist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager + + +# Set the cache size for astng objects. +cache-size=500 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=apply,input,reduce + + +# Disable the report(s) with the given id(s). +# All non-Google reports are disabled by default. +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# List of optional constructs for which whitespace checking is disabled +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception,StandardError,BaseException + + +[AST] + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc + + +[DOCSTRING] + +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 + + +[GOOGLE LINES] + +# Regexp for a proper copyright notice. +copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. From 837d019b5090b87827c804e6184ebf5149ca6709 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Tue, 29 Sep 2020 17:26:25 -0400 Subject: [PATCH 03/21] paired --- README.md | 32 +--- docker/terra_image/Dockerfile | 8 +- docker/terra_image/README.md | 8 +- docker/vm_boot_images/Dockerfile | 5 +- .../config/tensorflow-requirements.txt | 2 + ml4h/plots.py | 36 ++-- ml4h/tensorize/tensor_writer_ukbb.py | 181 +++++++++++------- 7 files changed, 154 insertions(+), 118 deletions(-) diff --git a/README.md b/README.md index 0335a4885..8996bfe39 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # ml4h `ml4h` is a project aimed at using machine learning to model multi-modal cardiovascular time series and imaging data. `ml4h` began as a set of tools to make it easy to work -with the UK Biobank on the Google Cloud and has since expanded to include other data sources +with the UK Biobank on Google Cloud Platform and has since expanded to include other data sources and functionality. @@ -9,6 +9,7 @@ Getting Started * [Setting up your local environment](#setting-up-your-local-environment) * [Setting up a remote VM](#setting-up-a-remote-vm) * Modeling/Data Sources/Tests [(`ml4h/DATA_MODELING_TESTS.md`)](ml4h/DATA_MODELING_TESTS.md) +* [Contributing Code](#contributing-code) Advanced Topics: * Tensorizing Data (going from raw data to arrays suitable for modeling, in `ml4h/tensorize/README.md, TENSORIZE.md` ) @@ -19,7 +20,7 @@ Clone the repo ``` git clone git@github.com:broadinstitute/ml.git ``` -Make sure you have installed the [google cloud tools (gcloud)](https://cloud.google.com/storage/docs/gsutil_install). With [Homebrew](https://brew.sh/), you can use +Make sure you have installed the [Google Cloud SDK (gcloud)](https://cloud.google.com/sdk/docs/downloads-interactive). With [Homebrew](https://brew.sh/), you can use ``` brew cask install google-cloud-sdk ``` @@ -145,29 +146,6 @@ If you get a public key error run: `gcloud compute config-ssh` Now open a browser on your laptop and go to the URL `http://localhost:8888` +## Contributing code -### Installing git-secrets - -```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our -repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be -installed with Homebrew by running ```brew install git-secrets```. - -To add hooks to all repositories that you initialize or clone in the future: - -```git secrets --install --global``` - -To add hooks to all local repositories: - -``` -git secrets --install ~/.git-templates/git-secrets -git config --global init.templateDir ~/.git-templates/git-secrets -``` - -We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid -committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to -```git_secrets_provider_ml4h.txt``` to match types of critical data that are not currently covered by the patterns in that -file. To register the patterns in this file with ```git-secrets```: - -``` -git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4h.txt -``` +Want to contribute code to this project? Please see [CONTRIBUTING](./CONTRIBUTING.md) for developer setup and other details. diff --git a/docker/terra_image/Dockerfile b/docker/terra_image/Dockerfile index 721f94500..a59ecd6ae 100644 --- a/docker/terra_image/Dockerfile +++ b/docker/terra_image/Dockerfile @@ -1,4 +1,4 @@ -FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.0 +FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.6 # https://github.com/DataBiosphere/terra-docker/blob/master/terra-jupyter-gatk/CHANGELOG.md USER root @@ -19,6 +19,10 @@ RUN pip3 install --user -r $HOME/ml4h_pkg/config/tensorflow-requirements.txt \ # first few rows of the downloaded dataframe of query results. # Pin version due to https://github.com/googleapis/google-cloud-python/issues/9965 && pip3 install --upgrade --user google-cloud-bigquery[pandas]==1.22.0 \ + # Upgrade to a newer version. The one on the base Terra image was a bit too old. + && pip3 install --upgrade --user numpy \ # Configure notebook extensions. && jupyter nbextension install --user --py vega \ - && jupyter nbextension enable --user --py vega + && jupyter nbextension enable --user --py vega \ + && jupyter nbextension install --user --py ipycanvas \ + && jupyter nbextension enable --user --py ipycanvas diff --git a/docker/terra_image/README.md b/docker/terra_image/README.md index 71284c0bd..9a81dc74a 100644 --- a/docker/terra_image/README.md +++ b/docker/terra_image/README.md @@ -2,13 +2,13 @@ To build and push: ``` -mv ml4cvd ml4cvdBAK_$(date +"%Y%m%d_%H%M%S") \ +mv ml4h ml4hBAK_$(date +"%Y%m%d_%H%M%S") \ && mv config configBAK_$(date +"%Y%m%d_%H%M%S") \ - && cp -r ../../ml4cvd . \ + && cp -r ../../ml4h . \ && cp -r ../vm_boot_images/config . \ && gcloud --project uk-biobank-sek-data builds submit \ --timeout 20m \ - --tag gcr.io/uk-biobank-sek-data/ml4cvd_terra:`date +"%Y%m%d_%H%M%S"` . + --tag gcr.io/uk-biobank-sek-data/ml4h_terra:`date +"%Y%m%d_%H%M%S"` . ``` Notes: @@ -20,5 +20,5 @@ available to docker. cd notebooks find . -name "*.ipynb" -type f -print0 | \ xargs -0 perl -i -pe \ - 's/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:20200623_145127/g' + 's/gcr.io\/uk-biobank-sek-data\/ml4h_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4h_terra:20200623_145127/g' ``` diff --git a/docker/vm_boot_images/Dockerfile b/docker/vm_boot_images/Dockerfile index a62694ca0..59e5b32be 100644 --- a/docker/vm_boot_images/Dockerfile +++ b/docker/vm_boot_images/Dockerfile @@ -34,4 +34,7 @@ RUN apt-get install python3-tk libgl1-mesa-glx libxt-dev -y # Requirements for the tensorflow project RUN pip3 install --upgrade pip RUN pip3 install -r pre_requirements.txt -RUN pip3 install -r tensorflow-requirements.txt +RUN pip3 install -r tensorflow-requirements.txt \ + # Configure notebook extensions. + && jupyter nbextension install --user --py ipycanvas \ + && jupyter nbextension enable --user --py ipycanvas diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index bb6a1e777..d782967af 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -28,3 +28,5 @@ altair facets-overview plotnine vega +ipycanvas==0.4.1 +ipyannotations==0.2.0 diff --git a/ml4h/plots.py b/ml4h/plots.py index c60282803..6b5ad6a78 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -40,6 +40,9 @@ from scipy.ndimage.filters import gaussian_filter from scipy import stats +import ml4h.tensormap.ukb.ecg +import ml4h.tensormap.mgb.ecg +from ml4h.tensormap.mgb.dynamic import make_waveform_maps from ml4h.TensorMap import TensorMap from ml4h.metrics import concordance_index, coefficient_of_determination from ml4h.defines import IMAGE_EXT, JOIN_CHAR, PDF_EXT, TENSOR_EXT, ECG_REST_LEADS, ECG_REST_MEDIAN_LEADS, PARTNERS_DATETIME_FORMAT, PARTNERS_DATE_FORMAT, HD5_GROUP_CHAR @@ -1227,16 +1230,15 @@ def _plot_partners_figure( def plot_partners_ecgs(args): plot_tensors = [ - 'partners_ecg_patientid', 'partners_ecg_firstname', 'partners_ecg_lastname', - 'partners_ecg_sex', 'partners_ecg_dob', 'partners_ecg_age', - 'partners_ecg_datetime', 'partners_ecg_sitename', 'partners_ecg_location', - 'partners_ecg_read_md', 'partners_ecg_taxis_md', 'partners_ecg_rate_md', - 'partners_ecg_pr_md', 'partners_ecg_qrs_md', 'partners_ecg_qt_md', - 'partners_ecg_paxis_md', 'partners_ecg_raxis_md', 'partners_ecg_qtc_md', + ml4h.tensormap.mgb.ecg.partners_ecg_patientid, ml4h.tensormap.mgb.ecg.partners_ecg_firstname, ml4h.tensormap.mgb.ecg.partners_ecg_lastname, + ml4h.tensormap.mgb.ecg.partners_ecg_sex, ml4h.tensormap.mgb.ecg.partners_ecg_dob, ml4h.tensormap.mgb.ecg.partners_ecg_age, + ml4h.tensormap.mgb.ecg.partners_ecg_datetime, ml4h.tensormap.mgb.ecg.partners_ecg_sitename, ml4h.tensormap.mgb.ecg.partners_ecg_location, + ml4h.tensormap.mgb.ecg.partners_ecg_read_md, ml4h.tensormap.mgb.ecg.partners_ecg_taxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_rate_md, + ml4h.tensormap.mgb.ecg.partners_ecg_pr_md, ml4h.tensormap.mgb.ecg.partners_ecg_qrs_md, ml4h.tensormap.mgb.ecg.partners_ecg_qt_md, + ml4h.tensormap.mgb.ecg.partners_ecg_paxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_raxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_qtc_md, ] - voltage_tensor = 'partners_ecg_2500_raw' - from ml4h.tensor_maps_partners_ecg_labels import TMAPS - tensor_maps_in = [TMAPS[it] for it in plot_tensors + [voltage_tensor]] + voltage_tensor = make_waveform_maps('partners_ecg_2500_raw') + tensor_maps_in = plot_tensors + [voltage_tensor] tensor_paths = [os.path.join(args.tensors, tp) for tp in os.listdir(args.tensors) if os.path.splitext(tp)[-1].lower()==TENSOR_EXT] if 'clinical' == args.plot_mode: @@ -1503,13 +1505,13 @@ def plot_ecg_rest( :param is_blind: if True, the plot gets blinded (helpful for review and annotation) """ map_fields_to_tmaps = { - 'ramp': 'ecg_rest_ramplitude_raw', - 'samp': 'ecg_rest_samplitude_raw', - 'aVL': 'ecg_rest_lvh_avl', - 'Sokolow_Lyon': 'ecg_rest_lvh_sokolow_lyon', - 'Cornell': 'ecg_rest_lvh_cornell', - } - from ml4h.tensor_from_file import TMAPS + 'ramp': ml4h.tensormap.ukb.ecg.ecg_rest_ramplitude_raw, + 'samp': ml4h.tensormap.ukb.ecg.ecg_rest_samplitude_raw, + 'aVL': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_avl, + 'Sokolow_Lyon': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_sokolow_lyon, + 'Cornell': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_cornell, + } + raw_scale = 0.005 # Conversion from raw to mV default_yrange = ECG_REST_PLOT_DEFAULT_YRANGE # mV time_interval = 2.5 # time-interval per plot in seconds. ts_Reference data is in s, voltage measurement is 5 uv per lsb @@ -1521,7 +1523,7 @@ def plot_ecg_rest( with h5py.File(tensor_path, 'r') as hd5: traces, text = _ecg_rest_traces_and_text(hd5) for field in map_fields_to_tmaps: - tm = TMAPS[map_fields_to_tmaps[field]] + tm = map_fields_to_tmaps[field] patient_dic[field] = np.zeros(tm.shape) try: patient_dic[field][:] = tm.tensor_from_file(tm, hd5) diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index 64242bcbe..ceb4389e8 100644 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -87,6 +87,10 @@ def write_tensors( mri_unzip: str, mri_field_ids: List[int], xml_field_ids: List[int], + zoom_x: int, + zoom_y: int, + zoom_width: int, + zoom_height: int, write_pngs: bool, min_sample_id: int, max_sample_id: int, @@ -105,6 +109,13 @@ def write_tensors( :param mri_unzip: Folder where zipped DICOM will be decompressed :param mri_field_ids: List of MRI field IDs from UKBB :param xml_field_ids: List of ECG field IDs from UKBB + :param x: Maximum x dimension of MRIs + :param y: Maximum y dimension of MRIs + :param z: Maximum z dimension of MRIs + :param zoom_x: x coordinate of the zoom + :param zoom_y: y coordinate of the zoom + :param zoom_width: width of the zoom + :param zoom_height: height of the zoom :param write_pngs: write MRIs as PNG images for debugging :param min_sample_id: Minimum sample id to generate, for parallelization :param max_sample_id: Maximum sample id to generate, for parallelization @@ -126,7 +137,7 @@ def write_tensors( continue try: with h5py.File(tp, 'w') as hd5: - _write_tensors_from_zipped_dicoms(write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) + _write_tensors_from_zipped_dicoms(zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) _write_tensors_from_zipped_niftis(zip_folder, mri_field_ids, hd5, sample_id, stats) _write_tensors_from_xml(xml_field_ids, xml_folder, hd5, sample_id, write_pngs, stats, continuous_stats) stats['Tensors written'] += 1 @@ -177,26 +188,19 @@ def write_tensors_from_dicom_pngs( continue stats[sample_header + '_' + sample_id] += 1 dicom_file = row[dicom_index] - try: png = imageio.imread(os.path.join(png_path, dicom_file + png_postfix)) - if len(png.shape) == 3 and png.mean() == png[:, :, 0].mean(): - png = png[:, :, 0] - elif len(png.shape) == 3: - raise ValueError(f'PNG has color information but no method to tensorize it {png.mean()}, 0ch :{png[:, :, 0].mean()}, 1ch :{png[:, :, 1].mean()}, 2ch :{png[:, :, 2].mean()}.') full_tensor = np.zeros((x, y), dtype=np.float32) full_tensor[:png.shape[0], :png.shape[1]] = png tensor_file = os.path.join(tensors, str(sample_id) + TENSOR_EXT) if not os.path.exists(os.path.dirname(tensor_file)): os.makedirs(os.path.dirname(tensor_file)) with h5py.File(tensor_file, 'a') as hd5: - tensor_name = series.lower() + '_annotated_' + row[instance_index] + tensor_name = series + '_annotated_' + row[instance_index] tp = tensor_path(path_prefix, tensor_name) if tp in hd5: tensor = first_dataset_at_path(hd5, tp) - min_x = min(png.shape[0], tensor.shape[0]) - min_y = min(png.shape[1], tensor.shape[1]) - tensor[:min_x, :min_y] = full_tensor[:min_x, :min_y] + tensor[:] = full_tensor stats['updated'] += 1 else: create_tensor_in_hd5(hd5, path_prefix, tensor_name, full_tensor, stats) @@ -324,7 +328,7 @@ def _dicts_and_plots_from_tensorization( continuous = {} value_counter = Counter() for k in sorted(list(stats.keys())): - #logging.info("{} has {}".format(k, stats[k])) + logging.info("{} has {}".format(k, stats[k])) if 'categorical' not in k and 'continuous' not in k: continue @@ -342,10 +346,10 @@ def _dicts_and_plots_from_tensorization( plot_value_counter(list(categories.keys()), value_counter, a_id + '_v_count', os.path.join(output_folder, a_id)) plot_histograms(continuous_stats, a_id, os.path.join(output_folder, a_id)) - # logging.info("Continuous tensor map: {}".format(continuous)) - # logging.info("Continuous Columns: {}".format(len(continuous))) - # logging.info("Category tensor map: {}".format(categories)) - # logging.info("Categories Columns: {}".format(len(categories))) + logging.info("Continuous tensor map: {}".format(continuous)) + logging.info("Continuous Columns: {}".format(len(continuous))) + logging.info("Category tensor map: {}".format(categories)) + logging.info("Categories Columns: {}".format(len(categories))) def _to_float_or_false(s): @@ -363,6 +367,10 @@ def _to_float_or_nan(s): def _write_tensors_from_zipped_dicoms( + zoom_x: int, + zoom_y: int, + zoom_width: int, + zoom_height: int, write_pngs: bool, tensors: str, dicoms: str, @@ -382,8 +390,10 @@ def _write_tensors_from_zipped_dicoms( os.makedirs(dicom_folder) with zipfile.ZipFile(zipped, "r") as zip_ref: zip_ref.extractall(dicom_folder) - ukb_instance = zipped.split('_')[2] - _write_tensors_from_dicoms(write_pngs, tensors, dicom_folder, hd5, sample_str, ukb_instance, stats) + _write_tensors_from_dicoms( + zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, dicom_folder, + hd5, sample_str, stats, + ) stats['MRI fields written'] += 1 shutil.rmtree(dicom_folder) @@ -400,31 +410,36 @@ def _write_tensors_from_zipped_niftis(zip_folder: str, mri_field_ids: List[str], def _write_tensors_from_dicoms( - write_pngs: bool, tensors: str, dicom_folder: str, hd5: h5py.File, sample_str: str, ukb_instance: str, stats: Dict[str, int], + zoom_x: int, zoom_y: int, zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str, + dicom_folder: str, hd5: h5py.File, sample_str: str, stats: Dict[str, int], ) -> None: """Convert a folder of DICOMs from a sample into tensors for each series Segmented dicoms require special processing and are written to tensor per-slice Arguments + :param x: Width of the tensors (actual MRI width will be padded with 0s or cropped to this number) + :param y: Height of the tensors (actual MRI width will be padded with 0s or cropped to this number) + :param z: Minimum number of slices to include in the each tensor if more slices are found they will be kept + :param zoom_x: x coordinate of the zoom + :param zoom_y: y coordinate of the zoom + :param zoom_width: width of the zoom + :param zoom_height: height of the zoom :param write_pngs: write MRIs as PNG images for debugging :param tensors: Folder where hd5 tensor files are being written :param dicom_folder: Folder with all dicoms associated with one sample. :param hd5: Tensor file in which to create datasets for each series and each segmented slice :param sample_str: The current sample ID as a string - :param ukb_instance: The UK Biobank assessment visit instance number :param stats: Counter to keep track of summary statistics """ views = defaultdict(list) - series_to_numbers = defaultdict(set) min_ideal_series = 9e9 for dicom in os.listdir(dicom_folder): if os.path.splitext(dicom)[-1] != DICOM_EXT: continue d = pydicom.read_file(os.path.join(dicom_folder, dicom)) series = d.SeriesDescription.lower().replace(' ', '_') - series_to_numbers[series].add(int(d.SeriesNumber)) if series + '_12bit' in MRI_LIVER_SERIES_12BIT and d.LargestImagePixelValue > 2048: views[series + '_12bit'].append(d) stats[series + '_12bit'] += 1 @@ -447,61 +462,99 @@ def _write_tensors_from_dicoms( else: mri_group = 'ukb_mri' - if len(series_to_numbers[v]) > 1 and v not in MRI_BRAIN_SERIES: - max_series = max(series_to_numbers[v]) - single_series = [dicom for dicom in views[v] if int(dicom.SeriesNumber) == max_series] - # for d in views[v]: - # logging.warning(f'{d.SeriesNumber} with Date: {_datetime_from_dicom(d)} Time {d.AcquisitionTime}') - logging.warning(f'{v} has {len(views[v])} series:{series_to_numbers[v]} Using only max series: {max_series} with {len(single_series)}') - views[v] = single_series if v == MRI_TO_SEGMENT: - _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, ukb_instance, hd5, mri_date, mri_group, stats) + _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, hd5, mri_date, mri_group, stats) elif v in MRI_BRAIN_SERIES: _tensorize_brain_mri(views[v], v, mri_date, mri_group, hd5) else: - pass - # mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32) - # for slicer in views[v]: - # _save_pixel_dimensions_if_missing(slicer, v, hd5) - # _save_slice_thickness_if_missing(slicer, v, hd5) - # _save_series_orientation_and_position_if_missing(slicer, v, hd5) - # slice_index = slicer.InstanceNumber - 1 - # if v in MRI_LIVER_IDEAL_PROTOCOL: - # slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series) - # mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32) - # create_tensor_in_hd5(hd5, mri_group, f'{v}/{ukb_instance}', mri_data, stats, mri_date) + mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32) + for slicer in views[v]: + _save_pixel_dimensions_if_missing(slicer, v, hd5) + _save_slice_thickness_if_missing(slicer, v, hd5) + _save_series_orientation_and_position_if_missing(slicer, v, hd5) + slice_index = slicer.InstanceNumber - 1 + if v in MRI_LIVER_IDEAL_PROTOCOL: + slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series) + mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32) + create_tensor_in_hd5(hd5, mri_group, v, mri_data, stats, mri_date) def _tensorize_short_and_long_axis_segmented_cardiac_mri( - slices: List[pydicom.Dataset], series: str, instance: str, - hd5: h5py.File, mri_date: datetime.datetime, mri_group: str, stats: Dict[str, int], + slices: List[pydicom.Dataset], series: str, zoom_x: int, zoom_y: int, + zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str, + hd5: h5py.File, mri_date: datetime.datetime, mri_group: str, + stats: Dict[str, int], ) -> None: + systoles = {} + diastoles = {} + systoles_pix = {} + systoles_masks = {} + diastoles_masks = {} + for slicer in slices: - #full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) + full_mask = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) + full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32) + if _has_overlay(slicer): if _is_mitral_valve_segmentation(slicer): series = series.replace('sax', 'lax') else: series = series.replace('lax', 'sax') - series_segmented = f'{series}_segmented' + series_zoom = f'{series}_zoom' + series_zoom_segmented = f'{series}_zoom_segmented' + try: overlay, mask, ventricle_pixels, _ = _get_overlay_from_dicom(slicer) except KeyError: logging.exception(f'Got key error trying to make anatomical mask, skipping.') continue - # _save_pixel_dimensions_if_missing(slicer, series, hd5) - # _save_slice_thickness_if_missing(slicer, series, hd5) - # _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber)) + _save_pixel_dimensions_if_missing(slicer, series, hd5) + _save_slice_thickness_if_missing(slicer, series, hd5) + _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber)) _save_pixel_dimensions_if_missing(slicer, series_segmented, hd5) _save_slice_thickness_if_missing(slicer, series_segmented, hd5) _save_series_orientation_and_position_if_missing(slicer, series_segmented, hd5, str(slicer.InstanceNumber)) - # - # cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based - #full_slice[:] = slicer.pixel_array.astype(np.float32) - #create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{instance}', full_slice, stats, mri_date, slicer.InstanceNumber) - create_tensor_in_hd5(hd5, mri_group, f'{series_segmented}{HD5_GROUP_CHAR}{instance}', mask, stats, mri_date, slicer.InstanceNumber) + + cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based + full_slice[:] = slicer.pixel_array.astype(np.float32) + create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', full_slice, stats, mri_date) + create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', mask, stats, mri_date) + + zoom_slice = full_slice[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height] + zoom_mask = mask[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height] + create_tensor_in_hd5(hd5, mri_group, f'{series_zoom}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_slice, stats, mri_date) + create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_mask, stats, mri_date) + + if (slicer.InstanceNumber - 1) % MRI_FRAMES == 0: # Diastole frame is always the first + diastoles[cur_angle] = slicer + diastoles_masks[cur_angle] = mask + if cur_angle not in systoles: + systoles[cur_angle] = slicer + systoles_pix[cur_angle] = ventricle_pixels + systoles_masks[cur_angle] = mask + else: + if ventricle_pixels < systoles_pix[cur_angle]: + systoles[cur_angle] = slicer + systoles_pix[cur_angle] = ventricle_pixels + systoles_masks[cur_angle] = mask + + for angle in diastoles: + logging.info(f'Found systole, instance:{systoles[angle].InstanceNumber} ventricle pixels:{systoles_pix[angle]}') + full_slice = diastoles[angle].pixel_array.astype(np.float32) + create_tensor_in_hd5(hd5, mri_group, f'diastole_frame_b{angle}', full_slice, stats, mri_date) + create_tensor_in_hd5(hd5, mri_group, f'diastole_mask_b{angle}', diastoles_masks[angle], stats, mri_date) + if write_pngs: + plt.imsave(tensors + 'diastole_frame_b' + str(angle) + IMAGE_EXT, full_slice) + plt.imsave(tensors + 'diastole_mask_b' + str(angle) + IMAGE_EXT, full_mask) + + full_slice = systoles[angle].pixel_array.astype(np.float32) + create_tensor_in_hd5(hd5, mri_group, f'systole_frame_b{angle}', full_slice, stats, mri_date) + create_tensor_in_hd5(hd5, mri_group, f'systole_mask_b{angle}', systoles_masks[angle], stats, mri_date) + if write_pngs: + plt.imsave(tensors + 'systole_frame_b' + str(angle) + IMAGE_EXT, full_slice) + plt.imsave(tensors + 'systole_mask_b' + str(angle) + IMAGE_EXT, full_mask) def _tensorize_brain_mri(slices: List[pydicom.Dataset], series: str, mri_date: datetime.datetime, mri_group: str, hd5: h5py.File) -> None: @@ -535,16 +588,13 @@ def _save_slice_thickness_if_missing(slicer, series, hd5): def _save_series_orientation_and_position_if_missing(slicer, series, hd5, instance=None): orientation_ds_name = MRI_PATIENT_ORIENTATION + '_' + series position_ds_name = MRI_PATIENT_POSITION + '_' + series - if instance is not None: - orientation_ds_name = f'{orientation_ds_name}_{instance}' - position_ds_name = f'{position_ds_name}_{instance}' - try: - if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: - hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) - if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: - hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) - except RuntimeError as e: - logging.warning(f' got error {e} \n orientation : {orientation_ds_name} {slicer.ImageOrientationPatient} and pos: {position_ds_name} {slicer.ImagePositionPatient}') + if instance: + orientation_ds_name += HD5_GROUP_CHAR + instance + position_ds_name += HD5_GROUP_CHAR + instance + if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) + if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) def _has_overlay(d) -> bool: @@ -695,16 +745,13 @@ def _write_ecg_rest_tensors(ecgs, xml_field, hd5, sample_id, write_pngs, stats, def create_tensor_in_hd5( hd5: h5py.File, path_prefix: str, name: str, value, stats: Counter = None, date: datetime.datetime = None, - instance: str = None, storage_type: StorageType = None, attributes: Dict[str, Any] = None, + storage_type: StorageType = None, attributes: Dict[str, Any] = None, ): hd5_path = tensor_path(path_prefix, name) - if instance is not None: - hd5_path = f'{hd5_path}instance_{instance}/' if hd5_path in hd5: hd5_path = f'{hd5_path}instance_{len(hd5[hd5_path])}' - elif instance is None: + else: hd5_path = f'{hd5_path}instance_0' - if stats is not None: stats[hd5_path] += 1 if storage_type == StorageType.STRING: From 6021abb4b65527349b8042cd52a29c25f3d0532e Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Tue, 29 Sep 2020 17:27:41 -0400 Subject: [PATCH 04/21] paired --- .../visualization_tools/annotation_storage.py | 36 ++-- ml4h/visualization_tools/annotations.py | 55 +++--- .../dicom_interactive_plots.py | 74 +++---- ml4h/visualization_tools/dicom_plots.py | 122 ++++++------ .../ecg_interactive_plots.py | 22 ++- ml4h/visualization_tools/ecg_reshape.py | 58 +++--- ml4h/visualization_tools/ecg_static_plots.py | 11 +- ml4h/visualization_tools/facets.py | 13 +- ml4h/visualization_tools/hd5_mri_plots.py | 181 +++++++++--------- 9 files changed, 311 insertions(+), 261 deletions(-) diff --git a/ml4h/visualization_tools/annotation_storage.py b/ml4h/visualization_tools/annotation_storage.py index ac8e89249..d0020b1ec 100644 --- a/ml4h/visualization_tools/annotation_storage.py +++ b/ml4h/visualization_tools/annotation_storage.py @@ -2,9 +2,11 @@ import abc import datetime -import pandas as pd +from typing import Optional, Union + from google.cloud import bigquery from google.cloud.bigquery import magics as bqmagics +import pandas as pd class AnnotationStorage(abc.ABC): @@ -14,12 +16,14 @@ class AnnotationStorage(abc.ABC): """ @abc.abstractmethod - def describe(self): + def describe(self) -> str: """Return a string describing how annotations are stored.""" - pass @abc.abstractmethod - def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): + def submit_annotation( + self, sample_id: Union[int, str], annotator: str, key: str, + value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, + ) -> bool: """Add an annotation to the collection of annotations. Args: @@ -32,10 +36,9 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri Returns: Whether the submission was successful. Throws an Exception on failure. """ - pass @abc.abstractmethod - def view_recent_submissions(self, count=10): + def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: """View a dataframe of up to [count] most recent submissions. Args: @@ -44,7 +47,6 @@ def view_recent_submissions(self, count=10): Returns: A dataframe of the most recent annotations. """ - pass class TransientAnnotationStorage(AnnotationStorage): @@ -56,11 +58,14 @@ class TransientAnnotationStorage(AnnotationStorage): def __init__(self): self.annotations = [] - def describe(self): + def describe(self) -> str: return '''Annotations will be stored in memory only during the duration of this demo.\n For durable storage of annotations, use BigQueryAnnotationStorage instead.''' - def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): + def submit_annotation( + self, sample_id: Union[int, str], annotator: str, key: str, + value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, + ) -> bool: """Add this annotation to our in-memory collection of annotations. Args: @@ -85,7 +90,7 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri self.annotations.append(annotation) return True - def view_recent_submissions(self, count=10): + def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: """View a dataframe of up to [count] most recent submissions. Args: @@ -110,14 +115,17 @@ class BigQueryAnnotationStorage(AnnotationStorage): annotations_schema.json """ - def __init__(self, table): + def __init__(self, table: str): """This table should already exist.""" self.table = table - def describe(self): + def describe(self) -> str: return f'''Annotations are stored in BigQuery table {self.table}''' - def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment): + def submit_annotation( + self, sample_id: Union[int, str], annotator: str, key: str, + value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str, + ) -> bool: """Call a BigQuery INSERT statement to add a row containing annotation information. Args: @@ -150,7 +158,7 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri # Return whether the submission completed. return submission.done() - def view_recent_submissions(self, count=10): + def view_recent_submissions(self, count: int = 10) -> pd.DataFrame: """View a dataframe of up to [count] most recent submissions. This is a convenience method for use within the annotation flow. For full access to the underlying annotations, diff --git a/ml4h/visualization_tools/annotations.py b/ml4h/visualization_tools/annotations.py index 2400a07d8..9ca9c1b44 100644 --- a/ml4h/visualization_tools/annotations.py +++ b/ml4h/visualization_tools/annotations.py @@ -2,8 +2,11 @@ import os import socket +from typing import Any, Dict, Union + from IPython.display import display from IPython.display import HTML +import pandas as pd import ipywidgets as widgets from ml4h.visualization_tools.annotation_storage import AnnotationStorage from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage @@ -11,14 +14,18 @@ DEFAULT_ANNOTATION_STORAGE = TransientAnnotationStorage() -def _get_df_sample(sample_info, sample_id): +def _get_df_sample(sample_info: pd.DataFrame, sample_id: Union[int, str]) -> pd.DataFrame: """Return a dataframe containing only the row for the indicated sample_id.""" df_sample = sample_info[sample_info['sample_id'] == str(sample_id)] - if 0 == df_sample.shape[0]: df_sample = sample_info.query('sample_id == ' + str(sample_id)) + if df_sample.shape[0] == 0: df_sample = sample_info.query('sample_id == ' + str(sample_id)) return df_sample -def display_annotation_collector(sample_info, sample_id, annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE, custom_annotation_key=None): +def display_annotation_collector( + sample_info: pd.DataFrame, sample_id: Union[int, str], + annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE, + custom_annotation_key: str = None, +) -> None: """Method to create a gui (set of widgets) through which the user can create an annotation and submit it to storage. Args: @@ -26,15 +33,16 @@ def display_annotation_collector(sample_info, sample_id, annotation_storage: Ann sample_id: The selected sample for which the values will be displayed. annotation_storage: An instance of AnnotationStorage. custom_annotation_key: The key for an annotation of data other than the tabular fields. - - Returns: - A notebook-friendly messages indicating the status of the submission. """ df_sample = _get_df_sample(sample_info, sample_id) if df_sample.shape[0] == 0: - return HTML(f'''
- Warning: Sample {sample_id} not present in sample_info DataFrame.
''') + display( + HTML(f'''
+ Warning: Sample {sample_id} not present in sample_info DataFrame. +
'''), + ) + return # Show the sample ID for this annotation. sample = widgets.HTML(value=f'For sample {sample_id}') @@ -82,7 +90,7 @@ def handle_key_change(change): submit_button = widgets.Button(description='Submit annotation', button_style='success') output = widgets.Output() - def on_button_clicked(b): + def cb_on_button_clicked(b): params = _format_annotation(sample_id=sample_id, key=key.value, keyvalue=keyvalue.value, comment=comment.value) try: success = annotation_storage.submit_annotation( @@ -93,34 +101,38 @@ def on_button_clicked(b): value_string=params['value_string'], comment=params['comment'], ) - except Exception as e: + except Exception as e: # pylint: disable=broad-except display( HTML(f'''
- Warning: Unable to store annotation. -

{e}

-
'''), + Warning: Unable to store annotation. +

{e}

+
'''), ) - return() + return with output: if success: # Show the information that was submitted. display( HTML(f'''
- Submission successful\n[{annotation_storage.describe()}]
'''), + Submission successful\n[{annotation_storage.describe()}] + '''), ) display(annotation_storage.view_recent_submissions(1)) else: display( HTML('''
- Annotation not submitted. Please try again.
'''), + Annotation not submitted. Please try again. + '''), ) - submit_button.on_click(on_button_clicked) + submit_button.on_click(cb_on_button_clicked) # Display all the widgets. display(sample, box1, comment, submit_button, output) -def _format_annotation(sample_id, key, keyvalue, comment): +def _format_annotation( + sample_id: Union[int, str], key: str, keyvalue: Union[int, float, str], comment: str, +) -> Dict[str, Any]: """Helper method to clean and reshape info from the widgets and the environment into a dictionary representing the annotation.""" # Programmatically get the identity of the person running this Terra notebook. current_user = os.getenv('OWNER_EMAIL') @@ -128,11 +140,10 @@ def _format_annotation(sample_id, key, keyvalue, comment): if current_user is None: current_user = socket.gethostname() # By convention, we prefix the hostname with our username. + value_numeric = None + value_string = None # Check whether the value is string or numeric. - if keyvalue is None: - value_numeric = None - value_string = None - else: + if keyvalue is not None: try: value_numeric = float(keyvalue) # this will fail if the value is text value_string = None diff --git a/ml4h/visualization_tools/dicom_interactive_plots.py b/ml4h/visualization_tools/dicom_interactive_plots.py index d9850e841..ec9d63834 100644 --- a/ml4h/visualization_tools/dicom_interactive_plots.py +++ b/ml4h/visualization_tools/dicom_interactive_plots.py @@ -1,4 +1,4 @@ -"""Methods for integration of interactive dicom plots within notebooks. +"""Methods for integration of interactive DICOM plots within notebooks. TODO: * Continue to *pragmatically* improve this to make the visualization controls @@ -8,14 +8,15 @@ import collections import os import tempfile +from typing import Any, DefaultDict, Dict, Optional, Tuple import zipfile from IPython.display import display from IPython.display import HTML +import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_mri_folders -import numpy as np import pydicom import tensorflow as tf @@ -27,15 +28,12 @@ MAX_COLOR_RANGE = 6000 -def choose_mri(sample_id, folder=None): +def choose_mri(sample_id, folder: Optional[str] = None) -> None: """Render widget to choose the MRI to plot. Args: sample_id: The id of the sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - - Returns: - ipywidget or HTML upon error. """ if folder is None: folders = get_mri_folders(sample_id) @@ -45,22 +43,26 @@ def choose_mri(sample_id, folder=None): sample_mris = [] sample_mri_glob = str(sample_id) + '_*.zip' try: - for folder in folders: - sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob))) + for f in folders: + sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(f, sample_mri_glob))) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - return HTML(f''' -
+ display( + HTML(f'''
Warning: MRI not available for sample {sample_id} in {folders}:

{e.message}

Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
''') +
'''), + ) + return if not sample_mris: - return HTML(f''' -
+ display( + HTML(f'''
Warning: MRI DICOMs not available for sample {sample_id} in {folders}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
''') +
'''), + ) + return mri_chooser = widgets.Dropdown( options=sample_mris, @@ -77,14 +79,11 @@ def choose_mri(sample_id, folder=None): display(file_controls_ui, file_controls_output) -def choose_mri_series(sample_mri): +def choose_mri_series(sample_mri: str) -> None: """Render widgets and interactive plots for MRIs. Args: sample_mri: The local or Cloud Storage path to the MRI file. - - Returns: - ipywidget or HTML upon error. """ with tempfile.TemporaryDirectory() as tmpdirname: local_path = os.path.join(tmpdirname, os.path.basename(sample_mri)) @@ -93,13 +92,15 @@ def choose_mri_series(sample_mri): with zipfile.ZipFile(local_path, 'r') as zip_ref: zip_ref.extractall(tmpdirname) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - return HTML(f''' -
+ display( + HTML(f'''
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:

{e.message}

-
''') +
'''), + ) + return - unordered_dicoms = collections.defaultdict(dict) + unordered_dicoms: DefaultDict[Any, Any] = collections.defaultdict(dict) for dcm_file in os.listdir(tmpdirname): if not dcm_file.endswith('.dcm'): continue @@ -112,8 +113,13 @@ def choose_mri_series(sample_mri): unordered_dicoms[key1][key2] = dcm if not unordered_dicoms: - print(f'\n\nNo series available in MRI for sample {os.path.basename(sample_mri)}\n\nTry a different MRI.') - return None + display( + HTML(f'''
+ No series available in MRI for sample {os.path.basename(sample_mri)}. + Try a different MRI. +
'''), + ) + return # Convert from dict of dicts to dict of ordered lists. dicoms = {} @@ -134,7 +140,7 @@ def choose_mri_series(sample_mri): style={'description_width': 'initial'}, layout=widgets.Layout(width='800px'), ) - # Slide through dicom image instances using a slide bar. + # Slide through DICOM image instances using a slide bar. instance_chooser = widgets.IntSlider( continuous_update=True, value=default_instance_value, @@ -212,25 +218,25 @@ def on_value_change(change): display(viz_controls_ui, viz_controls_output) -def compute_color_range(dicoms, series_name): +def compute_color_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]: """Compute the mean values for the color ranges of instances in the series.""" vmin = np.mean([np.min(d.pixel_array) for d in dicoms[series_name]]) vmax = np.mean([np.max(d.pixel_array) for d in dicoms[series_name]]) - return(vmin, vmax) + return (vmin, vmax) -def compute_instance_range(dicoms, series_name): +def compute_instance_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]: """Compute middle and max instances.""" middle_instance = int(len(dicoms[series_name]) / 2) max_instance = len(dicoms[series_name]) - return(middle_instance, max_instance) + return (middle_instance, max_instance) def dicom_animation( - dicoms, series_name, instance, vmin, vmax, transpose, - fig_width, title_prefix='', -): - """Render one frame of a dicom animation. + dicoms: Dict[str, Any], series_name: str, instance: int, vmin: int, vmax: int, transpose: bool, + fig_width: int, title_prefix: str = '', +) -> None: + """Render one frame of a DICOM animation. Args: dicoms: the dictionary DICOM series and instances lists @@ -250,7 +256,7 @@ def dicom_animation( dcm = dicoms[series_name][instance - 1] if instance != dcm.InstanceNumber: # Notice invalid input, but don't throw an error. - print(f'WARNING: Instance parameter {str(instance)} and dicom instance number {str(dcm.InstanceNumber)} do not match.') + print(f'WARNING: Instance parameter {str(instance)} and instance number {str(dcm.InstanceNumber)} do not match.') if transpose: height = dcm.pixel_array.T.shape[0] diff --git a/ml4h/visualization_tools/dicom_plots.py b/ml4h/visualization_tools/dicom_plots.py index ce2b3e083..093691382 100644 --- a/ml4h/visualization_tools/dicom_plots.py +++ b/ml4h/visualization_tools/dicom_plots.py @@ -1,16 +1,17 @@ -"""Methods for integration of dicom plots within notebooks.""" +"""Methods for integration of DICOM plots within notebooks.""" import collections import os import tempfile +from typing import Dict, List, Optional, Tuple, Union import zipfile from IPython.display import display from IPython.display import HTML +import numpy as np import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_cardiac_mri_folder -import numpy as np import pydicom from scipy.ndimage.morphology import binary_closing from scipy.ndimage.morphology import binary_erosion @@ -27,21 +28,21 @@ MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2} -def _is_mitral_valve_segmentation(d): # -> bool: - """Determine whether a dicom has mitral valve segmentation. +def _is_mitral_valve_segmentation(d: pydicom.FileDataset) -> bool: + """Determine whether a DICOM has mitral valve segmentation. This is used for visualization of CINE_segmented_SAX_InlineVF. Args: - d: the dicom file + d: the DICOM file Returns: - Whether or not the dicom has mitral valve segmentation + Whether or not the DICOM has mitral valve segmentation """ return d.SliceThickness == 6 -def _get_overlay_from_dicom(d): +def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]: """Get an overlay from a DICOM file. Morphological operators are used to transform the pixel outline of the @@ -49,7 +50,7 @@ def _get_overlay_from_dicom(d): is used for visualization of CINE_segmented_SAX_InlineVF. Args: - d: the dicom file + d: the DICOM file Returns: Raw overlay array with myocardium outline, anatomical mask (a pixel @@ -77,29 +78,30 @@ def _get_overlay_from_dicom(d): byte >>= 1 bit += 1 overlay = overlay[:expected_bit_length] - if overlay_frames == 1: - overlay = overlay.reshape(rows, cols) - idx = np.where(overlay == 1) - min_pos = (np.min(idx[0]), np.min(idx[1])) - max_pos = (np.max(idx[0]), np.max(idx[1])) - short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) - small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) - big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) - small_structure = _unit_disk(small_radius) - m1 = binary_closing(overlay, small_structure).astype(np.int) - big_structure = _unit_disk(big_radius) - m2 = binary_closing(overlay, big_structure).astype(np.int) - anatomical_mask = m1 + m2 + if overlay_frames != 1: + raise ValueError(f'DICOM has {overlay_frames} overlay frames, but only one expected.') + overlay = overlay.reshape(rows, cols) + idx = np.where(overlay == 1) + min_pos = (np.min(idx[0]), np.min(idx[1])) + max_pos = (np.max(idx[0]), np.max(idx[1])) + short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) + small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) + big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) + small_structure = _unit_disk(small_radius) + m1 = binary_closing(overlay, small_structure).astype(np.int) + big_structure = _unit_disk(big_radius) + m2 = binary_closing(overlay, big_structure).astype(np.int) + anatomical_mask = m1 + m2 + ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) + myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) + if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: + erode_structure = _unit_disk(small_radius*1.5) + anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) - myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) - if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: - erode_structure = _unit_disk(small_radius*1.5) - anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) - ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) - return overlay, anatomical_mask, ventricle_pixels + return overlay, anatomical_mask, ventricle_pixels -def _unit_disk(r): # -> np.ndarray: +def _unit_disk(r: int) -> np.ndarray: """Get the unit disk for a radius. This is used for visualization of CINE_segmented_SAX_InlineVF. @@ -114,7 +116,9 @@ def _unit_disk(r): # -> np.ndarray: return (x ** 2 + y ** 2 <= r ** 2).astype(np.int) -def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''): +def plot_cardiac_long_axis( + b_series: List[pydicom.FileDataset], sides: int = 7, fig_width: int = 18, title_prefix: str = '', +) -> None: """Visualize CINE_segmented_SAX_InlineVF series. Args: @@ -168,9 +172,9 @@ def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''): def plot_cardiac_short_axis( - series, transpose=False, fig_width=18, - title_prefix='', -): + series: List[pydicom.FileDataset], transpose: bool = False, fig_width: int = 18, + title_prefix: str = '', +) -> None: """Visualize CINE_segmented_LAX series. Args: @@ -225,14 +229,14 @@ def plot_cardiac_short_axis( def plot_mri_series( - sample_mri, dicoms, series_name, sax_sides, - lax_transpose, fig_width, -): + sample_mri: str, dicoms: Dict[str, pydicom.FileDataset], series_name: str, sax_sides: int, + lax_transpose: bool, fig_width: int, +) -> None: """Visualize the applicable series within this DICOM. Args: sample_mri: The local or Cloud Storage path to the MRI file. - dicoms: A dictionary of dicoms. + dicoms: A dictionary of DICOMs. series_name: The name of the chosen series. sax_sides: How many sides to display for CINE_segmented_SAX_InlineVF. lax_transpose: Whether to transpose when plotting CINE_segmented_LAX. @@ -258,10 +262,9 @@ def plot_mri_series( ) else: print(f'Visualization not currently implemented for {series_name}.') - return None -def choose_mri_series(sample_mri): +def choose_mri_series(sample_mri: str) -> None: """Render widgets and plots for cardiac MRIs. Visualization is supported for CINE_segmented_SAX_InlineVF series and @@ -269,9 +272,6 @@ def choose_mri_series(sample_mri): Args: sample_mri: The local or Cloud Storage path to the MRI file. - - Returns: - ipywidget or HTML upon error. """ with tempfile.TemporaryDirectory() as tmpdirname: local_path = os.path.join(tmpdirname, os.path.basename(sample_mri)) @@ -280,11 +280,13 @@ def choose_mri_series(sample_mri): with zipfile.ZipFile(local_path, 'r') as zip_ref: zip_ref.extractall(tmpdirname) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - return HTML(f''' -
+ display( + HTML(f'''
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:

{e.message}

-
''') +
'''), + ) + return filtered_dicoms = collections.defaultdict(list) series_descriptions = [] @@ -295,7 +297,7 @@ def choose_mri_series(sample_mri): series_descriptions.append(dcm.SeriesDescription) if 'cine_segmented_lax' in dcm.SeriesDescription.lower(): filtered_dicoms[dcm.SeriesDescription.lower()].append(dcm) - if 'cine_segmented_sax_inlinevf' == dcm.SeriesDescription.lower(): + if dcm.SeriesDescription.lower() == 'cine_segmented_sax_inlinevf': cur_angle = (dcm.InstanceNumber - 1) // MRI_FRAMES filtered_dicoms[f'{dcm.SeriesDescription.lower()}_angle_{str(cur_angle)}'].append(dcm) @@ -350,22 +352,20 @@ def choose_mri_series(sample_mri): ) display(viz_controls_ui, viz_controls_output) else: - print( - f'\n\nNeither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.', - '\n\nTry a different MRI.', + display( + HTML(f'''
+ Neither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}. + Try a different MRI. +
'''), ) - return None -def choose_cardiac_mri(sample_id, folder=None): +def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None) -> None: """Render widget to choose the cardiac MRI to plot. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - - Returns: - ipywidget or HTML upon error. """ if folder is None: folder = get_cardiac_mri_folder(sample_id) @@ -374,19 +374,23 @@ def choose_cardiac_mri(sample_id, folder=None): try: sample_mris = tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob)) except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - return HTML(f''' -
+ display( + HTML(f'''
Warning: Cardiac MRI not available for sample {sample_id} in {folder}:

{e.message}

Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
''') +
'''), + ) + return if not sample_mris: - return HTML(f''' -
+ display( + HTML(f'''
Warning: Cardiac MRI DICOM not available for sample {sample_id} in {folder}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket. -
''') +
'''), + ) + return mri_chooser = widgets.Dropdown( options=[(os.path.basename(mri), mri) for mri in sample_mris], diff --git a/ml4h/visualization_tools/ecg_interactive_plots.py b/ml4h/visualization_tools/ecg_interactive_plots.py index 97a4e1547..18ed39a9b 100644 --- a/ml4h/visualization_tools/ecg_interactive_plots.py +++ b/ml4h/visualization_tools/ecg_interactive_plots.py @@ -2,10 +2,12 @@ import os import tempfile +from typing import Optional, Union -import altair as alt # Interactive data visualization for plots. from IPython.display import HTML -from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME +import altair as alt # Interactive data visualization for plots. +from ml4h.TensorMap import TensorMap +from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP from ml4h.visualization_tools.ecg_reshape import reshape_exercise_ecg_to_tidy from ml4h.visualization_tools.ecg_reshape import reshape_resting_ecg_to_tidy @@ -31,18 +33,21 @@ ) -def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME): +def resting_ecg_interactive_plot( + sample_id: Union[int, str], folder: Optional[str] = None, + tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP, +) -> Union[HTML, alt.Chart]: """Wrangle resting ECG data to tidy and present it as an interactive plot. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap_name: The name of the TMAP to use for ecg input. + tmap: The TensorMap to use for ECG input. Returns: An Altair plot or a notebook-friendly error. """ - tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap_name) + tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap) if tidy_resting_ecg_signal.shape[0] == 0: return HTML(f'''
@@ -85,7 +90,9 @@ def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTI return upper & lower -def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds=10): +def exercise_ecg_interactive_plot( + sample_id: Union[int, str], folder: Optional[str] = None, time_interval_seconds: int = 10, +) -> Union[HTML, alt.Chart]: """Wrangle exercise ECG data to tidy and present it as an interactive plot. Args: @@ -140,7 +147,8 @@ def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds= lead_select, ).transform_filter( # https://github.com/altair-viz/altair/issues/1960 - f'((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})', + f'''((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) + && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})''', ) return trend.encode(y='heartrate:Q') & trend.encode(y='load:Q') & signal diff --git a/ml4h/visualization_tools/ecg_reshape.py b/ml4h/visualization_tools/ecg_reshape.py index b3213d359..167eb5012 100644 --- a/ml4h/visualization_tools/ecg_reshape.py +++ b/ml4h/visualization_tools/ecg_reshape.py @@ -1,53 +1,57 @@ """Methods for reshaping raw ECG signal data for use in the pandas ecosystem.""" import os import tempfile +from typing import Any, Dict, Optional, Tuple, Union +import numpy as np +import pandas as pd from biosppy.signals.tools import filter_signal import h5py from ml4h.defines import ECG_BIKE_LEADS from ml4h.defines import ECG_REST_LEADS from ml4h.runtime_data_defines import get_exercise_ecg_hd5_folder from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder -from ml4h.tensor_maps_by_hand import TMAPS -import numpy as np -import pandas as pd +from ml4h.TensorMap import TensorMap +import ml4h.tensormap.ukb.ecg as ecg_tmaps import tensorflow as tf RAW_SCALE = 0.005 # Convert to mV. SAMPLING_RATE = 500.0 -DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME = 'ecg_rest' +DEFAULT_RESTING_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_rest # TODO(deflaux): parameterize exercise ECG by TMAP name if there is similar ECG data from other studies. -EXERCISE_ECG_SIGNAL_TMAP = TMAPS['ecg-bike-raw-full'] +EXERCISE_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_bike_raw_full EXERCISE_ECG_TREND_TMAPS = [ - TMAPS['ecg-bike-raw-trend-hr'], - TMAPS['ecg-bike-raw-trend-load'], - TMAPS['ecg-bike-raw-trend-grade'], - TMAPS['ecg-bike-raw-trend-artifact'], - TMAPS['ecg-bike-raw-trend-mets'], - TMAPS['ecg-bike-raw-trend-pacecount'], - TMAPS['ecg-bike-raw-trend-phasename'], - TMAPS['ecg-bike-raw-trend-phasetime'], - TMAPS['ecg-bike-raw-trend-time'], - TMAPS['ecg-bike-raw-trend-vecount'], + ecg_tmaps.ecg_bike_raw_trend_hr, + ecg_tmaps.ecg_bike_raw_trend_load, + ecg_tmaps.ecg_bike_raw_trend_grade, + ecg_tmaps.ecg_bike_raw_trend_artifact, + ecg_tmaps.ecg_bike_raw_trend_mets, + ecg_tmaps.ecg_bike_raw_trend_pacecount, + ecg_tmaps.ecg_bike_raw_trend_phasename, + ecg_tmaps.ecg_bike_raw_trend_phasetime, + ecg_tmaps.ecg_bike_raw_trend_time, + ecg_tmaps.ecg_bike_raw_trend_vecount, ] EXERCISE_PHASES = {0.0: 'Pretest', 1.0: 'Exercise', 2.0: 'Recovery'} -def _examine_available_keys(hd5): +def _examine_available_keys(hd5: Dict[str, Any]) -> None: print(f'hd5 ECG keys {[k for k in hd5.keys() if "ecg" in k]}') for key in [k for k in hd5.keys() if 'ecg' in k]: - print(f'hd5 {key} keys {[k for k in hd5[key].keys()]}') + print(f'hd5 {key} keys {k for k in hd5[key]}') -def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME): +def reshape_resting_ecg_to_tidy( + sample_id: Union[int, str], folder: Optional[str] = None, tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP, +) -> pd.DataFrame: """Wrangle resting ECG data to tidy. Args: sample_id: The id of the ECG sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap_name: The name of the TMAP to use for ecg input. + tmap: The TensorMap to use for ECG input. Returns: A pandas dataframe in tidy format or print a notebook-friendly error and return an empty dataframe. @@ -55,7 +59,7 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN if folder is None: folder = get_resting_ecg_hd5_folder(sample_id) - data = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []} + data: Dict[str, Any] = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []} with tempfile.TemporaryDirectory() as tmpdirname: sample_hd5 = str(sample_id) + '.hd5' @@ -69,10 +73,10 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN with h5py.File(local_path, mode='r') as hd5: try: - signals = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5) + signals = tmap.tensor_from_file(tmap, hd5) except (KeyError, ValueError) as e: - print(f'''Warning: Resting ECG TMAP {tmap_name} not available for sample {sample_id}. - Use the tmap_name parameter to choose a different TMAP.\n\n{e}''') + print(f'''Warning: Resting ECG TMAP {tmap.name} not available for sample {sample_id}. + Use the tmap parameter to choose a different TMAP.\n\n{e}''') _examine_available_keys(hd5) return pd.DataFrame(data) for (lead, channel) in ECG_REST_LEADS.items(): @@ -136,7 +140,9 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN return tidy_signal_df -def reshape_exercise_ecg_to_tidy(sample_id, folder=None): +def reshape_exercise_ecg_to_tidy( + sample_id: Union[int, str], folder: Optional[str] = None, +) -> Tuple[pd.DataFrame, pd.DataFrame]: """Wrangle exercise ECG signal data to tidy format. Args: @@ -208,7 +214,9 @@ def reshape_exercise_ecg_to_tidy(sample_id, folder=None): return (trend_df, tidy_signal_df) -def reshape_exercise_ecg_and_trend_to_tidy(sample_id, folder=None): +def reshape_exercise_ecg_and_trend_to_tidy( + sample_id: Union[int, str], folder: Optional[str] = None, +) -> Tuple[pd.DataFrame, pd.DataFrame]: """Wrangle exercise ECG signal and trend data to tidy format. Args: diff --git a/ml4h/visualization_tools/ecg_static_plots.py b/ml4h/visualization_tools/ecg_static_plots.py index 2ebcfc3e1..ac7283237 100644 --- a/ml4h/visualization_tools/ecg_static_plots.py +++ b/ml4h/visualization_tools/ecg_static_plots.py @@ -1,17 +1,18 @@ """Methods for integration of static plots within notebooks.""" import os import tempfile +from typing import List, Optional, Union from IPython.display import HTML from IPython.display import SVG +import numpy as np from ml4h.plots import plot_ecg_rest from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder from ml4h.runtime_data_defines import get_resting_ecg_svg_folder -import numpy as np import tensorflow as tf -def display_resting_ecg(sample_id, folder=None): +def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None) -> Union[HTML, SVG]: """Retrieve (or render) and display the SVG of the resting ECG. Args: @@ -53,8 +54,8 @@ def display_resting_ecg(sample_id, folder=None): try: # We don't need the resulting SVG, so send it to a temporary directory. with tempfile.TemporaryDirectory() as tmpdirname: - plot_ecg_rest(tensor_paths = [local_path], rows=[0], out_folder=tmpdirname, is_blind=False) - except Exception as e: + return plot_ecg_rest(tensor_paths=[local_path], rows=[0], out_folder=tmpdirname, is_blind=False) + except Exception as e: # pylint: disable=broad-except return HTML(f'''
Warning: Unable to render static plot of resting ECG for sample {sample_id} from {hd5_folder}: @@ -62,7 +63,7 @@ def display_resting_ecg(sample_id, folder=None):
''') -def major_breaks_x_resting_ecg(limits): +def major_breaks_x_resting_ecg(limits: List[float]) -> np.array: """Method to compute breaks for plotnine plots of ECG resting data. Args: diff --git a/ml4h/visualization_tools/facets.py b/ml4h/visualization_tools/facets.py index a45ea88da..18f96327d 100644 --- a/ml4h/visualization_tools/facets.py +++ b/ml4h/visualization_tools/facets.py @@ -2,6 +2,7 @@ import base64 import os +import pandas as pd from facets_overview.generic_feature_statistics_generator import GenericFeatureStatisticsGenerator FACETS_DEPENDENCIES = { @@ -25,10 +26,10 @@ FACETS_DEPENDENCIES[dep] = os.path.basename(url) -class FacetsOverview(object): +class FacetsOverview(): """Methods for Facets Overview notebook integration.""" - def __init__(self, data): + def __init__(self, data: pd.DataFrame): # This takes the dataframe and computes all the inputs to the Facets # Overview plots such as: # - numeric variables: histogram bins, mean, min, median, max, etc.. @@ -39,7 +40,7 @@ def __init__(self, data): [{'name': 'data', 'table': data}], ) - def _repr_html_(self): + def _repr_html_(self) -> str: """Html representation of Facets Overview for use in a Jupyter notebook.""" protostr = base64.b64encode(self._proto.SerializeToString()).decode('utf-8') html_template = ''' @@ -57,14 +58,14 @@ def _repr_html_(self): return html -class FacetsDive(object): +class FacetsDive(): """Methods for Facets Dive notebook integration.""" - def __init__(self, data, height=1000): + def __init__(self, data: pd.DataFrame, height: int = 1000): self._data = data self.height = height - def _repr_html_(self): + def _repr_html_(self) -> str: """Html representation of Facets Dive for use in a Jupyter notebook.""" html_template = """ diff --git a/ml4h/visualization_tools/hd5_mri_plots.py b/ml4h/visualization_tools/hd5_mri_plots.py index 20b3305b1..d3894b39d 100644 --- a/ml4h/visualization_tools/hd5_mri_plots.py +++ b/ml4h/visualization_tools/hd5_mri_plots.py @@ -1,29 +1,34 @@ """Methods for integration of plots of mri data processed to 3D tensors from within notebooks.""" +from collections import OrderedDict from enum import Enum, auto import os import tempfile +from typing import Any, Dict, List, Optional, Tuple, Union -import h5py from IPython.display import display from IPython.display import HTML +import numpy as np +import h5py import ipywidgets as widgets import matplotlib.pyplot as plt from ml4h.runtime_data_defines import get_mri_hd5_folder -from ml4h.tensor_maps_by_hand import TMAPS -from ml4h.TensorMap import Interpretation -import numpy as np +import ml4h.tensormap.ukb.mri as ukb_mri +import ml4h.tensormap.ukb.mri_vtk as ukb_mri_vtk +from ml4h.TensorMap import Interpretation, TensorMap import tensorflow as tf -# Discover applicable TMAPS. -CARDIAC_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if ('_lax_' in k or '_sax_' in k) and TMAPS[k].axes() == 3] -CARDIAC_MRI_TMAP_NAMES.extend( - [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_cardiac_mri' and TMAPS[k].axes() == 3], +# Discover applicable TensorMaps. +MRI_TMAPS = { + key: value for key, value in ukb_mri.__dict__.items() if isinstance(value, TensorMap) + and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3 +} +MRI_TMAPS.update( + { + key: value for key, value in ukb_mri_vtk.__dict__.items() + if isinstance(value, TensorMap) and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3 + }, ) -LIVER_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_liver_mri' and TMAPS[k].axes() == 3] -BRAIN_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_brain_mri' and TMAPS[k].axes() == 3] -# This includes more than just MRI TMAPS, it is a best effort. -BEST_EFFORT_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].interpretation == Interpretation.CONTINUOUS and TMAPS[k].axes() == 3] MIN_IMAGE_WIDTH = 8 DEFAULT_IMAGE_WIDTH = 12 @@ -41,42 +46,30 @@ class PlotType(Enum): class TensorMapCache: """Cache the tensor to display for reuse when re-plotting the same TMAP with different plot parameters.""" - def __init__(self, hd5, tmap_name): + def __init__(self, hd5: Dict[str, Any], tmap: TensorMap): self.hd5 = hd5 - self.tmap_name = None + self.tmap: Optional[TensorMap] = None self.tensor = None - _ = self.get(tmap_name) + _ = self.get(tmap) - def get(self, tmap_name): - if self.tmap_name != tmap_name: - self.tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], self.hd5) - self.tmap_name = tmap_name + def get(self, tmap: TensorMap) -> np.array: + if self.tmap != tmap: + self.tensor = tmap.tensor_from_file(tmap, self.hd5) + self.tmap = tmap return self.tensor -def choose_cardiac_mri_tmap(sample_id, folder=None, tmap_name='cine_lax_4ch_192', default_tmap_names=CARDIAC_MRI_TMAP_NAMES): - choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) - - -def choose_brain_mri_tmap(sample_id, folder=None, tmap_name='t2_flair_sag_p2_1mm_fs_ellip_pf78_1', default_tmap_names=BRAIN_MRI_TMAP_NAMES): - choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) - - -def choose_liver_mri_tmap(sample_id, folder=None, tmap_name='liver_shmolli_segmented', default_tmap_names=LIVER_MRI_TMAP_NAMES): - choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names) - - -def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=BEST_EFFORT_MRI_TMAP_NAMES): +def choose_mri_tmap( + sample_id: Union[int, str], folder: Optional[str] = None, tmap: Optional[TensorMap] = None, + default_tmaps: Dict[str, TensorMap] = MRI_TMAPS, +) -> None: """Render widgets and plots for MRI tensors. Args: sample_id: The id of the sample to retrieve. folder: The local or Cloud Storage folder under which the files reside. - tmap_name: The TMAP name for the 3D MRI tensor to visualize. - default_tmap_names: Other TMAP names to offer for visualization, if present in the hd5. - - Returns: - ipywidget or HTML upon error. + tmap: The TensorMap for the 3D MRI tensor to visualize. + default_tmaps: Other TensorMaps to offer for visualization, if present in the hd5. """ if folder is None: folder = get_mri_hd5_folder(sample_id) @@ -88,42 +81,45 @@ def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=B tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path) hd5 = h5py.File(local_path, mode='r') except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e: - return HTML(f''' -
+ display( + HTML(f'''
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}:

{e.message}

Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket. -
''') - - sample_tmap_names = [] - # Add the passed tmap_name parameter, if it is present in this hd5. - if tmap_name: - if TMAPS[tmap_name].hd5_key_guess() in hd5: - if len(TMAPS[tmap_name].shape) == 3: - sample_tmap_names.append(tmap_name) +
'''), + ) + return + + sample_tmaps = OrderedDict() + # Add the passed tmap parameter, if it is present in this hd5. + if tmap: + if tmap.hd5_key_guess() in hd5: + if len(tmap.shape) == 3: + sample_tmaps[tmap.name] = tmap else: - print(f'{tmap_name} is not a 3D tensor, skipping it') + print(f'{tmap} is not a 3D tensor, skipping it') else: - print(f'{tmap_name} is not available in {sample_id}') - # Also discover applicable TMAPS for this particular sample's HD5 file. - sample_tmap_names.extend( - sorted(set([k for k in default_tmap_names if TMAPS[k].hd5_key_guess() in hd5])), - ) - - if not sample_tmap_names: - return HTML(f'''
- Neither {tmap_name} nor any of {default_tmap_names} are present in this HD5 for sample {sample_id} in {folder}. - Use the tmap_name parameter to try a different TMAP or the folder parameter to try a different hd5 for the sample. -
''') - - default_tmap_name_value = sample_tmap_names[0] + print(f'{tmap} is not available in {sample_id}') + # Also discover applicable TensorMaps for this particular sample's HD5 file. + sample_tmaps.update({n: t for n, t in sorted(default_tmaps.items(), key=lambda t: t[0]) if t.hd5_key_guess() in hd5}) + + if not sample_tmaps: + display( + HTML(f'''
+ Neither {tmap.name} nor any of {default_tmaps.keys()} are present in this HD5 for sample {sample_id} in {folder}. + Use the tmap parameter to try a different TensorMap or the folder parameter to try a different hd5 for the sample. +
'''), + ) + return + + default_tmap_value = next(iter(sample_tmaps.values())) # Display the middle instance by default in the interactive view. - default_instance_value, max_instance_value = compute_instance_range(default_tmap_name_value) - default_vmin_value, default_vmax_value = compute_color_range(hd5, default_tmap_name_value) + default_instance_value, max_instance_value = compute_instance_range(default_tmap_value) + default_vmin_value, default_vmax_value = compute_color_range(hd5, default_tmap_value) - tmap_name_chooser = widgets.Dropdown( - options=sample_tmap_names, - value=default_tmap_name_value, + tmap_chooser = widgets.Dropdown( + options=sample_tmaps, + value=default_tmap_value, description='Choose the MRI tensor TMAP name to visualize:', style={'description_width': 'initial'}, layout=widgets.Layout(width='900px'), @@ -174,20 +170,20 @@ def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=B viz_controls_ui = widgets.VBox( [ widgets.HTML('

Visualization controls

'), - tmap_name_chooser, + tmap_chooser, widgets.HBox([transpose_chooser, fig_width_chooser]), widgets.HBox([flip_chooser, color_range_chooser]), widgets.HBox([plot_type_chooser, instance_chooser]), ], layout=widgets.Layout(width='auto', border='solid 1px grey'), ) - tmap_cache = TensorMapCache(hd5=hd5, tmap_name=tmap_name_chooser.value) + tmap_cache = TensorMapCache(hd5=hd5, tmap=tmap_chooser.value) viz_controls_output = widgets.interactive_output( plot_mri_tmap, { 'sample_id': widgets.fixed(sample_id), 'tmap_cache': widgets.fixed(tmap_cache), - 'tmap_name': tmap_name_chooser, + 'tmap': tmap_chooser, 'plot_type': plot_type_chooser, 'instance': instance_chooser, 'color_range': color_range_chooser, @@ -209,33 +205,36 @@ def on_plot_type_change(change): else: instance_chooser.layout.visibility = 'hidden' - tmap_name_chooser.observe(on_tmap_value_change, names='value') + tmap_chooser.observe(on_tmap_value_change, names='value') plot_type_chooser.observe(on_plot_type_change, names='value') display(viz_controls_ui, viz_controls_output) -def compute_color_range(hd5, tmap_name): +def compute_color_range(hd5: Dict[str, Any], tmap: TensorMap) -> List[int]: """Compute the mean values for the color ranges of instances in the MRI series.""" - mri_tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5) + mri_tensor = tmap.tensor_from_file(tmap, hd5) vmin = np.mean([np.min(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])]) vmax = np.mean([np.max(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])]) - return[vmin, vmax] + return [vmin, vmax] -def compute_instance_range(tmap_name): +def compute_instance_range(tmap: TensorMap) -> Tuple[int, int]: """Compute middle and max instances.""" - middle_instance = int(TMAPS[tmap_name].shape[2] / 2) - max_instance = TMAPS[tmap_name].shape[2] - return(middle_instance, max_instance) + middle_instance = int(tmap.shape[2] / 2) + max_instance = tmap.shape[2] + return (middle_instance, max_instance) -def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_range, transpose, flip, fig_width): +def plot_mri_tmap( + sample_id: Union[int, str], tmap_cache: TensorMapCache, tmap: TensorMap, plot_type: PlotType, + instance: int, color_range: Tuple[int, int], transpose: bool, flip: bool, fig_width: int, +) -> None: """Visualize the applicable MRI series within this HD5 file. Args: sample_id: The local or Cloud Storage path to the MRI file. tmap_cache: The cache from which to retrieve the tensor to be plotted. - tmap_name: The name of the chosen TMAP for the MRI series. + tmap: The chosen TensorMap for the MRI series. plot_type: Whether to display instances interactively or in a panel view. instance: The particular instance to display, if interactive. color_range: Array of minimum and maximum value for the color range. @@ -243,12 +242,9 @@ def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_r flip: Whether to flip the image on its vertical axis fig_width: The desired width of the figure. Note that height computed as the proportion of the width based on the data to be plotted. - - Returns: - The plot or a notebook-friendly error message. """ - title_prefix = f'{tmap_name} from MRI {sample_id}' - mri_tensor = tmap_cache.get(tmap_name) + title_prefix = f'{tmap.name} from MRI {sample_id}' + mri_tensor = tmap_cache.get(tmap) if plot_type == PlotType.INTERACTIVE: plot_mri_tensor_as_animation( mri_tensor=mri_tensor, @@ -275,10 +271,13 @@ def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_r title_prefix=title_prefix, ) else: - return HTML(f'''
Invalid plot type: {plot_type}
''') + HTML(f'''
Invalid plot type: {plot_type}
''') -def plot_mri_tensor_as_panels(mri_tensor, vmin, vmax, transpose=False, flip=False, fig_width=DEFAULT_IMAGE_WIDTH, title_prefix=''): +def plot_mri_tensor_as_panels( + mri_tensor: np.array, vmin: int, vmax: int, transpose: bool = False, flip: bool = False, + fig_width: int = DEFAULT_IMAGE_WIDTH, title_prefix: str = '', +) -> None: """Visualize an MRI series from a 3D tensor as a panel of static plots. Args: @@ -314,7 +313,7 @@ def plot_mri_tensor_as_panels(mri_tensor, vmin, vmax, transpose=False, flip=Fals axes[row, col].set_yticklabels([]) axes[row, col].set_xticklabels([]) fig.suptitle( - f'{title_prefix}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', + f'{title_prefix}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', # pylint: disable=line-too-long fontsize=fig_width, ) fig.subplots_adjust( @@ -326,7 +325,11 @@ def plot_mri_tensor_as_panels(mri_tensor, vmin, vmax, transpose=False, flip=Fals ) -def plot_mri_tensor_as_animation(mri_tensor, instance, vmin, vmax, transpose=False, flip=False, fig_width=DEFAULT_IMAGE_WIDTH, title_prefix=''): +def plot_mri_tensor_as_animation( + mri_tensor: np.array, instance: int, vmin: int, vmax: int, + transpose: bool = False, flip: bool = False, + fig_width: int = DEFAULT_IMAGE_WIDTH, title_prefix: str = '', +) -> None: """Visualize an MRI series from a 3D tensor as an animation rendered one panel at a time. Args: @@ -358,7 +361,7 @@ def plot_mri_tensor_as_animation(mri_tensor, instance, vmin, vmax, transpose=Fal _, ax = plt.subplots(figsize=(fig_width, fig_height), facecolor='beige') ax.imshow(pixels, cmap='gray', vmin=vmin, vmax=vmax) ax.set_title( - f'{title_prefix}, Instance: {instance}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', + f'{title_prefix}, Instance: {instance}\nColor range: {vmin}-{vmax}, Transpose: {transpose}, Flip: {flip}, Figure size:{fig_width}x{fig_height}', # pylint: disable=line-too-long fontsize=fig_width, ) ax.set_yticklabels([]) From 6ff7a2d2ca558c2f77e4cf5af94553b0fb4379a2 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Tue, 29 Sep 2020 17:28:20 -0400 Subject: [PATCH 05/21] paired --- .../paired_multimodal_autoencoder.ipynb | 113 ++-------- .../paired_multimodal_segmenter_mri_ecg.ipynb | 210 ++++++------------ notebooks/autoencoders/vae_mri_slice.ipynb | 2 +- notebooks/mnist_demo.ipynb | 90 +------- .../mri/mri_cardiac_long_axis_sketch.ipynb | 8 +- .../mri/mri_cardiac_short_axis_sketch.ipynb | 14 +- .../identify_a_sample_to_review.ipynb | 6 +- .../review_results/review_one_sample.ipynb | 8 +- ...handling_for_notebook_visualizations.ipynb | 40 +--- ...ntify_a_sample_to_review_interactive.ipynb | 4 +- .../review_one_sample_interactive.ipynb | 4 +- 11 files changed, 135 insertions(+), 364 deletions(-) diff --git a/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb b/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb index bd58f6d25..67870c053 100644 --- a/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb +++ b/notebooks/autoencoders/paired_multimodal_autoencoder.ipynb @@ -328,11 +328,9 @@ ") -> Model:\n", " inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in args.tensor_maps_in}\n", " original_outputs = {tm:1 for tm in args.tensor_maps_out}\n", - " real_serial_layers = kwargs['model_layers']\n", - " args.model_layers = None\n", " multimodal_activations = []\n", - " encoders = {}\n", - " decoders = {}\n", + " desired_distance_tm = []\n", + " my_metrics = {}\n", " outputs = []\n", " losses = []\n", " for left, right in pairs:\n", @@ -347,22 +345,15 @@ " h_right = encode_right(inputs[right]) \n", " \n", " if pair_loss == 'cosine':\n", - " loss_layer = CosineLossLayer(1.0)\n", + " loss_layer = CosineLossLayer(100.0)\n", " elif pair_loss == 'euclid':\n", - " loss_layer = L2LossLayer(1.0)\n", + " loss_layer = L2LossLayer(100.0)\n", " \n", " paired_embeddings = loss_layer([h_left, h_right])\n", " multimodal_activations.extend(paired_embeddings)\n", - " if left not in encoders:\n", - " encoders[left] = encode_left\n", - " if right not in encoders:\n", - " encoders[right] = encode_right \n", " \n", " multimodal_activation = Concatenate()(multimodal_activations)\n", - " encoder = Model(inputs=list(inputs.values()), outputs=[multimodal_activation], name='encoder')\n", " \n", - " # build decoder models\n", - " latent_inputs = Input(shape=(args.dense_layers[0]*len(inputs)), name='input_concept_space')\n", " pre_decoder_shapes: Dict[TensorMap, Optional[Tuple[int, ...]]] = {}\n", " for tm in args.tensor_maps_out:\n", " shape = _calc_start_shape(num_upsamples=len(args.dense_blocks), output_shape=tm.shape, \n", @@ -389,26 +380,24 @@ " upsample_z=args.pool_z,\n", " )\n", " \n", - " reconstruction = decode(restructure(latent_inputs))\n", - " decoder = Model(latent_inputs, reconstruction, name=tm.output_name())\n", - " decoders[tm] = decoder\n", - " outputs.append(decoder(multimodal_activation))\n", + " outputs.append(decode(restructure(multimodal_activation)))\n", " losses.append(tm.loss)\n", "\n", - " args.tensor_maps_out = list(original_outputs.keys())\n", + " args.tensor_maps_out = list(original_outputs.keys()) + desired_distance_tm\n", " args.tensor_maps_in = list(inputs.keys())\n", " \n", + " opt = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n", + " #outputs.reverse() # Make paired loss last\n", + " #losses.reverse()\n", " m = Model(inputs=list(inputs.values()), outputs=outputs)\n", - " my_metrics = {tm.output_name(): tm.metrics for tm in args.tensor_maps_out}\n", - " opt = Adam(lr=kwargs['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n", - " m.compile(optimizer=opt, loss=losses, metrics=my_metrics)\n", + " m.compile(optimizer=opt, loss=losses)\n", " m.summary()\n", " \n", - " if real_serial_layers is not None:\n", + " if kwargs['model_layers'] is not None:\n", " m.load_weights(kwargs['model_layers'], by_name=True)\n", " print(f\"Loaded model weights from:{kwargs['model_layers']}\")\n", " \n", - " return m, encoders, decoders" + " return m" ] }, { @@ -418,7 +407,7 @@ "outputs": [], "source": [ "sys.argv = ['train', \n", - " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", + " '--tensors', '/mnt/disks/segmented-sax-lax/2020-07-07/', \n", " '--input_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d', \n", " '--output_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d',\n", " '--activation', 'swish',\n", @@ -428,13 +417,13 @@ " '--conv_z', '3', '3', '3', \n", " '--dense_blocks', '32', '32', '32',\n", " '--block_size', '3',\n", - " '--dense_layers', '64',\n", + " '--dense_layers', '512',\n", " '--pool_x', '2',\n", " '--pool_y', '2',\n", " '--batch_size', '1',\n", " '--patience', '32',\n", - " '--epochs', '292',\n", - " '--learning_rate', '0.0001',\n", + " '--epochs', '248',\n", + " '--learning_rate', '0.001',\n", " '--training_steps', '256',\n", " '--validation_steps', '30',\n", " '--test_steps', '2',\n", @@ -444,17 +433,12 @@ " '--id', 'lax_2ch_3ch_diastole_pair_cosine_loss']\n", "args = parse_args()\n", "pairs = [(args.tensor_maps_in[0], args.tensor_maps_in[1])]\n", - "overparameterized_model, encoders, decoders = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", + "overparameterized_model = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", "train_model_from_generators(\n", " overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", " args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", - " plot=False, save_last_model=True\n", - ")\n", - "for tm in encoders:\n", - " encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5')\n", - "for tm in decoders:\n", - " decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5')" + ")" ] }, { @@ -518,55 +502,6 @@ "latent_df.info()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "out_path = os.path.join(args.output_folder, args.id + '/')\n", - "test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps*5)\n", - "print(list(test_data.keys()))\n", - "\n", - "preds = overparameterized_model.predict(test_data)\n", - "print([p.shape for p in preds])\n", - "print([tm.name for tm in args.tensor_maps_out])\n", - "print(test_paths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ml4h.plots import _plot_reconstruction\n", - "_plot_reconstruction(args.tensor_maps_out[0], test_data['input_lax_2ch_diastole_slice0_3d_continuous'], preds[0], out_path, test_paths, num_samples=2)\n", - "from ml4h.explorations import predictions_to_pngs\n", - "predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_out, \n", - " test_data, test_labels, test_paths, out_path)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, etm in enumerate(encoders):\n", - " embed = encoders[etm].predict(test_data[etm.input_name()])\n", - " double = np.tile(embed, 2)\n", - " print(f'embed shape: {embed.shape} double shape: {double.shape}')\n", - " for dtm in decoders:\n", - " predictions = decoders[dtm].predict(double)\n", - " print(f'prediction shape: {predictions.shape}')\n", - " out_path = os.path.join(args.output_folder, args.id, f'decoding_{dtm.name}_from_{etm.name}/')\n", - " if not os.path.exists(os.path.dirname(out_path)):\n", - " os.makedirs(os.path.dirname(out_path))\n", - " _plot_reconstruction(dtm, test_data[dtm.input_name()], predictions.copy(), out_path, test_paths, 8)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -776,18 +711,6 @@ "display_name": "Python 3", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" } }, "nbformat": 4, diff --git a/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb b/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb index 3ea3e32a2..442e27a18 100644 --- a/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb +++ b/notebooks/autoencoders/paired_multimodal_segmenter_mri_ecg.ipynb @@ -38,7 +38,6 @@ "# ml4h Imports\n", "from ml4h.TensorMap import TensorMap\n", "from ml4h.arguments import parse_args\n", - "from ml4h.plots import _plot_reconstruction\n", "from ml4h.models import make_multimodal_multitask_model, train_model_from_generators, make_hidden_layer_model, _conv_layer_from_kind_and_dimension\n", "from ml4h.tensor_generators import TensorGenerator, big_batch_from_minibatch_generator, test_train_valid_tensor_generators\n", "from ml4h.recipes import plot_predictions, infer_hidden_layer_multimodal_multitask\n", @@ -272,6 +271,13 @@ " return norm\n", "\n", "def pairwise_cosine_difference(t1, t2):\n", + " \"\"\"\n", + " A [batch x n x d] tensor of n rows with d dimensions\n", + " B [batch x m x d] tensor of n rows with d dimensions\n", + "\n", + " returns:\n", + " D [batch x n x m] tensor of cosine similarity scores between each point i Model:\n", " inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in args.tensor_maps_in}\n", " original_outputs = {tm:1 for tm in args.tensor_maps_out}\n", - " real_serial_layers = kwargs['model_layers']\n", - " args.model_layers = None\n", " multimodal_activations = []\n", - " encoders = {}\n", - " decoders = {}\n", " outputs = []\n", " losses = []\n", " for left, right in pairs:\n", @@ -346,16 +348,9 @@ " \n", " paired_embeddings = loss_layer([h_left, h_right])\n", " multimodal_activations.extend(paired_embeddings)\n", - " if left not in encoders:\n", - " encoders[left] = encode_left\n", - " if right not in encoders:\n", - " encoders[right] = encode_right \n", " \n", " multimodal_activation = Concatenate()(multimodal_activations)\n", - " encoder = Model(inputs=list(inputs.values()), outputs=[multimodal_activation], name='encoder')\n", " \n", - " # build decoder models\n", - " latent_inputs = Input(shape=(args.dense_layers[0]*len(inputs)), name='input_concept_space')\n", " pre_decoder_shapes: Dict[TensorMap, Optional[Tuple[int, ...]]] = {}\n", " for tm in args.tensor_maps_out:\n", " shape = _calc_start_shape(num_upsamples=len(args.dense_blocks), output_shape=tm.shape, \n", @@ -382,10 +377,7 @@ " upsample_z=args.pool_z,\n", " )\n", " \n", - " reconstruction = decode(restructure(latent_inputs))\n", - " decoder = Model(latent_inputs, reconstruction, name=tm.output_name())\n", - " decoders[tm] = decoder\n", - " outputs.append(decoder(multimodal_activation))\n", + " outputs.append(decode(restructure(multimodal_activation)))\n", " losses.append(tm.loss)\n", "\n", " args.tensor_maps_out = list(original_outputs.keys())\n", @@ -397,11 +389,11 @@ " m.compile(optimizer=opt, loss=losses, metrics=my_metrics)\n", " m.summary()\n", " \n", - " if real_serial_layers is not None:\n", + " if kwargs['model_layers'] is not None:\n", " m.load_weights(kwargs['model_layers'], by_name=True)\n", " print(f\"Loaded model weights from:{kwargs['model_layers']}\")\n", " \n", - " return m, encoders, decoders" + " return m" ] }, { @@ -414,90 +406,35 @@ " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", " '--input_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole', \n", " '--output_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole',\n", - " '--activation', 'selu',\n", + " '--activation', 'swish',\n", " '--conv_layers', '32',\n", - " '--conv_x', '15', '15', '15',\n", + " '--conv_x', '9', '9', '9',\n", " '--conv_y', '3', '3', '3', \n", " '--conv_z', '3', '3', '3',\n", " '--dense_blocks', '32', '32', '32',\n", " '--block_size', '3',\n", - " '--dense_layers', '64',\n", + " '--dense_layers', '512',\n", " '--pool_x', '2',\n", " '--pool_y', '2',\n", " '--batch_size', '1',\n", - " '--patience', '94',\n", - " '--epochs', '396',\n", - " '--learning_rate', '0.00005',\n", + " '--patience', '44',\n", + " '--epochs', '496',\n", + " '--learning_rate', '0.0001',\n", " '--training_steps', '128',\n", " '--validation_steps', '30',\n", " '--test_steps', '8',\n", " '--num_workers', '4',\n", " '--inspect_model',\n", " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", - " '--model_layers', './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.h5',\n", - " '--id', 'paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu']\n", + " '--id', 'ecg_mri_lax_4ch_diastole_euclid_paired_segmenter_512d']\n", "args = parse_args()\n", "pairs = [(args.tensor_maps_in[0], args.tensor_maps_in[1])]\n", - "overparameterized_model, encoders, decoders = make_paired_autoencoder_model(pairs, pair_loss='cosine', **args.__dict__)\n", + "overparameterized_model = make_paired_autoencoder_model(pairs, pair_loss='euclid', **args.__dict__)\n", "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", "train_model_from_generators(\n", " overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", " args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", - " plot=False, save_last_model=True\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for tm in encoders:\n", - " encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5')\n", - "for tm in decoders:\n", - " decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sys.argv = ['train', \n", - " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", - " '--input_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole', \n", - " '--output_tensors', 'ecg.ecg_rest', 'mri.cine_segmented_lax_4ch_diastole',\n", - " '--activation', 'swish',\n", - " '--conv_layers', '32',\n", - " '--conv_x', '9', '9', '9',\n", - " '--conv_y', '3', '3', '3', \n", - " '--conv_z', '3', '3', '3', \n", - " '--dense_blocks', '32', '32', '32',\n", - " '--block_size', '3',\n", - " '--dense_layers', '256',\n", - " '--pool_x', '2',\n", - " '--pool_y', '2',\n", - " '--batch_size', '1',\n", - " '--patience', '44',\n", - " '--epochs', '532',\n", - " '--learning_rate', '0.0002',\n", - " '--training_steps', '72',\n", - " '--validation_steps', '30',\n", - " '--test_steps', '8',\n", - " '--num_workers', '4',\n", - " '--inspect_model',\n", - " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", - " '--hidden_layer', 'concatenate_36',\n", - " '--model_file', './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.h5',\n", - " #'--sample_csv', '/home/sam/lvh/lvh_hold_out.txt',\n", - " '--id', 'paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu']\n", - "args = parse_args()\n", - "#overparameterized_model = make_multimodal_multitask_model(**args.__dict__)\n", - "generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", - "#plot_predictions(args)\n", - "#infer_hidden_layer_multimodal_multitask(args)" + ")" ] }, { @@ -543,17 +480,7 @@ "metadata": {}, "outputs": [], "source": [ - "for i, etm in enumerate(encoders):\n", - " embed = encoders[etm].predict(test_data[etm.input_name()])\n", - " double = np.tile(embed, 2)\n", - " print(f'embed shape: {embed.shape} double shape: {double.shape}')\n", - " for dtm in decoders:\n", - " predictions = decoders[dtm].predict(double)\n", - " print(f'prediction shape: {predictions.shape}')\n", - " out_path = os.path.join(args.output_folder, args.id, f'decoding_{dtm.name}_from_{etm.name}/')\n", - " if not os.path.exists(os.path.dirname(out_path)):\n", - " os.makedirs(os.path.dirname(out_path))\n", - " _plot_reconstruction(dtm, test_data[dtm.input_name()], predictions.copy(), out_path, test_paths, 8)" + "print(list(test_data['input_strip_continuous'].shape))" ] }, { @@ -588,7 +515,7 @@ " if i % sample_every == 0 and col < samples:\n", " for j in range(rows):\n", " if len(test_data[test_key].shape) == 4:\n", - " axes[j, col].imshow(test_data[test_key][j, :, :, 0], cmap = 'plasma')\n", + " axes[j, col].imshow(test_data[test_key][j, :, :, 0], cmap = 'gray')\n", " axes[j, col].set_yticks(())\n", " elif len(test_data[test_key].shape) == 3:\n", " for l in range(12):\n", @@ -614,14 +541,14 @@ "\n", "\n", "test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)\n", - "test_key = 'input_cine_segmented_lax_4ch_diastole_categorical'\n", + "test_key = 'input_lax_4ch_diastole_slice0_224_3d_continuous'\n", "test_shape = test_data[test_key].shape\n", "test_data[test_key] = np.random.random(test_shape)\n", "out_path = os.path.join(args.output_folder, args.id, test_key + '_noise/')\n", "if not os.path.exists(os.path.dirname(out_path)):\n", " os.makedirs(os.path.dirname(out_path))\n", "noise_preds = plot_ae_towards_attractor(overparameterized_model, test_data, test_labels, test_key, \n", - " test_index=1, rows=8, samples=4, steps = 28)\n", + " test_index=1, rows=8, samples=4, steps = 18)\n", "print(list(test_data.keys()))\n", "_plot_reconstruction(args.tensor_maps_out[0], test_data['input_strip_continuous'], \n", " noise_preds[0], out_path, test_paths)\n", @@ -665,10 +592,7 @@ "df = pd.read_csv('/home/sam/ml/trained_models/lax_4ch_diastole_autoencode_leaky_converge/tensors_all_union.csv')\n", "df['21003_Age-when-attended-assessment-centre_2_0'].plot.hist(bins=30)\n", "hidden_inference = './recipes_output/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples/hidden_inference_ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples.tsv'\n", - "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_euclid_256d_swish/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_euclid_256d_swish.tsv'\n", - "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_256d_selu/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_cosine_256d_selu.tsv'\n", "\n", - "hidden_inference = './recipes_output/paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu/hidden_inference_paired_ecg_segmented_mri_lax_4ch_diastole_cosine_64d_selu.tsv'\n", "\n", "df2 = pd.read_csv(hidden_inference, sep='\\t')\n", "df['fpath'] = pd.to_numeric(df['fpath'], errors='coerce')\n", @@ -790,7 +714,7 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", - "latent_dimension = 128\n", + "latent_dimension = 256\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", "for strat in ['Sex_Female_0_0', 'has_ttntv', 'atrial_fibrillation_or_flutter', \n", @@ -818,7 +742,7 @@ " stratify_latent_space(strat, 1.0, latent_cols, latent_df)\n", "strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - "theshes = [45, 100, 150, 3500000, 27.5, 65]\n", + "theshes = [45, 100, 150, 3500000, 27.5, 70]\n", "for strat, thresh in zip(strats, theshes):\n", " stratify_latent_space(strat, thresh, latent_cols, latent_df)" ] @@ -832,12 +756,12 @@ "latent_dimension = 512\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", - "for strat in ['Sex_Female_0_0', 'has_ttntv', 'atrial_fibrillation_or_flutter', \n", + "for strat in ['Sex_Female_0_0', 'atrial_fibrillation_or_flutter', \n", " 'coronary_artery_disease', 'hypertension']:\n", " stratify_latent_space(strat, 1.0, latent_cols, latent_df)\n", "strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - "theshes = [45, 100, 150, 3500000, 27.5, 65]\n", + "theshes = [45, 100, 150, 3500000, 27.5, 70]\n", "for strat, thresh in zip(strats, theshes):\n", " stratify_latent_space(strat, thresh, latent_cols, latent_df)" ] @@ -851,11 +775,11 @@ "latent_dimension = 256\n", "latent_cols = [f'latent_{256+i}' for i in range(latent_dimension)]\n", "pca, matrix_reduce = pca_on_matrix(df2[latent_cols].to_numpy(), 10)\n", - "c_strats = [ 'Sex_Female_0_0', 'has_ttntv']\n", + "c_strats = [ 'Sex_Female_0_0']\n", "for c_strat in c_strats:\n", " strats = ['LVEF', 'LVM', 'LVEDV', 'sample_id',\n", " '21001_Body-mass-index-BMI_0_0', '21003_Age-when-attended-assessment-centre_2_0']\n", - " theshes = [50, 100, 150, 3750000, 27.5, 65]\n", + " theshes = [50, 100, 150, 3750000, 27.5, 70]\n", " for strat, thresh in zip(strats, theshes):\n", " directions_in_latent_space(c_strat, 1.0, strat, thresh, latent_cols, latent_df)" ] @@ -900,8 +824,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"{np.mean(np.sqrt(np.einsum('ij, ij->ij', ecg_encode, ecg_encode)))}\")\n", - "print(f\"{np.mean(np.sqrt(np.einsum('ij, ij->ij', mri_encode, mri_encode)))}\")" + "print(f'{ecg_encode[:5,:5]} \\n{mri_encode[:5,:5]}')" ] }, { @@ -913,7 +836,7 @@ "latent_dimension = 256\n", "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", "ecg_encode = latent_df[latent_cols].to_numpy()\n", - "latent_cols = [f'latent_{250+i}' for i in range(latent_dimension)]\n", + "latent_cols = [f'latent_{18+i}' for i in range(latent_dimension)]\n", "mri_encode = latent_df[latent_cols].to_numpy()\n", "diff = np.sqrt(np.einsum('ij, ij->ij', ecg_encode - mri_encode, ecg_encode - mri_encode))\n", "print(diff.shape) \n", @@ -926,8 +849,8 @@ "metadata": {}, "outputs": [], "source": [ - "ch2_random = np.random.random((8520, 256))\n", - "ch3_random = np.random.random((8520, 256))\n", + "ch2_random = np.random.random((4452, 256))\n", + "ch3_random = np.random.random((4452, 256))\n", "diff = np.sqrt(np.einsum('ij, ij->ij', ch2_random - ch3_random, ch2_random - ch3_random))\n", "print(diff.shape) \n", "print(np.mean(diff))" @@ -939,35 +862,44 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dimension = 128\n", - "latent_cols = [f'latent_{i}' for i in range(latent_dimension)]\n", - "all_encode = latent_df[latent_cols].to_numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ml4h.plots import _plot_reconstruction\n", - "for tm in decoders:\n", - " predictions = decoders[tm].predict(all_encode[:4])\n", - " print(predictions.shape)\n", - " out_path = os.path.join(args.output_folder, args.id, 'decodings/')\n", - " if not os.path.exists(os.path.dirname(out_path)):\n", - " os.makedirs(os.path.dirname(out_path))\n", - " samples = list(map(str, list(latent_df['sample_id'])[:10]))\n", - " _plot_reconstruction(tm, predictions, predictions, out_path, samples)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print()" + "sys.argv = ['train', \n", + " '--tensors', '/mnt/disks/sax-lax-40k-lvm/2020-01-29/', \n", + " '--input_tensors', 'ecg.ecg_rest', 'mri.lax_4ch_diastole_slice0_224_3d', \n", + " '--output_tensors', 'ecg.ecg_rest', 'mri.lax_4ch_diastole_slice0_224_3d',\n", + " '--activation', 'swish',\n", + " '--conv_layers', '32',\n", + " '--conv_x', '9', '9', '9',\n", + " '--conv_y', '3', '3', '3', \n", + " '--conv_z', '3', '3', '3', \n", + " '--dense_blocks', '32', '32', '32',\n", + " '--block_size', '3',\n", + " '--dense_layers', '256',\n", + " '--pool_x', '2',\n", + " '--pool_y', '2',\n", + " '--batch_size', '1',\n", + " '--patience', '44',\n", + " '--epochs', '532',\n", + " '--learning_rate', '0.0002',\n", + " '--training_steps', '72',\n", + " '--validation_steps', '30',\n", + " '--test_steps', '8',\n", + " '--num_workers', '4',\n", + " '--inspect_model',\n", + " '--tensormap_prefix', 'ml4h.tensormap.ukb',\n", + " '--hidden_layer', 'concatenate_36',\n", + " '--model_file', './recipes_output/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples/ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples.h5',\n", + " '--train_csv', '/home/sam/lvh/small_set.csv',\n", + " #'--sample_csv', '/home/sam/lvh/lvh_hold_out.txt',\n", + " '--id', 'ecg_mri_lax_4ch_diastole_paired_autoencoder_2blocks_256d_200_samples']\n", + "args = parse_args()\n", + "\n", + "#overparameterized_model = make_multimodal_multitask_model(**args.__dict__)\n", + "#infer_hidden_layer_multimodal_multitask(args)\n", + "#generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)\n", + "# train_model_from_generators(\n", + "# overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,\n", + "# args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,\n", + "# )" ] }, { diff --git a/notebooks/autoencoders/vae_mri_slice.ipynb b/notebooks/autoencoders/vae_mri_slice.ipynb index 0aa84e195..21785359e 100644 --- a/notebooks/autoencoders/vae_mri_slice.ipynb +++ b/notebooks/autoencoders/vae_mri_slice.ipynb @@ -286,7 +286,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/notebooks/mnist_demo.ipynb b/notebooks/mnist_demo.ipynb index b4fc5bcbf..88393f90f 100644 --- a/notebooks/mnist_demo.ipynb +++ b/notebooks/mnist_demo.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ "from tensorflow.keras.layers import Dense, Conv2D, Flatten\n", "\n", "from ml4h.defines import StorageType\n", - "from ml4h.arguments import parse_args\n", + "from ml4h.arguments import parse_args, TMAPS, _get_tmap\n", "from ml4h.TensorMap import TensorMap, Interpretation\n", "from ml4h.tensor_generators import test_train_valid_tensor_generators\n", "from ml4h.models import train_model_from_generators, make_multimodal_multitask_model, _inspect_model\n", @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -113,28 +113,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading data...\n", - "(50000, 784)\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "plot_mnist(4)" ] @@ -158,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,27 +158,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading data...\n", - "Wrote 5000 MNIST images and labels as HD5 files\n", - "Wrote 10000 MNIST images and labels as HD5 files\n", - "Wrote 15000 MNIST images and labels as HD5 files\n", - "Wrote 20000 MNIST images and labels as HD5 files\n", - "Wrote 25000 MNIST images and labels as HD5 files\n", - "Wrote 30000 MNIST images and labels as HD5 files\n", - "Wrote 35000 MNIST images and labels as HD5 files\n", - "Wrote 40000 MNIST images and labels as HD5 files\n", - "Wrote 45000 MNIST images and labels as HD5 files\n", - "Wrote 50000 MNIST images and labels as HD5 files\n" - ] - } - ], + "outputs": [], "source": [ "mnist_as_hd5(HD5_FOLDER)" ] @@ -247,41 +210,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2020-09-18 16:17:13,404 - logger:25 - INFO - Logging configuration was loaded. Log messages can be found at ./runs/learn_mnist/log_2020-09-18_16-17_0.log.\n", - "2020-09-18 16:17:13,410 - arguments:444 - INFO - Command Line was: \n", - "./scripts/tf.sh train --tensors ./mnist_hd5s/ --tensormap_prefix ml4h.tensormap.mnist --input_tensors mnist_image --output_tensors mnist_label --batch_size 64 --test_steps 64 --epochs 24 --output_folder ./runs/ --id learn_mnist\n", - "\n", - "2020-09-18 16:17:13,410 - arguments:445 - INFO - Arguments are Namespace(activation='relu', aligned_dimension=16, alpha=0.5, anneal_max=2.0, anneal_rate=0.0, anneal_shift=0.0, app_csv=None, b_slice_force=None, balance_csvs=[], batch_size=64, bigquery_credentials_file='/mnt/ml4cvd/projects/jamesp/bigquery/bigquery-viewer-credentials.json', bigquery_dataset='broad-ml4cvd.ukbb7089_r10data', block_size=3, bottleneck_type=, cache_size=437500000.0, categorical_field_ids=[], continuous_field_ids=[], continuous_file=None, continuous_file_column=None, continuous_file_discretization_bounds=[], continuous_file_normalize=False, conv_dilate=False, conv_layers=[32], conv_normalize=None, conv_regularize=None, conv_regularize_rate=0.0, conv_type='conv', conv_x=[3], conv_y=[3], conv_z=[2], debug=False, dense_blocks=[32, 24, 16], dense_layers=[16, 64], dense_normalize=None, dense_regularize=None, dense_regularize_rate=0.0, dicom_series='cine_segmented_sax_b6', dicoms='./dicoms/', eager=False, embed_visualization=None, epochs=24, explore_export_errors=False, freeze_model_layers=False, hidden_layer='embed', id='learn_mnist', imputation_method_for_continuous_fields='random', include_array=False, include_instance=False, include_missing_continuous_channel=False, input_tensors=['mnist_image'], inspect_model=False, inspect_show_labels=True, join_tensors=['partners_ecg_patientid_clean'], label_weights=None, language_layer='ecg_rest_text', language_prefix='ukb_ecg_rest', learning_rate=0.0002, learning_rate_schedule=None, logging_level='INFO', match_any_window=False, max_models=16, max_parameters=9000000, max_patients=999999, max_pools=[], max_sample_id=7000000, max_samples=None, max_slices=999999, min_sample_id=0, min_samples=3, min_values=10, mixup_alpha=0, mlp_concat=False, mode='mlp', model_file=None, model_files=[], model_layers=None, mri_field_ids=['20208', '20209'], num_workers=8, number_per_window=1, optimizer='radam', order_in_window=None, output_folder='./runs/', output_tensors=['mnist_label'], padding='same', patience=8, phecode_definitions='/mnt/ml4cvd/projects/jamesp/data/phecode_definitions1.2.csv', phenos_folder='gs://ml4cvd/phenotypes/', plot_hist=True, plot_mode='clinical', pool_type='max', pool_x=2, pool_y=2, pool_z=1, protected_tensors=[], random_seed=12878, reference_end_time_tensor=None, reference_join_tensors=None, reference_labels=None, reference_name='Reference', reference_start_time_tensor=None, reference_tensors=None, sample_csv=None, sample_weight=None, save_last_model=False, t=48, tensor_maps_in=[TensorMap(mnist_image, (28, 28, 1), continuous)], tensor_maps_out=[TensorMap(mnist_label, (10,), categorical)], tensor_maps_protected=[], tensormap_prefix='ml4h.tensormap.mnist', tensors='./mnist_hd5s/', tensors_name='Tensors', tensors_source=None, test_csv=None, test_ratio=0.1, test_steps=64, text_file=None, text_one_hot=False, text_window=32, time_frequency='3M', time_tensor='partners_ecg_datetime', train_csv=None, training_steps=72, tsv_style='standard', u_connect=defaultdict(, {}), valid_csv=None, valid_ratio=0.2, validation_steps=18, window_name=None, write_pngs=False, x=256, xml_field_ids=['20205', '6025'], xml_folder='/mnt/disks/ecg-rest-xml/', y=256, z=48, zip_folder='/mnt/disks/sax-mri-zip/', zoom_height=96, zoom_width=96, zoom_x=50, zoom_y=35)\n", - "\n", - "2020-09-18 16:17:13,411 - tensor_generators:661 - INFO - Found 0 train, 0 validation, and 0 testing tensors at: ./mnist_hd5s/\n" - ] - }, - { - "ename": "ValueError", - "evalue": "Not enough tensors at ./mnist_hd5s/\nFound 0 training, 0 validation, and 0 testing tensors\nDiscarded 0 tensors", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m ]\n\u001b[1;32m 13\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparse_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mtrain_multimodal_multitask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/home/sam/ml/ml4h/recipes.py\u001b[0m in \u001b[0;36mtrain_multimodal_multitask\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_multimodal_multitask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0mgenerate_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerate_valid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerate_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_train_valid_tensor_generators\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_multimodal_multitask_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m model = train_model_from_generators(\n", - "\u001b[0;32m/home/sam/ml/ml4h/tensor_generators.py\u001b[0m in \u001b[0;36mtest_train_valid_tensor_generators\u001b[0;34m(tensor_maps_in, tensor_maps_out, tensor_maps_protected, tensors, batch_size, num_workers, training_steps, validation_steps, cache_size, balance_csvs, keep_paths, keep_paths_test, mixup_alpha, sample_csv, valid_ratio, test_ratio, train_csv, valid_csv, test_csv, siamese, sample_weight, **kwargs)\u001b[0m\n\u001b[1;32m 796\u001b[0m \u001b[0mtrain_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 797\u001b[0m \u001b[0mvalid_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalid_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 798\u001b[0;31m \u001b[0mtest_csv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 799\u001b[0m )\n\u001b[1;32m 800\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/home/sam/ml/ml4h/tensor_generators.py\u001b[0m in \u001b[0;36mget_train_valid_test_paths\u001b[0;34m(tensors, sample_csv, valid_ratio, test_ratio, train_csv, valid_csv, test_csv)\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalid_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_paths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m raise ValueError(\n\u001b[0;32m--> 665\u001b[0;31m \u001b[0;34mf'Not enough tensors at {tensors}\\n'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 666\u001b[0m \u001b[0;34mf'Found {len(train_paths)} training, {len(valid_paths)} validation, and {len(test_paths)} testing tensors\\n'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;34mf'Discarded {len(discard_paths)} tensors'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Not enough tensors at ./mnist_hd5s/\nFound 0 training, 0 validation, and 0 testing tensors\nDiscarded 0 tensors" - ] - } - ], + "outputs": [], "source": [ "sys.argv = ['train', \n", " '--tensors', HD5_FOLDER, \n", - " '--tensormap_prefix', 'ml4h.tensormap.mnist',\n", " '--input_tensors', 'mnist_image',\n", " '--output_tensors', 'mnist_label',\n", " '--batch_size', '64',\n", diff --git a/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb b/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb index 26c08c73a..f9f99b4ad 100644 --- a/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb +++ b/notebooks/mri/mri_cardiac_long_axis_sketch.ipynb @@ -67,8 +67,8 @@ "outputs": [], "source": [ "def plot_lax(series, transpose=False, size=18):\n", - " cols = 2\n", - " rows = 25\n", + " cols = 5\n", + " rows = 10\n", " _, axes = plt.subplots(rows, cols, figsize=(size, size))\n", " for dcm in series:\n", " col = (dcm.InstanceNumber-1)%cols\n", @@ -76,7 +76,7 @@ " if transpose:\n", " axes[row, col].imshow(dcm.pixel_array.T)\n", " else:\n", - " axes[row, col].imshow(dcm.pixel_array, cmap='gray')\n", + " axes[row, col].imshow(dcm.pixel_array)\n", " axes[row, col].set_yticklabels([])\n", " axes[row, col].set_xticklabels([])" ] @@ -132,7 +132,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb b/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb index 64493df15..20b93d12f 100644 --- a/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb +++ b/notebooks/mri/mri_cardiac_short_axis_sketch.ipynb @@ -32,9 +32,8 @@ "source": [ "!mkdir ./dcm_scratch\n", "!rm ./dcm_scratch/*\n", - "\n", - "!cp /mnt/ml4cvd/projects/bulk/cardiac_mri/2467677_20209_2_0.zip ./dcm_scratch/\n", - "!unzip ./dcm_scratch/2467677_20209_2_0.zip -d ./dcm_scratch/" + "!cp /mnt/ml4cvd/projects/bulk/cardiac_mri/1000387_20209_2_0.zip ./dcm_scratch/\n", + "!unzip ./dcm_scratch/1000387_20209_2_0.zip -d ./dcm_scratch/" ] }, { @@ -155,12 +154,11 @@ " if idx >= sides*sides:\n", " continue\n", " if _is_mitral_valve_segmentation(dcm):\n", - " axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')\n", + " axes[idx%sides, idx//sides].imshow(dcm.pixel_array)\n", " else:\n", " try:\n", " overlay, anatomical_mask, ventricle_pixels = _get_overlay_from_dicom(dcm)\n", - " #axes[idx%sides, idx//sides].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array), cmap='gray')\n", - " axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')\n", + " axes[idx%sides, idx//sides].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array))\n", " except KeyError:\n", " print(f'Could not get overlay at {dcm.InstanceNumber}, angle {s}')\n", " axes[idx, idx//sides].imshow(dcm.pixel_array)\n", @@ -174,7 +172,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_b_series(series[4], sides=2)" + "plot_b_series(series[2], sides=7)" ] }, { @@ -263,7 +261,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/notebooks/review_results/identify_a_sample_to_review.ipynb b/notebooks/review_results/identify_a_sample_to_review.ipynb index a5c659d24..f26d76309 100644 --- a/notebooks/review_results/identify_a_sample_to_review.ipynb +++ b/notebooks/review_results/identify_a_sample_to_review.ipynb @@ -16,7 +16,7 @@ "
\n", " This notebook assumes:\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -32,7 +32,7 @@ "from ml4h.runtime_data_defines import determine_runtime\n", "from ml4h.runtime_data_defines import Runtime\n", "\n", - "if Runtime.ml4h_VM == determine_runtime():\n", + "if Runtime.ML4H_VM == determine_runtime():\n", " !pip3 install --user --upgrade pandas_gbq pyarrow\n", " # Be sure to restart the kernel if pip installs anything." ] @@ -259,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.8" }, "toc": { "base_numbering": 1, diff --git a/notebooks/review_results/review_one_sample.ipynb b/notebooks/review_results/review_one_sample.ipynb index 9fb1fc2b2..9569380aa 100644 --- a/notebooks/review_results/review_one_sample.ipynb +++ b/notebooks/review_results/review_one_sample.ipynb @@ -16,7 +16,7 @@ "
\n", " This notebook assumes\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -41,7 +41,7 @@ "from ml4h.runtime_data_defines import determine_runtime\n", "from ml4h.runtime_data_defines import Runtime\n", "\n", - "if Runtime.ml4h_VM == determine_runtime():\n", + "if Runtime.ML4H_VM == determine_runtime():\n", " !pip3 install --user --upgrade pandas_gbq pyarrow\n", " # Be sure to restart the kernel if pip installs anything." ] @@ -523,7 +523,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## MRI dicom visualization" + "## MRI DICOM visualization" ] }, { @@ -641,7 +641,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.8" }, "toc": { "base_numbering": 1, diff --git a/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb b/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb index 0c4a1c97a..97eee7d12 100644 --- a/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb +++ b/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb @@ -21,7 +21,7 @@ "metadata": {}, "source": [ "
\n", - " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732)\n", + " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608)\n", "
" ] }, @@ -41,7 +41,9 @@ "from ml4h.visualization_tools.annotation_storage import BigQueryAnnotationStorage\n", "\n", "import pandas as pd\n", - "import tensorflow as tf" + "import tensorflow as tf\n", + "\n", + "%matplotlib inline" ] }, { @@ -62,7 +64,7 @@ "outputs": [], "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { @@ -255,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" + "#dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" ] }, { @@ -264,7 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" + "#dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" ] }, { @@ -281,7 +283,7 @@ "outputs": [], "source": [ "SAMPLE_TO_REVIEW = 5993648\n", - "folder = 'gs://broad-ml4cvd-vcm/'" + "folder = 'gs://deflaux-test-001/'" ] }, { @@ -326,7 +328,7 @@ "metadata": {}, "outputs": [], "source": [ - "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" + "#dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" ] }, { @@ -335,7 +337,7 @@ "metadata": {}, "outputs": [], "source": [ - "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" + "#dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW, folder=folder)" ] }, { @@ -458,24 +460,6 @@ "hd5_mri_plots.choose_mri_tmap(sample_id=SAMPLE_TO_REVIEW)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dicom_interactive_plots.choose_mri(sample_id=SAMPLE_TO_REVIEW)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dicom_plots.choose_cardiac_mri(sample_id=SAMPLE_TO_REVIEW)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -631,7 +615,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.8" }, "toc": { "base_numbering": 1, @@ -646,7 +630,7 @@ "height": "calc(100% - 180px)", "left": "10px", "top": "150px", - "width": "342.756px" + "width": "197.756px" }, "toc_section_display": true, "toc_window_display": true diff --git a/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb b/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb index e166223b1..e17eb08e2 100644 --- a/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb +++ b/notebooks/terra_featured_workspace/identify_a_sample_to_review_interactive.ipynb @@ -20,7 +20,7 @@ "
\n", " This notebook assumes:\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -79,7 +79,7 @@ "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", "# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { diff --git a/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb b/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb index 44379f740..8e863a7e4 100644 --- a/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb +++ b/notebooks/terra_featured_workspace/review_one_sample_interactive.ipynb @@ -18,7 +18,7 @@ "
\n", " This notebook assumes\n", "
    \n", - "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
  • \n", + "
  • Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
  • \n", "
  • ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
  • \n", "
\n", "
" @@ -84,7 +84,7 @@ "source": [ "#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n", "# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n", - "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'" + "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'" ] }, { From 509b32ff1e9e64f16cef4dfc5a6dbf30f511c76c Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Wed, 30 Sep 2020 06:41:16 -0400 Subject: [PATCH 06/21] paired --- tests/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index b553cf848..b288c57bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,9 @@ from ml4h.test_utils import build_hdf5s -def pytest_configure(): +def pytest_configure(config): pytest.N_TENSORS = 50 + config.addinivalue_line("markers", "env(slow): mark tests as slow") #@mock.patch.dict(TMAPS, MOCK_TMAPS) From cf7a04278cecc8633b5307314988d9a9bdf4b055 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Wed, 30 Sep 2020 06:44:51 -0400 Subject: [PATCH 07/21] paired --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index b288c57bb..1b5cfde2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ def pytest_configure(config): pytest.N_TENSORS = 50 - config.addinivalue_line("markers", "env(slow): mark tests as slow") + config.addinivalue_line("markers", "slow: mark tests as slow") #@mock.patch.dict(TMAPS, MOCK_TMAPS) From a145dac7622a2966a4b97953bc8038132580e3cf Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Wed, 30 Sep 2020 06:58:00 -0400 Subject: [PATCH 08/21] paired --- tests/test_models.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 976a6b846..e3195df1d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -349,6 +349,42 @@ def test_paired_models(self, pairs, tmpdir): **params ) + @pytest.mark.parametrize( + 'pairs', + [ + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1])], + [(CATEGORICAL_TMAPS[2], CATEGORICAL_TMAPS[1])], + [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[3])] + ], + ) + @pytest.mark.parametrize( + 'output_tmaps', + [ + [CONTINUOUS_TMAPS[0]], + [CATEGORICAL_TMAPS[0]], + [CONTINUOUS_TMAPS[0], CATEGORICAL_TMAPS[0]], + ], + ) + def test_semi_supervised_paired_models(self, pairs, output_tmaps, tmpdir): + params = DEFAULT_PARAMS.copy() + pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) + params['u_connect'] = {tm: [] for tm in pair_list} + m, encoders, decoders = make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list+output_tmaps, + **params + ) + assert_model_trains(pair_list, pair_list+output_tmaps, m, skip_shape_check=True) + m.save(os.path.join(tmpdir, 'paired_ae.h5')) + path = os.path.join(tmpdir, f'm{MODEL_EXT}') + m.save(path) + make_paired_autoencoder_model( + pairs=pairs, + tensor_maps_in=pair_list, + tensor_maps_out=pair_list+output_tmaps, + **params + ) @pytest.mark.parametrize( 'tmaps', From 9010e1bf4d25ca8f8910256427951333b41e8cb2 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 1 Oct 2020 07:38:51 -0400 Subject: [PATCH 09/21] paired --- ml4h/models.py | 1 - ml4h/plots.py | 8 +++----- ml4h/recipes.py | 3 +-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ml4h/models.py b/ml4h/models.py index 3eb53ccbd..ac2f7a350 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -1206,7 +1206,6 @@ def make_paired_autoencoder_model( multimodal_activation = Concatenate()(multimodal_activations) multimodal_activation = Dense(units=kwargs['dense_layers'][0])(multimodal_activation) - #multimodal_activation = _activation_layer(kwargs['activation'])(multimodal_activation) latent_inputs = Input(shape=(kwargs['dense_layers'][0]), name='input_concept_space') # build decoder models diff --git a/ml4h/plots.py b/ml4h/plots.py index 6b5ad6a78..c5da87886 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2205,10 +2205,8 @@ def plot_hit_to_miss_transforms(latent_df, decoders, feature='Sex_Female_0_0', p fig, axes = plt.subplots(samples, 2, figsize=(18, samples * 4)) for i in range(samples): axes[i, 0].set_title(f"{feature}: {sexes[i]} ?>== thresh: @@ -2228,4 +2226,4 @@ def plot_hit_to_miss_transforms(latent_df, decoders, feature='Sex_Female_0_0', p figure_path = f'{prefix}/{dtm.name}_{feature}_transform_scalar_{scalar}.png' if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) - plt.savefig(figure_path) \ No newline at end of file + plt.savefig(figure_path) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 148c2c517..c6b0581a1 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -421,8 +421,7 @@ def train_paired_model(args): reconstruction = decoders[dtm].predict(embed) logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') - if not os.path.exists(os.path.dirname(my_out_path)): - os.makedirs(os.path.dirname(my_out_path)) + os.makedirs(os.path.dirname(my_out_path), exist_ok=True) plot_reconstruction(dtm, test_data[dtm.input_name()], reconstruction, my_out_path, test_paths, samples) else: y_truth = np.array(test_labels[dtm.output_name()]) From c50a11643b3b0305a470206ce063d50929ae1f31 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 07:27:38 -0400 Subject: [PATCH 10/21] lots of fixes --- ml4h/arguments.py | 6 ++-- ml4h/models.py | 84 +++++++++++++++++++++++++++-------------------- ml4h/plots.py | 4 --- ml4h/recipes.py | 20 +++++++---- 4 files changed, 65 insertions(+), 49 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 6692bcc74..3a8e53516 100644 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -148,6 +148,7 @@ def parse_args(): parser.add_argument('--dense_normalize', default=None, choices=list(NORMALIZATION_CLASSES), help='Type of normalization layer for dense layers.') parser.add_argument('--activation', default='relu', help='Activation function for hidden units in neural nets dense layers.') parser.add_argument('--conv_layers', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional layers.') + parser.add_argument('--conv_width', default=[71], nargs='*', type=int, help='X dimension of convolutional kernel for 1D models. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_x', default=[3], nargs='*', type=int, help='X dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_y', default=[3], nargs='*', type=int, help='Y dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') parser.add_argument('--conv_z', default=[2], nargs='*', type=int, help='Z dimension of convolutional kernel. Filter sizes are specified per layer given by conv_layers and per block given by dense_blocks. Filter sizes are repeated if there are less than the number of layers/blocks.') @@ -172,7 +173,9 @@ def parse_args(): '--pairs', nargs=2, action='append', help='TensorMap pairs for paired autoencoder. The pair_loss metric will encourage similar embeddings for each two input TensorMap pairs. Can be provided multiple times.', ) - parser.add_argument('--aligned_dimension', default=16, type=int, help='Dimensionality of aligned embedded space for multi-modal alignment models.') + parser.add_argument('--pair_loss', default='cosine', help='Distance metric between paired embeddings', choices=['euclid', 'cosine']) + parser.add_argument('--pair_loss_weight', type=float, default=1.0, help='Weight on the pair loss term relative to other losses') + parser.add_argument('--multimodal_merge', default='average', choices=['average', 'concatenate'], help='How to merge modality specific encodings.') parser.add_argument( '--max_parameters', default=9000000, type=int, help='Maximum number of trainable parameters in a model during hyperparameter optimization.', @@ -205,7 +208,6 @@ def parse_args(): parser.add_argument('--validation_steps', default=18, type=int, help='Number of validation batches to examine in an epoch validation.') parser.add_argument('--learning_rate', default=0.0002, type=float, help='Learning rate during training.') parser.add_argument('--mixup_alpha', default=0, type=float, help='If positive apply mixup and sample from a Beta with this value as shape parameter alpha.') - parser.add_argument('--pair_loss', default='cosine', help='Distance metric between paired embeddings', choices=['euclid', 'cosine']) parser.add_argument( '--label_weights', nargs='*', type=float, help='List of per-label weights for weighted categorical cross entropy. If provided, must map 1:1 to number of labels.', diff --git a/ml4h/models.py b/ml4h/models.py index ac2f7a350..67996f272 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -525,10 +525,8 @@ def l2_norm(x, axis=None): """ takes an input tensor and returns the l2 norm along specified axis """ - square_sum = K.sum(K.square(x), axis=axis, keepdims=True) norm = K.sqrt(K.maximum(square_sum, K.epsilon())) - return norm @@ -536,12 +534,11 @@ def pairwise_cosine_difference(t1, t2): t1_norm = t1 / l2_norm(t1, axis=-1) t2_norm = t2 / l2_norm(t2, axis=-1) dot = K.clip(K.batch_dot(t1_norm, t2_norm), -1, 1) - return tf.acos(dot) + return K.mean(tf.acos(dot)) class CosineLossLayer(Layer): """Layer that creates an Cosine loss.""" - def __init__(self, weight, **kwargs): super(CosineLossLayer, self).__init__(**kwargs) self.weight = weight @@ -552,15 +549,13 @@ def get_config(self): return config def call(self, inputs): - # We use `add_loss` to create a regularization loss - # that depends on the inputs. + # We use `add_loss` to create a regularization loss that depends on the inputs. self.add_loss(self.weight * pairwise_cosine_difference(inputs[0], inputs[1])) return inputs class L2LossLayer(Layer): """Layer that creates an L2 loss.""" - def __init__(self, weight, **kwargs): super(L2LossLayer, self).__init__(**kwargs) self.weight = weight @@ -573,6 +568,7 @@ def get_config(self): def call(self, inputs): self.add_loss(self.weight * tf.reduce_sum(tf.square(inputs[0] - inputs[1]))) return inputs + return inputs class VariationalDiagNormal(Layer): @@ -1150,7 +1146,9 @@ def _make_multimodal_multitask_model( def make_paired_autoencoder_model( pairs: List[Tuple[TensorMap, TensorMap]], - pair_loss='cosine', + pair_loss: str = 'cosine', + pair_loss_weight: float = 1.0, + multimodal_merge: str = 'average', **kwargs ) -> Model: opt = get_optimizer( @@ -1172,8 +1170,7 @@ def make_paired_autoencoder_model( logging.info(f"Loaded model file from: {kwargs['model_file']}") return m, encoders, decoders - inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} - original_outputs = {tm: 1 for tm in kwargs['tensor_maps_out']} + inputs = {tm.input_name(): Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} real_serial_layers = kwargs['model_layers'] kwargs['model_layers'] = None multimodal_activations = [] @@ -1182,30 +1179,40 @@ def make_paired_autoencoder_model( outputs = {} losses = [] for left, right in pairs: - kwargs['tensor_maps_in'] = [left] - left_model = make_multimodal_multitask_model(**kwargs) - encode_left = make_hidden_layer_model(left_model, [left], kwargs['hidden_layer']) - h_left = encode_left(inputs[left]) + if left in encoders: + encode_left = encoders[left] + else: + kwargs['tensor_maps_in'] = [left] + left_model = make_multimodal_multitask_model(**kwargs) + encode_left = make_hidden_layer_model(left_model, [left], kwargs['hidden_layer']) + h_left = encode_left(inputs[left.input_name()]) - kwargs['tensor_maps_in'] = [right] - right_model = make_multimodal_multitask_model(**kwargs) - encode_right = make_hidden_layer_model(right_model, [right], kwargs['hidden_layer']) - h_right = encode_right(inputs[right]) + if right in encoders: + encode_right = encoders[right] + else: + kwargs['tensor_maps_in'] = [right] + right_model = make_multimodal_multitask_model(**kwargs) + encode_right = make_hidden_layer_model(right_model, [right], kwargs['hidden_layer']) + h_right = encode_right(inputs[right.input_name()]) if pair_loss == 'cosine': - loss_layer = CosineLossLayer(1.0) + loss_layer = CosineLossLayer(pair_loss_weight) elif pair_loss == 'euclid': - loss_layer = L2LossLayer(1.0) + loss_layer = L2LossLayer(pair_loss_weight) - paired_embeddings = loss_layer([h_left, h_right]) - multimodal_activations.extend(paired_embeddings) - if left not in encoders: - encoders[left] = encode_left - if right not in encoders: - encoders[right] = encode_right + multimodal_activations.extend(loss_layer([h_left, h_right])) + encoders[left] = encode_left + encoders[right] = encode_right - multimodal_activation = Concatenate()(multimodal_activations) - multimodal_activation = Dense(units=kwargs['dense_layers'][0])(multimodal_activation) + kwargs['tensor_maps_in'] = list(inputs.keys()) + if multimodal_merge == 'average': + multimodal_activation = Average()(multimodal_activations) + elif multimodal_merge == 'concatenate': + multimodal_activation = Concatenate()(multimodal_activations) + multimodal_activation = Dense(units=kwargs['dense_layers'][0], use_bias=False)(multimodal_activation) + multimodal_activation = _activation_layer(kwargs['activation'])(multimodal_activation) + else: + raise NotImplementedError(f'No merge architecture for method: {multimodal_merge}') latent_inputs = Input(shape=(kwargs['dense_layers'][0]), name='input_concept_space') # build decoder models @@ -1222,7 +1229,7 @@ def make_paired_autoencoder_model( tensor_map_out=tm, filters_per_dense_block=kwargs['dense_blocks'][::-1], conv_layer_type=kwargs['conv_type'], - conv_x=kwargs['conv_x'], + conv_x=kwargs['conv_x'] if tm.axes() > 2 else kwargs['conv_width'], conv_y=kwargs['conv_y'], conv_z=kwargs['conv_z'], block_size=kwargs['block_size'], @@ -1237,25 +1244,30 @@ def make_paired_autoencoder_model( ) reconstruction = decode(restructure(latent_inputs), {}, {}) else: + dense_block = FullyConnectedBlock( + widths=kwargs['dense_layers'], + activation=kwargs['activation'], + normalization=kwargs['dense_normalize'], + regularization=kwargs['dense_regularize'], + regularization_rate=kwargs['dense_regularize_rate'], + is_encoder=False, + ) decode = DenseDecoder(tensor_map_out=tm, parents=tm.parents, activation=kwargs['activation']) - reconstruction = decode(latent_inputs, {}, {}) + reconstruction = decode(dense_block(latent_inputs), {}, {}) decoder = Model(latent_inputs, reconstruction, name=tm.output_name()) decoders[tm] = decoder outputs[tm.output_name()] = decoder(multimodal_activation) losses.append(tm.loss) - kwargs['tensor_maps_out'] = list(original_outputs.keys()) - kwargs['tensor_maps_in'] = list(inputs.keys()) - - m = Model(inputs=list(inputs.values()), outputs=outputs) + m = Model(inputs=list(inputs.values()), outputs=list(outputs.values())) my_metrics = {tm.output_name(): tm.metrics for tm in kwargs['tensor_maps_out']} m.compile(optimizer=opt, loss=losses, metrics=my_metrics) m.summary() if real_serial_layers is not None: - m.load_weights(kwargs['model_layers'], by_name=True) - logging.info(f"Loaded model weights from:{kwargs['model_layers']}") + m.load_weights(real_serial_layers, by_name=True) + logging.info(f"Loaded model weights from:{real_serial_layers}") return m, encoders, decoders diff --git a/ml4h/plots.py b/ml4h/plots.py index c5da87886..d3b61fe89 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -195,10 +195,6 @@ def plot_metric_history(history, training_steps: int, title: str, prefix='./figu if not k.startswith('val_'): if isinstance(history.history[k][0], LearningRateSchedule): history.history[k] = [history.history[k][0](i * training_steps) for i in range(len(history.history[k]))] - if len(np.array(history.history[k]).shape) > 1: # Hack for models with paired loss - history.history[k] = np.array(history.history[k])[:, 0, 0] - if 'val_' + k in history.history: - history.history['val_' + k] = np.array(history.history['val_' + k])[:, 0, 0] axes[row, col].plot(history.history[k]) k_split = str(k).replace('output_', '').split('_') k_title = " ".join(OrderedDict.fromkeys(k_split)) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index c6b0581a1..56ea2a00a 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -408,24 +408,30 @@ def train_paired_model(args): decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5') out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) - samples = args.test_steps * args.batch_size + samples = min(args.test_steps * args.batch_size, 12) predictions_list = full_model.predict(test_data) predictions_dict = {name: pred for name, pred in zip(full_model.output_names, predictions_list)} logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') performance_metrics = {} + for tm in args.tensor_maps_out: + if tm.axes() == 1: + y = predictions_dict[tm.output_name()] + y_truth = np.array(test_labels[tm.output_name()]) + metrics = evaluate_predictions(tm, y, y_truth, {}, tm.name, os.path.join(args.output_folder, args.id), test_paths) + performance_metrics.update(metrics) for i, etm in enumerate(encoders): embed = encoders[etm].predict(test_data[etm.input_name()]) plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) for dtm in decoders: + reconstruction = decoders[dtm].predict(embed) + logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') + my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') + if not os.path.exists(os.path.dirname(my_out_path)): + os.makedirs(os.path.dirname(my_out_path)) if dtm.axes() > 1: - reconstruction = decoders[dtm].predict(embed) - logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') - my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') - os.makedirs(os.path.dirname(my_out_path), exist_ok=True) plot_reconstruction(dtm, test_data[dtm.input_name()], reconstruction, my_out_path, test_paths, samples) else: - y_truth = np.array(test_labels[dtm.output_name()]) - performance_metrics.update(evaluate_predictions(dtm, decoders[dtm].predict(embed), y_truth, {}, dtm.name, os.path.join(args.output_folder, args.id), test_paths)) + evaluate_predictions(dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, test_paths) return performance_metrics From a7a42703678c1f60fc1e7d70ab5b3e6c8b3c7fd4 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 07:31:51 -0400 Subject: [PATCH 11/21] lots of fixes --- scripts/tf.sh | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/scripts/tf.sh b/scripts/tf.sh index 779a0daca..d3adb6b10 100755 --- a/scripts/tf.sh +++ b/scripts/tf.sh @@ -35,19 +35,7 @@ done export GROUP_NAMES GROUP_IDS # Create string to be called in Docker's bash shell via eval; -# this creates a user, adds groups, adds user to groups, then calls the Python script -CALL_DOCKER_AS_USER=" - apt-get -y install sudo; - useradd -u $(id -u) ${USER}; - GROUP_NAMES_ARR=( \${GROUP_NAMES} ); - GROUP_IDS_ARR=( \${GROUP_IDS} ); - for (( i=0; i<\${#GROUP_NAMES_ARR[@]}; ++i )); do - echo \"Creating group\" \${GROUP_NAMES_ARR[i]} \"with gid\" \${GROUP_IDS_ARR[i]}; - groupadd -f -g \${GROUP_IDS_ARR[i]} \${GROUP_NAMES_ARR[i]}; - echo \"Adding user ${USER} to group\" \${GROUP_NAMES_ARR[i]} - usermod -aG \${GROUP_NAMES_ARR[i]} ${USER} - done; - sudo -u ${USER}" +CALL_DOCKER_AS_USER= ################### HELP TEXT ############################################ @@ -85,7 +73,7 @@ USAGE_MESSAGE ################### OPTION PARSING ####################################### -while getopts ":i:d:m:ctjrhT" opt ; do +while getopts ":i:d:m:ctjuhT" opt ; do case ${opt} in h) usage @@ -113,8 +101,19 @@ while getopts ":i:d:m:ctjrhT" opt ; do mkdir -p /home/${USER}/jupyter/root/ mkdir -p /mnt/ml4cvd/projects/${USER}/projects/jupyter/auto/ ;; - r) # Output owned by root - CALL_DOCKER_AS_USER="" + u) # this creates a user, adds groups, adds user to groups, then calls the Python script + CALL_DOCKER_AS_USER=" + apt-get -y install sudo; + useradd -u $(id -u) ${USER}; + GROUP_NAMES_ARR=( \${GROUP_NAMES} ); + GROUP_IDS_ARR=( \${GROUP_IDS} ); + for (( i=0; i<\${#GROUP_NAMES_ARR[@]}; ++i )); do + echo \"Creating group\" \${GROUP_NAMES_ARR[i]} \"with gid\" \${GROUP_IDS_ARR[i]}; + groupadd -f -g \${GROUP_IDS_ARR[i]} \${GROUP_NAMES_ARR[i]}; + echo \"Adding user ${USER} to group\" \${GROUP_NAMES_ARR[i]} + usermod -aG \${GROUP_NAMES_ARR[i]} ${USER} + done; + sudo -u ${USER}" ;; T) PYTHON_COMMAND=${TEST_COMMAND} From 5258555124ececd4ba25f1d3d4df321c7f6f9297 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 07:33:03 -0400 Subject: [PATCH 12/21] lots of fixes --- ml4h/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models.py b/ml4h/models.py index 67996f272..3224faa02 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -25,7 +25,7 @@ from tensorflow.keras.layers import SpatialDropout1D, SpatialDropout2D, SpatialDropout3D, add, concatenate from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, LSTM, RepeatVector from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D, UpSampling1D, UpSampling2D, UpSampling3D, MaxPooling1D -from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer +from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D import tensorflow_probability as tfp From e4a2dc0c9ddeddfeebaa57a7e8cc8dd2c9cb0320 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 07:42:50 -0400 Subject: [PATCH 13/21] lots of fixes --- ml4h/plots.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index d3b61fe89..15a1fff55 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2085,16 +2085,17 @@ def plot_reconstruction( logging.info(f'Plotting {num_samples} reconstructions of {tm}.') if None in tm.shape: # can't handle dynamic shapes return + os.makedirs(os.path.dirname(folder), exist_ok=True) for i in range(num_samples): sample_id = os.path.basename(paths[i]).replace(TENSOR_EXT, '') title = f'{tm.name}_{sample_id}_reconstruction' - y = y_true[i].reshape(tm.shape) - yp = y_pred[i].reshape(tm.shape) + y = y_true[i] + yp = y_pred[i] if tm.axes() == 2: index2channel = {v: k for k, v in tm.channel_map.items()} - fig, axes = plt.subplots(tm.shape[1], 2, figsize=(2 * SUBPLOT_SIZE, 6*SUBPLOT_SIZE)) + fig, axes = plt.subplots(tm.shape[1], 2, figsize=(2 * SUBPLOT_SIZE, 6*SUBPLOT_SIZE)) #, sharey=True) for j in range(tm.shape[1]): - axes[j, 0].plot(y[:, j], c='k', linestyle='--', label='original') + axes[j, 0].plot(y[:, j], c='k', label='original') axes[j, 1].plot(yp[:, j], c='b', label='reconstruction') axes[j, 0].set_title(f'Lead: {index2channel[j]}') axes[j, 0].legend() @@ -2112,13 +2113,13 @@ def plot_reconstruction( for j in range(y.shape[3]): image_path_base = f'{folder}{sample_id}_{tm.name}_{i:03d}_{j:03d}' if tm.is_categorical(): - truth = np.argmax(yp[tm.output_name()][:, :, j, :], axis=-1) + truth = np.argmax(yp[:, :, j, :], axis=-1) prediction = np.argmax(y[:, :, j, :], axis=-1) plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', truth, cmap='plasma') plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', prediction, cmap='plasma') else: plt.imsave(f'{image_path_base}_truth{IMAGE_EXT}', y[:, :, j, 0], cmap='gray') - plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', yp[:, :, j, :], cmap='gray') + plt.imsave(f'{image_path_base}_prediction{IMAGE_EXT}', yp[:, :, j, 0], cmap='gray') plt.clf() From 08b357abd5f429bafc6b91889edbb1a7b6845817 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 07:46:20 -0400 Subject: [PATCH 14/21] lots of fixes --- ml4h/plots.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 15a1fff55..28ef0d5e8 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -22,6 +22,9 @@ from tensorflow.keras.optimizers.schedules import LearningRateSchedule import matplotlib + +from ml4h.tensor_generators import _sample_csv_to_set + matplotlib.use('Agg') # Need this to write images from the GSA servers. Order matters: import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt from matplotlib.ticker import NullFormatter @@ -2181,25 +2184,30 @@ def stratify_latent_space(stratify_column, stratify_thresh, latent_cols, latent_ def plot_hit_to_miss_transforms(latent_df, decoders, feature='Sex_Female_0_0', prefix='./figures/', - thresh=1.0, latent_dimension=256, samples=16, scalar=3.0, cmap='plasma'): + thresh=1.0, latent_dimension=256, samples=16, scalar=3.0, cmap='plasma', test_csv=None): latent_cols = [f'latent_{i}' for i in range(latent_dimension)] female, male = stratify_latent_space(feature, thresh, latent_cols, latent_df) sex_vector = female - male + if test_csv is not None: + sample_ids = [int(s) for s in _sample_csv_to_set(test_csv) if len(s) > 4 and int(s) in latent_df.index] + latent_df = latent_df.loc[sample_ids] + latent_df.info() + logging.info(f'Subset to test set with {len(sample_ids)} samples') + + samples = min(len(latent_df.index), samples) embeddings = latent_df.iloc[:samples][latent_cols].to_numpy() sexes = latent_df.iloc[:samples][feature].to_numpy() - print(f'Embedding shape: {embeddings.shape} sexes shape: {sexes.shape}') + logging.info(f'Embedding shape: {embeddings.shape} sexes shape: {sexes.shape}') sex_vectors = np.tile(sex_vector, (samples, 1)) male_to_female = embeddings + (scalar * sex_vectors) female_to_male = embeddings - (scalar * sex_vectors) - print(f'embeddings shape: {embeddings.shape} features vectors shape: {sex_vectors.shape}') for dtm in decoders: predictions = decoders[dtm].predict(embeddings) m2f = decoders[dtm].predict(male_to_female) f2m = decoders[dtm].predict(female_to_male) - print(f'prediction shape: {predictions.shape}') if dtm.axes() == 3: - fig, axes = plt.subplots(samples, 2, figsize=(18, samples * 4)) + fig, axes = plt.subplots(max(2, samples), 2, figsize=(18, samples * 4)) for i in range(samples): axes[i, 0].set_title(f"{feature}: {sexes[i]} ?>= Date: Thu, 15 Oct 2020 08:04:57 -0400 Subject: [PATCH 15/21] lots of fixes --- tests/test_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index e3195df1d..5c19868b1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -23,6 +23,7 @@ 'optimizer': 'adam', 'conv_type': 'conv', 'conv_layers': [6, 5, 3], + 'conv_width': [71]*5, 'conv_x': [3]*5, 'conv_y': [3]*5, 'conv_z': [2]*5, @@ -48,7 +49,7 @@ 'model_layers': None, 'model_file': None, 'hidden_layer': 'embed', - 'u_connect': {}, + 'u_connect': defaultdict(dict), } From 442a87a56901af11d13b6c15fdda1d550f531d04 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:18:17 -0400 Subject: [PATCH 16/21] lots of fixes --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5c19868b1..0e4b3c5cb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -332,7 +332,7 @@ def test_language_models(self, input_output_tmaps, tmpdir): def test_paired_models(self, pairs, tmpdir): params = DEFAULT_PARAMS.copy() pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) - params['u_connect'] = {tm: [] for tm in pair_list} + params['u_connect'] = {tm.input_name(): [] for tm in pair_list} m, encoders, decoders = make_paired_autoencoder_model( pairs=pairs, tensor_maps_in=pair_list, @@ -369,7 +369,7 @@ def test_paired_models(self, pairs, tmpdir): def test_semi_supervised_paired_models(self, pairs, output_tmaps, tmpdir): params = DEFAULT_PARAMS.copy() pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) - params['u_connect'] = {tm: [] for tm in pair_list} + params['u_connect'] = {tm.input_name(): [] for tm in pair_list} m, encoders, decoders = make_paired_autoencoder_model( pairs=pairs, tensor_maps_in=pair_list, From 0b76a95fde9169c72a62ac4ce5f08e7153f24ae8 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:22:45 -0400 Subject: [PATCH 17/21] lots of fixes --- ml4h/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models.py b/ml4h/models.py index 3224faa02..360afd94c 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -1170,7 +1170,7 @@ def make_paired_autoencoder_model( logging.info(f"Loaded model file from: {kwargs['model_file']}") return m, encoders, decoders - inputs = {tm.input_name(): Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} + inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in kwargs['tensor_maps_in']} real_serial_layers = kwargs['model_layers'] kwargs['model_layers'] = None multimodal_activations = [] From 377a467b6a268f62ce12de9c39faec1bec1cce57 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:23:37 -0400 Subject: [PATCH 18/21] lots of fixes --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0e4b3c5cb..5c19868b1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -332,7 +332,7 @@ def test_language_models(self, input_output_tmaps, tmpdir): def test_paired_models(self, pairs, tmpdir): params = DEFAULT_PARAMS.copy() pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) - params['u_connect'] = {tm.input_name(): [] for tm in pair_list} + params['u_connect'] = {tm: [] for tm in pair_list} m, encoders, decoders = make_paired_autoencoder_model( pairs=pairs, tensor_maps_in=pair_list, @@ -369,7 +369,7 @@ def test_paired_models(self, pairs, tmpdir): def test_semi_supervised_paired_models(self, pairs, output_tmaps, tmpdir): params = DEFAULT_PARAMS.copy() pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs])) - params['u_connect'] = {tm.input_name(): [] for tm in pair_list} + params['u_connect'] = {tm: [] for tm in pair_list} m, encoders, decoders = make_paired_autoencoder_model( pairs=pairs, tensor_maps_in=pair_list, From 23c438c734ce6294d5f9ac522ff4d191fd8b0093 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:27:50 -0400 Subject: [PATCH 19/21] lots of fixes --- ml4h/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models.py b/ml4h/models.py index 360afd94c..cbe0efc57 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -1185,7 +1185,7 @@ def make_paired_autoencoder_model( kwargs['tensor_maps_in'] = [left] left_model = make_multimodal_multitask_model(**kwargs) encode_left = make_hidden_layer_model(left_model, [left], kwargs['hidden_layer']) - h_left = encode_left(inputs[left.input_name()]) + h_left = encode_left(inputs[left]) if right in encoders: encode_right = encoders[right] @@ -1193,7 +1193,7 @@ def make_paired_autoencoder_model( kwargs['tensor_maps_in'] = [right] right_model = make_multimodal_multitask_model(**kwargs) encode_right = make_hidden_layer_model(right_model, [right], kwargs['hidden_layer']) - h_right = encode_right(inputs[right.input_name()]) + h_right = encode_right(inputs[right]) if pair_loss == 'cosine': loss_layer = CosineLossLayer(pair_loss_weight) From 5aeebe8fb3c4eebf6bb41af955d30d5fc5c23fa5 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:32:49 -0400 Subject: [PATCH 20/21] lots of fixes --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5c19868b1..c7039b540 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,7 @@ MEAN_PRECISION_EPS = .02 # how much mean precision degradation is acceptable DEFAULT_PARAMS = { 'activation': 'relu', - 'dense_layers': [4, 2], + 'dense_layers': [4], 'dense_blocks': [5, 3], 'block_size': 3, 'learning_rate': 1e-3, From 0fd4138549fa4873cdfce3cf827af0a954c53205 Mon Sep 17 00:00:00 2001 From: Samwell Freeman Date: Thu, 15 Oct 2020 08:34:55 -0400 Subject: [PATCH 21/21] lots of fixes --- ml4h/models.py | 2 +- tests/test_models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models.py b/ml4h/models.py index cbe0efc57..0fde8d284 100755 --- a/ml4h/models.py +++ b/ml4h/models.py @@ -1213,7 +1213,7 @@ def make_paired_autoencoder_model( multimodal_activation = _activation_layer(kwargs['activation'])(multimodal_activation) else: raise NotImplementedError(f'No merge architecture for method: {multimodal_merge}') - latent_inputs = Input(shape=(kwargs['dense_layers'][0]), name='input_concept_space') + latent_inputs = Input(shape=(kwargs['dense_layers'][-1]), name='input_concept_space') # build decoder models for tm in kwargs['tensor_maps_out']: diff --git a/tests/test_models.py b/tests/test_models.py index c7039b540..5c19868b1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,7 @@ MEAN_PRECISION_EPS = .02 # how much mean precision degradation is acceptable DEFAULT_PARAMS = { 'activation': 'relu', - 'dense_layers': [4], + 'dense_layers': [4, 2], 'dense_blocks': [5, 3], 'block_size': 3, 'learning_rate': 1e-3,