'''),
+ Annotation not submitted. Please try again.'''),
)
- submit_button.on_click(cb_on_button_clicked)
+ submit_button.on_click(on_button_clicked)
# Display all the widgets.
display(sample, box1, comment, submit_button, output)
-def _format_annotation(
- sample_id: Union[int, str], key: str, keyvalue: Union[int, float, str], comment: str,
-) -> Dict[str, Any]:
+def _format_annotation(sample_id, key, keyvalue, comment):
"""Helper method to clean and reshape info from the widgets and the environment into a dictionary representing the annotation."""
# Programmatically get the identity of the person running this Terra notebook.
current_user = os.getenv('OWNER_EMAIL')
@@ -140,10 +128,11 @@ def _format_annotation(
if current_user is None:
current_user = socket.gethostname() # By convention, we prefix the hostname with our username.
- value_numeric = None
- value_string = None
# Check whether the value is string or numeric.
- if keyvalue is not None:
+ if keyvalue is None:
+ value_numeric = None
+ value_string = None
+ else:
try:
value_numeric = float(keyvalue) # this will fail if the value is text
value_string = None
diff --git a/ml4h/visualization_tools/batch_image_annotations.py b/ml4h/visualization_tools/batch_image_annotations.py
deleted file mode 100644
index 34ff731df..000000000
--- a/ml4h/visualization_tools/batch_image_annotations.py
+++ /dev/null
@@ -1,236 +0,0 @@
-"""Methods for batch annotations of images stored as 3D tensors, such as MRIs, from within notebooks."""
-
-import json
-import os
-import socket
-import tempfile
-from typing import Any, Dict, List
-
-from IPython.display import display
-import numpy as np
-import pandas as pd
-import h5py
-from ipyannotations import PolygonAnnotator
-import ipywidgets as widgets
-from ml4h.visualization_tools.hd5_mri_plots import MRI_TMAPS
-from ml4h.visualization_tools.annotation_storage import AnnotationStorage
-from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage
-from PIL import Image
-import tensorflow as tf
-
-
-class BatchImageAnnotator():
- """Annotate batches of images with polygons drawn over regions of interest."""
-
- SUBMIT_BUTTON_DESCRIPTION = 'Submit polygons, goto next sample'
- USE_INSTRUCTIONS = '''
-
-
To draw a polygon, click anywhere you'd like to start. Continue to click
- along the edge of the polygon until arrive back where you started. To
- finish, simply click the first point (highlighted in red). It may be
- helpful to increase the point size if you're struggling (using the slider).
-
-
You can change the class of a polygon using the dropdown menu while the
- polygon is still "open", or unfinished. If you make a mistake, use the Undo
- button until the point that's wrong has disappeared.
-
-
You can move, but not add / subtract polygon points, by clicking the "Edit"
- button. Simply drag a point you want to adjust. Again, if you have
- difficulty aiming at the points, you can increase the point size.
-
-
You can increase or decrease the contrast and brightness of the image
- using the sliders to make it easier to annotate. Sometimes you need to see
- what's behind already-created annotations, and for this purpose you can
- make them more see-through using the "Opacity" slider.
-
- '''
- EXPECTED_COLUMN_NAMES = ['sample_id', 'tmap_name', 'instance_number', 'folder']
- DEFAULT_ANNOTATION_CLASSNAME = 'region_of_interest'
- CSS = '''
-
- '''
-
- def __init__(
- self, samples: pd.DataFrame, annotation_categories: List[str] = None,
- zoom: float = 1.5, annotation_storage: AnnotationStorage = TransientAnnotationStorage(),
- ):
- """Initializes an instance of BatchImageAnnotator.
-
- Args:
- samples: A dataframe of samples to annotate. Columns must include those
- in BatchImageAnnotator.EXPECTED_COLUMN_NAMES.
- annotation_categories: A list of one or more strings to serve as tags for the polygons.
- zoom: Desired zoom level for the image.
- annotation_storage: An instance of AnnotationStorage. This faciltates the use of a user-provided
- strategy for the storage and processing of annotations.
-
- Raises:
- ValueError: The provided dataframe does not contain the expected columns.
- """
- if not set(self.EXPECTED_COLUMN_NAMES).issubset(samples.columns):
- raise ValueError(f'samples Dataframe must contain columns {self.EXPECTED_COLUMN_NAMES}')
- self.samples = samples
- self.current_sample = 0
- # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11
- self.zoom = zoom
- self.annotation_storage = annotation_storage
- if annotation_categories is None:
- annotation_categories = [self.DEFAULT_ANNOTATION_CLASSNAME]
-
- self.annotation_widget = PolygonAnnotator(
- options=annotation_categories,
- canvas_size=(900, 280 * self.zoom),
- )
- self.annotation_widget.on_submit(self._store_annotations)
- self.annotation_widget.submit_button.description = self.SUBMIT_BUTTON_DESCRIPTION
- self.annotation_widget.submit_button.layout = widgets.Layout(width='300px')
-
- self.title_widget = widgets.HTML('')
- self.results_widget = widgets.HTML('')
-
- def _store_annotations(self, data: Dict[Any, Any]) -> None:
- """Transfer widget state to the annotation storage and advance to the next sample."""
- if self.current_sample >= self.samples.shape[0]:
- self.results_widget.value = '
Annotation batch complete!
Thank you for making the model better.'
- return
-
- # Convert polygon points in canvas coordinates to tensor coordinates.
- image_canvas_position = self.annotation_widget.canvas.image_extent
- x_offset, y_offset, _, _ = image_canvas_position
- tensor_coords = [
- (
- a['label'],
- [(
- int((p[0] - x_offset) / self.zoom),
- int((p[1] - y_offset) / self.zoom),
- ) for p in a['points']],
- ) for a in data
- ]
- # Store the annotation using the provided annotation storage strategy.
- self.annotation_storage.submit_annotation(
- sample_id=self.samples.loc[self.current_sample, 'sample_id'],
- annotator=os.getenv('OWNER_EMAIL') if os.getenv('OWNER_EMAIL') else socket.gethostname(),
- key=self.samples.loc[self.current_sample, 'tmap_name'],
- value_numeric=self.samples.loc[self.current_sample, 'instance_number'],
- value_string=self.samples.loc[self.current_sample, 'folder'],
- comment=json.dumps(tensor_coords),
- )
-
- # Display this annotation at the bottom of the widget.
- results = f'''
-
-
Prior sample's submitted annotations
- The {self.SUBMIT_BUTTON_DESCRIPTION} button is both printing out the polygons below and storing the polygons
- via strategy {self.annotation_storage.__class__.__name__}.
- Details: {self.annotation_storage.describe()}
-
sample info
- {self._format_info_for_current_sample()}
-
canvas coordinates
- image extent {image_canvas_position}
- {[f'
{json.dumps(x)}
' for x in data]}
-
source tensor coordinates
- {[f'
{json.dumps(x)}
' for x in tensor_coords]}
-
- '''
- self.results_widget.value = results
-
- # Advance to the next sample.
- self.current_sample += 1
- self._annotate_image_for_current_sample()
-
- def _format_info_for_current_sample(self) -> str:
- """Convert information about the current sample to an HTML table for display within the widget."""
- headings = ' '.join([f'
- '''
-
- def _annotate_image_for_current_sample(self) -> None:
- """Retrieve the data for the current sample and display its image in the annotation widget.
-
- If all samples have been processed, display the completion message.
- """
- if self.current_sample >= self.samples.shape[0]:
- self.annotation_widget.canvas.clear()
- # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this
- # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15
- self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access
- self.title_widget.value = '
Annotation batch complete!
Thank you for making the model better.'
- return
-
- sample_id = self.samples.loc[self.current_sample, 'sample_id']
- tmap_name = self.samples.loc[self.current_sample, 'tmap_name']
- instance_number = self.samples.loc[self.current_sample, 'instance_number']
- folder = self.samples.loc[self.current_sample, 'folder']
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- sample_hd5 = str(sample_id) + '.hd5'
- local_path = os.path.join(tmpdirname, sample_hd5)
- try:
- tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path)
- hd5 = h5py.File(local_path, mode='r')
- except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- self.annotation_widget.canvas.clear()
- # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this
- # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15
- self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access
- self.title_widget.value = f'''
-
-
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}
- Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket.
-
{e.message}
-
- '''
- return
-
- tensor = MRI_TMAPS[tmap_name].tensor_from_file(MRI_TMAPS[tmap_name], hd5)
- tensor_instance = tensor[:, :, instance_number]
- if self.zoom > 1.0:
- # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11
- img = Image.fromarray(tensor_instance)
- zoomed_img = img.resize([int(self.zoom * s) for s in img.size], Image.LANCZOS)
- tensor_instance = np.asarray(zoomed_img)
-
- self.annotation_widget.display(tensor_instance)
- self.title_widget.value = f'''
- {self.CSS}
-
-
Batch annotation of {self.samples.shape[0]} samples
- {self.USE_INSTRUCTIONS}
-
-
Current sample
- {self._format_info_for_current_sample()}
-
- '''
-
- def annotate_images(self) -> None:
- """Begin the batch annotation task by displaying the annotation widget populated with the first sample.
-
- The submit button is used to proceed to the next sample until all samples have been processed.
- """
- self._annotate_image_for_current_sample()
- display(widgets.VBox([self.title_widget, self.annotation_widget, self.results_widget]))
-
- def view_recent_submissions(self, count: int = 10) -> pd.DataFrame:
- """View a dataframe of up to [count] most recent submissions.
-
- Args:
- count: The number of the most recent submissions to return.
-
- Returns:
- A dataframe of the most recent annotations.
- """
- return self.annotation_storage.view_recent_submissions(count=count)
diff --git a/ml4h/visualization_tools/dicom_interactive_plots.py b/ml4h/visualization_tools/dicom_interactive_plots.py
index ec9d63834..d9850e841 100644
--- a/ml4h/visualization_tools/dicom_interactive_plots.py
+++ b/ml4h/visualization_tools/dicom_interactive_plots.py
@@ -1,4 +1,4 @@
-"""Methods for integration of interactive DICOM plots within notebooks.
+"""Methods for integration of interactive dicom plots within notebooks.
TODO:
* Continue to *pragmatically* improve this to make the visualization controls
@@ -8,15 +8,14 @@
import collections
import os
import tempfile
-from typing import Any, DefaultDict, Dict, Optional, Tuple
import zipfile
from IPython.display import display
from IPython.display import HTML
-import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_mri_folders
+import numpy as np
import pydicom
import tensorflow as tf
@@ -28,12 +27,15 @@
MAX_COLOR_RANGE = 6000
-def choose_mri(sample_id, folder: Optional[str] = None) -> None:
+def choose_mri(sample_id, folder=None):
"""Render widget to choose the MRI to plot.
Args:
sample_id: The id of the sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
+
+ Returns:
+ ipywidget or HTML upon error.
"""
if folder is None:
folders = get_mri_folders(sample_id)
@@ -43,26 +45,22 @@ def choose_mri(sample_id, folder: Optional[str] = None) -> None:
sample_mris = []
sample_mri_glob = str(sample_id) + '_*.zip'
try:
- for f in folders:
- sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(f, sample_mri_glob)))
+ for folder in folders:
+ sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob)))
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: MRI not available for sample {sample_id} in {folders}:
{e.message}
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
'''),
- )
- return
+
''')
if not sample_mris:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: MRI DICOMs not available for sample {sample_id} in {folders}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
'''),
- )
- return
+
''')
mri_chooser = widgets.Dropdown(
options=sample_mris,
@@ -79,11 +77,14 @@ def choose_mri(sample_id, folder: Optional[str] = None) -> None:
display(file_controls_ui, file_controls_output)
-def choose_mri_series(sample_mri: str) -> None:
+def choose_mri_series(sample_mri):
"""Render widgets and interactive plots for MRIs.
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
+
+ Returns:
+ ipywidget or HTML upon error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
local_path = os.path.join(tmpdirname, os.path.basename(sample_mri))
@@ -92,15 +93,13 @@ def choose_mri_series(sample_mri: str) -> None:
with zipfile.ZipFile(local_path, 'r') as zip_ref:
zip_ref.extractall(tmpdirname)
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:
{e.message}
-
'''),
- )
- return
+
''')
- unordered_dicoms: DefaultDict[Any, Any] = collections.defaultdict(dict)
+ unordered_dicoms = collections.defaultdict(dict)
for dcm_file in os.listdir(tmpdirname):
if not dcm_file.endswith('.dcm'):
continue
@@ -113,13 +112,8 @@ def choose_mri_series(sample_mri: str) -> None:
unordered_dicoms[key1][key2] = dcm
if not unordered_dicoms:
- display(
- HTML(f'''
- No series available in MRI for sample {os.path.basename(sample_mri)}.
- Try a different MRI.
-
'''),
- )
- return
+ print(f'\n\nNo series available in MRI for sample {os.path.basename(sample_mri)}\n\nTry a different MRI.')
+ return None
# Convert from dict of dicts to dict of ordered lists.
dicoms = {}
@@ -140,7 +134,7 @@ def choose_mri_series(sample_mri: str) -> None:
style={'description_width': 'initial'},
layout=widgets.Layout(width='800px'),
)
- # Slide through DICOM image instances using a slide bar.
+ # Slide through dicom image instances using a slide bar.
instance_chooser = widgets.IntSlider(
continuous_update=True,
value=default_instance_value,
@@ -218,25 +212,25 @@ def on_value_change(change):
display(viz_controls_ui, viz_controls_output)
-def compute_color_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]:
+def compute_color_range(dicoms, series_name):
"""Compute the mean values for the color ranges of instances in the series."""
vmin = np.mean([np.min(d.pixel_array) for d in dicoms[series_name]])
vmax = np.mean([np.max(d.pixel_array) for d in dicoms[series_name]])
- return (vmin, vmax)
+ return(vmin, vmax)
-def compute_instance_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]:
+def compute_instance_range(dicoms, series_name):
"""Compute middle and max instances."""
middle_instance = int(len(dicoms[series_name]) / 2)
max_instance = len(dicoms[series_name])
- return (middle_instance, max_instance)
+ return(middle_instance, max_instance)
def dicom_animation(
- dicoms: Dict[str, Any], series_name: str, instance: int, vmin: int, vmax: int, transpose: bool,
- fig_width: int, title_prefix: str = '',
-) -> None:
- """Render one frame of a DICOM animation.
+ dicoms, series_name, instance, vmin, vmax, transpose,
+ fig_width, title_prefix='',
+):
+ """Render one frame of a dicom animation.
Args:
dicoms: the dictionary DICOM series and instances lists
@@ -256,7 +250,7 @@ def dicom_animation(
dcm = dicoms[series_name][instance - 1]
if instance != dcm.InstanceNumber:
# Notice invalid input, but don't throw an error.
- print(f'WARNING: Instance parameter {str(instance)} and instance number {str(dcm.InstanceNumber)} do not match.')
+ print(f'WARNING: Instance parameter {str(instance)} and dicom instance number {str(dcm.InstanceNumber)} do not match.')
if transpose:
height = dcm.pixel_array.T.shape[0]
diff --git a/ml4h/visualization_tools/dicom_plots.py b/ml4h/visualization_tools/dicom_plots.py
index 093691382..ce2b3e083 100644
--- a/ml4h/visualization_tools/dicom_plots.py
+++ b/ml4h/visualization_tools/dicom_plots.py
@@ -1,17 +1,16 @@
-"""Methods for integration of DICOM plots within notebooks."""
+"""Methods for integration of dicom plots within notebooks."""
import collections
import os
import tempfile
-from typing import Dict, List, Optional, Tuple, Union
import zipfile
from IPython.display import display
from IPython.display import HTML
-import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_cardiac_mri_folder
+import numpy as np
import pydicom
from scipy.ndimage.morphology import binary_closing
from scipy.ndimage.morphology import binary_erosion
@@ -28,21 +27,21 @@
MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2}
-def _is_mitral_valve_segmentation(d: pydicom.FileDataset) -> bool:
- """Determine whether a DICOM has mitral valve segmentation.
+def _is_mitral_valve_segmentation(d): # -> bool:
+ """Determine whether a dicom has mitral valve segmentation.
This is used for visualization of CINE_segmented_SAX_InlineVF.
Args:
- d: the DICOM file
+ d: the dicom file
Returns:
- Whether or not the DICOM has mitral valve segmentation
+ Whether or not the dicom has mitral valve segmentation
"""
return d.SliceThickness == 6
-def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]:
+def _get_overlay_from_dicom(d):
"""Get an overlay from a DICOM file.
Morphological operators are used to transform the pixel outline of the
@@ -50,7 +49,7 @@ def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]:
is used for visualization of CINE_segmented_SAX_InlineVF.
Args:
- d: the DICOM file
+ d: the dicom file
Returns:
Raw overlay array with myocardium outline, anatomical mask (a pixel
@@ -78,30 +77,29 @@ def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]:
byte >>= 1
bit += 1
overlay = overlay[:expected_bit_length]
- if overlay_frames != 1:
- raise ValueError(f'DICOM has {overlay_frames} overlay frames, but only one expected.')
- overlay = overlay.reshape(rows, cols)
- idx = np.where(overlay == 1)
- min_pos = (np.min(idx[0]), np.min(idx[1]))
- max_pos = (np.max(idx[0]), np.max(idx[1]))
- short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1]))
- small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR)
- big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR)
- small_structure = _unit_disk(small_radius)
- m1 = binary_closing(overlay, small_structure).astype(np.int)
- big_structure = _unit_disk(big_radius)
- m2 = binary_closing(overlay, big_structure).astype(np.int)
- anatomical_mask = m1 + m2
- ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
- myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium'])
- if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM:
- erode_structure = _unit_disk(small_radius*1.5)
- anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int)
+ if overlay_frames == 1:
+ overlay = overlay.reshape(rows, cols)
+ idx = np.where(overlay == 1)
+ min_pos = (np.min(idx[0]), np.min(idx[1]))
+ max_pos = (np.max(idx[0]), np.max(idx[1]))
+ short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1]))
+ small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR)
+ big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR)
+ small_structure = _unit_disk(small_radius)
+ m1 = binary_closing(overlay, small_structure).astype(np.int)
+ big_structure = _unit_disk(big_radius)
+ m2 = binary_closing(overlay, big_structure).astype(np.int)
+ anatomical_mask = m1 + m2
ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
- return overlay, anatomical_mask, ventricle_pixels
+ myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium'])
+ if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM:
+ erode_structure = _unit_disk(small_radius*1.5)
+ anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int)
+ ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
+ return overlay, anatomical_mask, ventricle_pixels
-def _unit_disk(r: int) -> np.ndarray:
+def _unit_disk(r): # -> np.ndarray:
"""Get the unit disk for a radius.
This is used for visualization of CINE_segmented_SAX_InlineVF.
@@ -116,9 +114,7 @@ def _unit_disk(r: int) -> np.ndarray:
return (x ** 2 + y ** 2 <= r ** 2).astype(np.int)
-def plot_cardiac_long_axis(
- b_series: List[pydicom.FileDataset], sides: int = 7, fig_width: int = 18, title_prefix: str = '',
-) -> None:
+def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''):
"""Visualize CINE_segmented_SAX_InlineVF series.
Args:
@@ -172,9 +168,9 @@ def plot_cardiac_long_axis(
def plot_cardiac_short_axis(
- series: List[pydicom.FileDataset], transpose: bool = False, fig_width: int = 18,
- title_prefix: str = '',
-) -> None:
+ series, transpose=False, fig_width=18,
+ title_prefix='',
+):
"""Visualize CINE_segmented_LAX series.
Args:
@@ -229,14 +225,14 @@ def plot_cardiac_short_axis(
def plot_mri_series(
- sample_mri: str, dicoms: Dict[str, pydicom.FileDataset], series_name: str, sax_sides: int,
- lax_transpose: bool, fig_width: int,
-) -> None:
+ sample_mri, dicoms, series_name, sax_sides,
+ lax_transpose, fig_width,
+):
"""Visualize the applicable series within this DICOM.
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
- dicoms: A dictionary of DICOMs.
+ dicoms: A dictionary of dicoms.
series_name: The name of the chosen series.
sax_sides: How many sides to display for CINE_segmented_SAX_InlineVF.
lax_transpose: Whether to transpose when plotting CINE_segmented_LAX.
@@ -262,9 +258,10 @@ def plot_mri_series(
)
else:
print(f'Visualization not currently implemented for {series_name}.')
+ return None
-def choose_mri_series(sample_mri: str) -> None:
+def choose_mri_series(sample_mri):
"""Render widgets and plots for cardiac MRIs.
Visualization is supported for CINE_segmented_SAX_InlineVF series and
@@ -272,6 +269,9 @@ def choose_mri_series(sample_mri: str) -> None:
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
+
+ Returns:
+ ipywidget or HTML upon error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
local_path = os.path.join(tmpdirname, os.path.basename(sample_mri))
@@ -280,13 +280,11 @@ def choose_mri_series(sample_mri: str) -> None:
with zipfile.ZipFile(local_path, 'r') as zip_ref:
zip_ref.extractall(tmpdirname)
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:
- Neither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.
- Try a different MRI.
-
'''),
+ print(
+ f'\n\nNeither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.',
+ '\n\nTry a different MRI.',
)
+ return None
-def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None) -> None:
+def choose_cardiac_mri(sample_id, folder=None):
"""Render widget to choose the cardiac MRI to plot.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
+
+ Returns:
+ ipywidget or HTML upon error.
"""
if folder is None:
folder = get_cardiac_mri_folder(sample_id)
@@ -374,23 +374,19 @@ def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None)
try:
sample_mris = tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob))
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: Cardiac MRI not available for sample {sample_id} in {folder}:
{e.message}
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
'''),
- )
- return
+
''')
if not sample_mris:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: Cardiac MRI DICOM not available for sample {sample_id} in {folder}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
'''),
- )
- return
+
''')
mri_chooser = widgets.Dropdown(
options=[(os.path.basename(mri), mri) for mri in sample_mris],
diff --git a/ml4h/visualization_tools/ecg_interactive_plots.py b/ml4h/visualization_tools/ecg_interactive_plots.py
index 18ed39a9b..97a4e1547 100644
--- a/ml4h/visualization_tools/ecg_interactive_plots.py
+++ b/ml4h/visualization_tools/ecg_interactive_plots.py
@@ -2,12 +2,10 @@
import os
import tempfile
-from typing import Optional, Union
-from IPython.display import HTML
import altair as alt # Interactive data visualization for plots.
-from ml4h.TensorMap import TensorMap
-from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP
+from IPython.display import HTML
+from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME
from ml4h.visualization_tools.ecg_reshape import reshape_exercise_ecg_to_tidy
from ml4h.visualization_tools.ecg_reshape import reshape_resting_ecg_to_tidy
@@ -33,21 +31,18 @@
)
-def resting_ecg_interactive_plot(
- sample_id: Union[int, str], folder: Optional[str] = None,
- tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP,
-) -> Union[HTML, alt.Chart]:
+def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME):
"""Wrangle resting ECG data to tidy and present it as an interactive plot.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap: The TensorMap to use for ECG input.
+ tmap_name: The name of the TMAP to use for ecg input.
Returns:
An Altair plot or a notebook-friendly error.
"""
- tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap)
+ tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap_name)
if tidy_resting_ecg_signal.shape[0] == 0:
return HTML(f'''
@@ -90,9 +85,7 @@ def resting_ecg_interactive_plot(
return upper & lower
-def exercise_ecg_interactive_plot(
- sample_id: Union[int, str], folder: Optional[str] = None, time_interval_seconds: int = 10,
-) -> Union[HTML, alt.Chart]:
+def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds=10):
"""Wrangle exercise ECG data to tidy and present it as an interactive plot.
Args:
@@ -147,8 +140,7 @@ def exercise_ecg_interactive_plot(
lead_select,
).transform_filter(
# https://github.com/altair-viz/altair/issues/1960
- f'''((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time)
- && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})''',
+ f'((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})',
)
return trend.encode(y='heartrate:Q') & trend.encode(y='load:Q') & signal
diff --git a/ml4h/visualization_tools/ecg_reshape.py b/ml4h/visualization_tools/ecg_reshape.py
index 167eb5012..b3213d359 100644
--- a/ml4h/visualization_tools/ecg_reshape.py
+++ b/ml4h/visualization_tools/ecg_reshape.py
@@ -1,57 +1,53 @@
"""Methods for reshaping raw ECG signal data for use in the pandas ecosystem."""
import os
import tempfile
-from typing import Any, Dict, Optional, Tuple, Union
-import numpy as np
-import pandas as pd
from biosppy.signals.tools import filter_signal
import h5py
from ml4h.defines import ECG_BIKE_LEADS
from ml4h.defines import ECG_REST_LEADS
from ml4h.runtime_data_defines import get_exercise_ecg_hd5_folder
from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder
-from ml4h.TensorMap import TensorMap
-import ml4h.tensormap.ukb.ecg as ecg_tmaps
+from ml4h.tensor_maps_by_hand import TMAPS
+import numpy as np
+import pandas as pd
import tensorflow as tf
RAW_SCALE = 0.005 # Convert to mV.
SAMPLING_RATE = 500.0
-DEFAULT_RESTING_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_rest
+DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME = 'ecg_rest'
# TODO(deflaux): parameterize exercise ECG by TMAP name if there is similar ECG data from other studies.
-EXERCISE_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_bike_raw_full
+EXERCISE_ECG_SIGNAL_TMAP = TMAPS['ecg-bike-raw-full']
EXERCISE_ECG_TREND_TMAPS = [
- ecg_tmaps.ecg_bike_raw_trend_hr,
- ecg_tmaps.ecg_bike_raw_trend_load,
- ecg_tmaps.ecg_bike_raw_trend_grade,
- ecg_tmaps.ecg_bike_raw_trend_artifact,
- ecg_tmaps.ecg_bike_raw_trend_mets,
- ecg_tmaps.ecg_bike_raw_trend_pacecount,
- ecg_tmaps.ecg_bike_raw_trend_phasename,
- ecg_tmaps.ecg_bike_raw_trend_phasetime,
- ecg_tmaps.ecg_bike_raw_trend_time,
- ecg_tmaps.ecg_bike_raw_trend_vecount,
+ TMAPS['ecg-bike-raw-trend-hr'],
+ TMAPS['ecg-bike-raw-trend-load'],
+ TMAPS['ecg-bike-raw-trend-grade'],
+ TMAPS['ecg-bike-raw-trend-artifact'],
+ TMAPS['ecg-bike-raw-trend-mets'],
+ TMAPS['ecg-bike-raw-trend-pacecount'],
+ TMAPS['ecg-bike-raw-trend-phasename'],
+ TMAPS['ecg-bike-raw-trend-phasetime'],
+ TMAPS['ecg-bike-raw-trend-time'],
+ TMAPS['ecg-bike-raw-trend-vecount'],
]
EXERCISE_PHASES = {0.0: 'Pretest', 1.0: 'Exercise', 2.0: 'Recovery'}
-def _examine_available_keys(hd5: Dict[str, Any]) -> None:
+def _examine_available_keys(hd5):
print(f'hd5 ECG keys {[k for k in hd5.keys() if "ecg" in k]}')
for key in [k for k in hd5.keys() if 'ecg' in k]:
- print(f'hd5 {key} keys {k for k in hd5[key]}')
+ print(f'hd5 {key} keys {[k for k in hd5[key].keys()]}')
-def reshape_resting_ecg_to_tidy(
- sample_id: Union[int, str], folder: Optional[str] = None, tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP,
-) -> pd.DataFrame:
+def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME):
"""Wrangle resting ECG data to tidy.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap: The TensorMap to use for ECG input.
+ tmap_name: The name of the TMAP to use for ecg input.
Returns:
A pandas dataframe in tidy format or print a notebook-friendly error and return an empty dataframe.
@@ -59,7 +55,7 @@ def reshape_resting_ecg_to_tidy(
if folder is None:
folder = get_resting_ecg_hd5_folder(sample_id)
- data: Dict[str, Any] = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []}
+ data = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []}
with tempfile.TemporaryDirectory() as tmpdirname:
sample_hd5 = str(sample_id) + '.hd5'
@@ -73,10 +69,10 @@ def reshape_resting_ecg_to_tidy(
with h5py.File(local_path, mode='r') as hd5:
try:
- signals = tmap.tensor_from_file(tmap, hd5)
+ signals = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5)
except (KeyError, ValueError) as e:
- print(f'''Warning: Resting ECG TMAP {tmap.name} not available for sample {sample_id}.
- Use the tmap parameter to choose a different TMAP.\n\n{e}''')
+ print(f'''Warning: Resting ECG TMAP {tmap_name} not available for sample {sample_id}.
+ Use the tmap_name parameter to choose a different TMAP.\n\n{e}''')
_examine_available_keys(hd5)
return pd.DataFrame(data)
for (lead, channel) in ECG_REST_LEADS.items():
@@ -140,9 +136,7 @@ def reshape_resting_ecg_to_tidy(
return tidy_signal_df
-def reshape_exercise_ecg_to_tidy(
- sample_id: Union[int, str], folder: Optional[str] = None,
-) -> Tuple[pd.DataFrame, pd.DataFrame]:
+def reshape_exercise_ecg_to_tidy(sample_id, folder=None):
"""Wrangle exercise ECG signal data to tidy format.
Args:
@@ -214,9 +208,7 @@ def reshape_exercise_ecg_to_tidy(
return (trend_df, tidy_signal_df)
-def reshape_exercise_ecg_and_trend_to_tidy(
- sample_id: Union[int, str], folder: Optional[str] = None,
-) -> Tuple[pd.DataFrame, pd.DataFrame]:
+def reshape_exercise_ecg_and_trend_to_tidy(sample_id, folder=None):
"""Wrangle exercise ECG signal and trend data to tidy format.
Args:
diff --git a/ml4h/visualization_tools/ecg_static_plots.py b/ml4h/visualization_tools/ecg_static_plots.py
index ac7283237..2ebcfc3e1 100644
--- a/ml4h/visualization_tools/ecg_static_plots.py
+++ b/ml4h/visualization_tools/ecg_static_plots.py
@@ -1,18 +1,17 @@
"""Methods for integration of static plots within notebooks."""
import os
import tempfile
-from typing import List, Optional, Union
from IPython.display import HTML
from IPython.display import SVG
-import numpy as np
from ml4h.plots import plot_ecg_rest
from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder
from ml4h.runtime_data_defines import get_resting_ecg_svg_folder
+import numpy as np
import tensorflow as tf
-def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None) -> Union[HTML, SVG]:
+def display_resting_ecg(sample_id, folder=None):
"""Retrieve (or render) and display the SVG of the resting ECG.
Args:
@@ -54,8 +53,8 @@ def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None
try:
# We don't need the resulting SVG, so send it to a temporary directory.
with tempfile.TemporaryDirectory() as tmpdirname:
- return plot_ecg_rest(tensor_paths=[local_path], rows=[0], out_folder=tmpdirname, is_blind=False)
- except Exception as e: # pylint: disable=broad-except
+ plot_ecg_rest(tensor_paths = [local_path], rows=[0], out_folder=tmpdirname, is_blind=False)
+ except Exception as e:
return HTML(f'''
Warning: Unable to render static plot of resting ECG for sample {sample_id} from {hd5_folder}:
@@ -63,7 +62,7 @@ def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None
''')
-def major_breaks_x_resting_ecg(limits: List[float]) -> np.array:
+def major_breaks_x_resting_ecg(limits):
"""Method to compute breaks for plotnine plots of ECG resting data.
Args:
diff --git a/ml4h/visualization_tools/facets.py b/ml4h/visualization_tools/facets.py
index 18f96327d..a45ea88da 100644
--- a/ml4h/visualization_tools/facets.py
+++ b/ml4h/visualization_tools/facets.py
@@ -2,7 +2,6 @@
import base64
import os
-import pandas as pd
from facets_overview.generic_feature_statistics_generator import GenericFeatureStatisticsGenerator
FACETS_DEPENDENCIES = {
@@ -26,10 +25,10 @@
FACETS_DEPENDENCIES[dep] = os.path.basename(url)
-class FacetsOverview():
+class FacetsOverview(object):
"""Methods for Facets Overview notebook integration."""
- def __init__(self, data: pd.DataFrame):
+ def __init__(self, data):
# This takes the dataframe and computes all the inputs to the Facets
# Overview plots such as:
# - numeric variables: histogram bins, mean, min, median, max, etc..
@@ -40,7 +39,7 @@ def __init__(self, data: pd.DataFrame):
[{'name': 'data', 'table': data}],
)
- def _repr_html_(self) -> str:
+ def _repr_html_(self):
"""Html representation of Facets Overview for use in a Jupyter notebook."""
protostr = base64.b64encode(self._proto.SerializeToString()).decode('utf-8')
html_template = '''
@@ -58,14 +57,14 @@ def _repr_html_(self) -> str:
return html
-class FacetsDive():
+class FacetsDive(object):
"""Methods for Facets Dive notebook integration."""
- def __init__(self, data: pd.DataFrame, height: int = 1000):
+ def __init__(self, data, height=1000):
self._data = data
self.height = height
- def _repr_html_(self) -> str:
+ def _repr_html_(self):
"""Html representation of Facets Dive for use in a Jupyter notebook."""
html_template = """
diff --git a/ml4h/visualization_tools/hd5_mri_plots.py b/ml4h/visualization_tools/hd5_mri_plots.py
index d3894b39d..20b3305b1 100644
--- a/ml4h/visualization_tools/hd5_mri_plots.py
+++ b/ml4h/visualization_tools/hd5_mri_plots.py
@@ -1,34 +1,29 @@
"""Methods for integration of plots of mri data processed to 3D tensors from within notebooks."""
-from collections import OrderedDict
from enum import Enum, auto
import os
import tempfile
-from typing import Any, Dict, List, Optional, Tuple, Union
+import h5py
from IPython.display import display
from IPython.display import HTML
-import numpy as np
-import h5py
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_mri_hd5_folder
-import ml4h.tensormap.ukb.mri as ukb_mri
-import ml4h.tensormap.ukb.mri_vtk as ukb_mri_vtk
-from ml4h.TensorMap import Interpretation, TensorMap
+from ml4h.tensor_maps_by_hand import TMAPS
+from ml4h.TensorMap import Interpretation
+import numpy as np
import tensorflow as tf
-# Discover applicable TensorMaps.
-MRI_TMAPS = {
- key: value for key, value in ukb_mri.__dict__.items() if isinstance(value, TensorMap)
- and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3
-}
-MRI_TMAPS.update(
- {
- key: value for key, value in ukb_mri_vtk.__dict__.items()
- if isinstance(value, TensorMap) and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3
- },
+# Discover applicable TMAPS.
+CARDIAC_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if ('_lax_' in k or '_sax_' in k) and TMAPS[k].axes() == 3]
+CARDIAC_MRI_TMAP_NAMES.extend(
+ [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_cardiac_mri' and TMAPS[k].axes() == 3],
)
+LIVER_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_liver_mri' and TMAPS[k].axes() == 3]
+BRAIN_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_brain_mri' and TMAPS[k].axes() == 3]
+# This includes more than just MRI TMAPS, it is a best effort.
+BEST_EFFORT_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].interpretation == Interpretation.CONTINUOUS and TMAPS[k].axes() == 3]
MIN_IMAGE_WIDTH = 8
DEFAULT_IMAGE_WIDTH = 12
@@ -46,30 +41,42 @@ class PlotType(Enum):
class TensorMapCache:
"""Cache the tensor to display for reuse when re-plotting the same TMAP with different plot parameters."""
- def __init__(self, hd5: Dict[str, Any], tmap: TensorMap):
+ def __init__(self, hd5, tmap_name):
self.hd5 = hd5
- self.tmap: Optional[TensorMap] = None
+ self.tmap_name = None
self.tensor = None
- _ = self.get(tmap)
+ _ = self.get(tmap_name)
- def get(self, tmap: TensorMap) -> np.array:
- if self.tmap != tmap:
- self.tensor = tmap.tensor_from_file(tmap, self.hd5)
- self.tmap = tmap
+ def get(self, tmap_name):
+ if self.tmap_name != tmap_name:
+ self.tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], self.hd5)
+ self.tmap_name = tmap_name
return self.tensor
-def choose_mri_tmap(
- sample_id: Union[int, str], folder: Optional[str] = None, tmap: Optional[TensorMap] = None,
- default_tmaps: Dict[str, TensorMap] = MRI_TMAPS,
-) -> None:
+def choose_cardiac_mri_tmap(sample_id, folder=None, tmap_name='cine_lax_4ch_192', default_tmap_names=CARDIAC_MRI_TMAP_NAMES):
+ choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
+
+
+def choose_brain_mri_tmap(sample_id, folder=None, tmap_name='t2_flair_sag_p2_1mm_fs_ellip_pf78_1', default_tmap_names=BRAIN_MRI_TMAP_NAMES):
+ choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
+
+
+def choose_liver_mri_tmap(sample_id, folder=None, tmap_name='liver_shmolli_segmented', default_tmap_names=LIVER_MRI_TMAP_NAMES):
+ choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
+
+
+def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=BEST_EFFORT_MRI_TMAP_NAMES):
"""Render widgets and plots for MRI tensors.
Args:
sample_id: The id of the sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap: The TensorMap for the 3D MRI tensor to visualize.
- default_tmaps: Other TensorMaps to offer for visualization, if present in the hd5.
+ tmap_name: The TMAP name for the 3D MRI tensor to visualize.
+ default_tmap_names: Other TMAP names to offer for visualization, if present in the hd5.
+
+ Returns:
+ ipywidget or HTML upon error.
"""
if folder is None:
folder = get_mri_hd5_folder(sample_id)
@@ -81,45 +88,42 @@ def choose_mri_tmap(
tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path)
hd5 = h5py.File(local_path, mode='r')
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- display(
- HTML(f'''
+ return HTML(f'''
+
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}:
{e.message}
Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket.
-
'''),
- )
- return
-
- sample_tmaps = OrderedDict()
- # Add the passed tmap parameter, if it is present in this hd5.
- if tmap:
- if tmap.hd5_key_guess() in hd5:
- if len(tmap.shape) == 3:
- sample_tmaps[tmap.name] = tmap
+
''')
+
+ sample_tmap_names = []
+ # Add the passed tmap_name parameter, if it is present in this hd5.
+ if tmap_name:
+ if TMAPS[tmap_name].hd5_key_guess() in hd5:
+ if len(TMAPS[tmap_name].shape) == 3:
+ sample_tmap_names.append(tmap_name)
else:
- print(f'{tmap} is not a 3D tensor, skipping it')
+ print(f'{tmap_name} is not a 3D tensor, skipping it')
else:
- print(f'{tmap} is not available in {sample_id}')
- # Also discover applicable TensorMaps for this particular sample's HD5 file.
- sample_tmaps.update({n: t for n, t in sorted(default_tmaps.items(), key=lambda t: t[0]) if t.hd5_key_guess() in hd5})
-
- if not sample_tmaps:
- display(
- HTML(f'''
- Neither {tmap.name} nor any of {default_tmaps.keys()} are present in this HD5 for sample {sample_id} in {folder}.
- Use the tmap parameter to try a different TensorMap or the folder parameter to try a different hd5 for the sample.
-
'''),
- )
- return
-
- default_tmap_value = next(iter(sample_tmaps.values()))
+ print(f'{tmap_name} is not available in {sample_id}')
+ # Also discover applicable TMAPS for this particular sample's HD5 file.
+ sample_tmap_names.extend(
+ sorted(set([k for k in default_tmap_names if TMAPS[k].hd5_key_guess() in hd5])),
+ )
+
+ if not sample_tmap_names:
+ return HTML(f'''
+ Neither {tmap_name} nor any of {default_tmap_names} are present in this HD5 for sample {sample_id} in {folder}.
+ Use the tmap_name parameter to try a different TMAP or the folder parameter to try a different hd5 for the sample.
+
'),
- tmap_chooser,
+ tmap_name_chooser,
widgets.HBox([transpose_chooser, fig_width_chooser]),
widgets.HBox([flip_chooser, color_range_chooser]),
widgets.HBox([plot_type_chooser, instance_chooser]),
],
layout=widgets.Layout(width='auto', border='solid 1px grey'),
)
- tmap_cache = TensorMapCache(hd5=hd5, tmap=tmap_chooser.value)
+ tmap_cache = TensorMapCache(hd5=hd5, tmap_name=tmap_name_chooser.value)
viz_controls_output = widgets.interactive_output(
plot_mri_tmap,
{
'sample_id': widgets.fixed(sample_id),
'tmap_cache': widgets.fixed(tmap_cache),
- 'tmap': tmap_chooser,
+ 'tmap_name': tmap_name_chooser,
'plot_type': plot_type_chooser,
'instance': instance_chooser,
'color_range': color_range_chooser,
@@ -205,36 +209,33 @@ def on_plot_type_change(change):
else:
instance_chooser.layout.visibility = 'hidden'
- tmap_chooser.observe(on_tmap_value_change, names='value')
+ tmap_name_chooser.observe(on_tmap_value_change, names='value')
plot_type_chooser.observe(on_plot_type_change, names='value')
display(viz_controls_ui, viz_controls_output)
-def compute_color_range(hd5: Dict[str, Any], tmap: TensorMap) -> List[int]:
+def compute_color_range(hd5, tmap_name):
"""Compute the mean values for the color ranges of instances in the MRI series."""
- mri_tensor = tmap.tensor_from_file(tmap, hd5)
+ mri_tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5)
vmin = np.mean([np.min(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])])
vmax = np.mean([np.max(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])])
- return [vmin, vmax]
+ return[vmin, vmax]
-def compute_instance_range(tmap: TensorMap) -> Tuple[int, int]:
+def compute_instance_range(tmap_name):
"""Compute middle and max instances."""
- middle_instance = int(tmap.shape[2] / 2)
- max_instance = tmap.shape[2]
- return (middle_instance, max_instance)
+ middle_instance = int(TMAPS[tmap_name].shape[2] / 2)
+ max_instance = TMAPS[tmap_name].shape[2]
+ return(middle_instance, max_instance)
-def plot_mri_tmap(
- sample_id: Union[int, str], tmap_cache: TensorMapCache, tmap: TensorMap, plot_type: PlotType,
- instance: int, color_range: Tuple[int, int], transpose: bool, flip: bool, fig_width: int,
-) -> None:
+def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_range, transpose, flip, fig_width):
"""Visualize the applicable MRI series within this HD5 file.
Args:
sample_id: The local or Cloud Storage path to the MRI file.
tmap_cache: The cache from which to retrieve the tensor to be plotted.
- tmap: The chosen TensorMap for the MRI series.
+ tmap_name: The name of the chosen TMAP for the MRI series.
plot_type: Whether to display instances interactively or in a panel view.
instance: The particular instance to display, if interactive.
color_range: Array of minimum and maximum value for the color range.
@@ -242,9 +243,12 @@ def plot_mri_tmap(
flip: Whether to flip the image on its vertical axis
fig_width: The desired width of the figure. Note that height computed as
the proportion of the width based on the data to be plotted.
+
+ Returns:
+ The plot or a notebook-friendly error message.
"""
- title_prefix = f'{tmap.name} from MRI {sample_id}'
- mri_tensor = tmap_cache.get(tmap)
+ title_prefix = f'{tmap_name} from MRI {sample_id}'
+ mri_tensor = tmap_cache.get(tmap_name)
if plot_type == PlotType.INTERACTIVE:
plot_mri_tensor_as_animation(
mri_tensor=mri_tensor,
@@ -271,13 +275,10 @@ def plot_mri_tmap(
title_prefix=title_prefix,
)
else:
- HTML(f'''
\n",
- " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608)\n",
+ " Terra Users test with the most recent custom Docker image which has all the software dependencies preinstalled. (e.g., more recent than gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732)\n",
"
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
\n",
+ "
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
\n",
"
ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
\n",
"
\n",
"
"
@@ -79,7 +79,7 @@
"source": [
"#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n",
"# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n",
- "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'"
+ "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'"
]
},
{
diff --git a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb
deleted file mode 100644
index ce15e4d73..000000000
--- a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb
+++ /dev/null
@@ -1,259 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Image annotations for a batch of samples\n",
- "\n",
- "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Setup\n",
- "\n",
- "
\n",
- " This notebook assumes\n",
- "
\n",
- "
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
\n",
- "
ml4cvd is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
\n",
+ "
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200729_091732.
\n",
"
ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
\n",
"
\n",
"
"
@@ -84,7 +84,7 @@
"source": [
"#---[ EDIT THIS VARIABLE VALUE IF YOU LIKE ]---\n",
"# TODO(paolo and team): provide CSV with phenotypes and ML results for fake samples.\n",
- "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4cvd/ukbiobank_query_results_plus_four_fake_samples.csv'"
+ "MODEL_RESULTS_FILE = 'gs://uk-biobank-sek-data-us-east1/phenotypes/ml4h/ukbiobank_query_results_plus_four_fake_samples.csv'"
]
},
{
diff --git a/pylintrc b/pylintrc
deleted file mode 100644
index 8a5e40122..000000000
--- a/pylintrc
+++ /dev/null
@@ -1,337 +0,0 @@
-# This configuration was copied from https://github.com/tensorflow/tensorflow/blob/18ebe824d2f6f20b09839cb0a0073032a2d6c5fe/tensorflow/tools/ci_build/pylintrc and then further modified.
-
-[MASTER]
-
-# Specify a configuration file.
-#rcfile=
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Profiled execution.
-profile=no
-
-# Add files or directories to the denylist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-
-[MESSAGES CONTROL]
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time. See also the "--disable" option for examples.
-enable=indexing-exception,old-raise-syntax
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager
-
-
-# Set the cache size for astng objects.
-cache-size=500
-
-
-[REPORTS]
-
-# Set the output format. Available formats are text, parseable, colorized, msvs
-# (visual studio) and html. You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Put messages in a separate file for each module / package specified on the
-# command line instead of printing them on stdout. Reports (if any) will be
-# written in a file name "pylint_global.[txt|html]".
-files-output=no
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Add a comment according to your evaluation note. This is used by the global
-# evaluation report (RP0004).
-comment=no
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-
-[TYPECHECK]
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# List of classes names for which member attributes should not be checked
-# (useful for classes with attributes dynamically set).
-ignored-classes=SQLObject
-
-# When zope mode is activated, add a predefined set of Zope acquired attributes
-# to generated-members.
-zope=no
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E0201 when accessed. Python regular
-# expressions are accepted.
-generated-members=REQUEST,acl_users,aq_parent
-
-# List of decorators that create context managers from functions, such as
-# contextlib.contextmanager.
-contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
-
-
-[VARIABLES]
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# A regular expression matching the beginning of the name of dummy variables
-# (i.e. not used).
-dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-
-[BASIC]
-
-# Required attributes for module, separated by a comma
-required-attributes=
-
-# List of builtins function names that should not be used, separated by a comma
-bad-functions=apply,input,reduce
-
-
-# Disable the report(s) with the given id(s).
-# All non-Google reports are disabled by default.
-disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923
-
-# Regular expression which should only match correct module names
-module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-
-# Regular expression which should only match correct module level names
-const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
-
-# Regular expression which should only match correct class names
-class-rgx=^_?[A-Z][a-zA-Z0-9]*$
-
-# Regular expression which should only match correct function names
-function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
-
-# Regular expression which should only match correct method names
-method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
-
-# Regular expression which should only match correct instance attribute names
-attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
-
-# Regular expression which should only match correct argument names
-argument-rgx=^[a-z][a-z0-9_]*$
-
-# Regular expression which should only match correct variable names
-variable-rgx=^[a-z][a-z0-9_]*$
-
-# Regular expression which should only match correct attribute names in class
-# bodies
-class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
-
-# Regular expression which should only match correct list comprehension /
-# generator expression variable names
-inlinevar-rgx=^[a-z][a-z0-9_]*$
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=main,_
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=(__.*__|main)
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=10
-
-
-[FORMAT]
-
-# Maximum number of characters on a single line.
-max-line-length=120
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=(?x)
- (^\s*(import|from)\s
- |\$Id:\s\/\/depot\/.+#\d+\s\$
- |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+')
- |^\s*\#\ LINT\.ThenChange
- |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$
- |pylint
- |"""
- |\#
- |lambda
- |(https?|ftp):)
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=y
-
-# List of optional constructs for which whitespace checking is disabled
-no-space-check=
-
-# Maximum number of lines in a module
-max-module-lines=99999
-
-# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
-# tab).
-indent-string=' '
-
-
-[SIMILARITIES]
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=
-
-
-[IMPORTS]
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-
-[CLASSES]
-
-# List of interface methods to ignore, separated by a comma. This is used for
-# instance to not check methods defines in Zope's Interface base class.
-ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,__new__,setUp
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls,class_
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=5
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*
-
-# Maximum number of locals for function / method body
-max-locals=15
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of branch for function / method body
-max-branches=12
-
-# Maximum number of statements in function / method body
-max-statements=50
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=Exception,StandardError,BaseException
-
-
-[AST]
-
-# Maximum line length for lambdas
-short-func-length=1
-
-# List of module members that should be marked as deprecated.
-# All of the string functions are listed in 4.1.4 Deprecated string functions
-# in the Python 2.4 docs.
-deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc
-
-
-[DOCSTRING]
-
-# List of exceptions that do not need to be mentioned in the Raises section of
-# a docstring.
-ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError
-
-
-
-[TOKENS]
-
-# Number of spaces of indent required when the last token on the preceding line
-# is an open (, [, or {.
-indent-after-paren=4
-
-
-[GOOGLE LINES]
-
-# Regexp for a proper copyright notice.
-copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\.
diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh
index 32edfb1f3..a83fd582d 100755
--- a/scripts/jupyter.sh
+++ b/scripts/jupyter.sh
@@ -54,7 +54,7 @@ while getopts ":ip:ch" opt ; do
;;
c)
DOCKER_IMAGE=${DOCKER_IMAGE_NO_GPU}
- GPU_DEVICE=""
+ GPU_DEVICE=""
;;
:)
echo "ERROR: Option -${OPTARG} requires an argument." 1>&2
@@ -99,7 +99,6 @@ ${DOCKER_COMMAND} run -it \
${GPU_DEVICE} \
--rm \
--ipc=host \
---hostname=$(hostname) \
-v /home/${USER}/:/home/${USER}/ \
-v /mnt/:/mnt/ \
-p 0.0.0.0:${PORT}:${PORT} \
diff --git a/tests/test_models.py b/tests/test_models.py
index 61df950cb..976a6b846 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -7,7 +7,8 @@
from typing import List, Optional, Dict, Tuple, Iterator
from ml4h.TensorMap import TensorMap
-from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, check_no_bottleneck
+from ml4h.models import make_multimodal_multitask_model, parent_sort, BottleneckType, ACTIVATION_FUNCTIONS, MODEL_EXT, train_model_from_generators, \
+ check_no_bottleneck, make_paired_autoencoder_model
from ml4h.test_utils import TMAPS_UP_TO_4D, MULTIMODAL_UP_TO_4D, CATEGORICAL_TMAPS, CONTINUOUS_TMAPS, SEGMENT_IN, SEGMENT_OUT, PARENT_TMAPS, CYCLE_PARENTS
from ml4h.test_utils import LANGUAGE_TMAP_1HOT_WINDOW, LANGUAGE_TMAP_1HOT_SOFTMAX
@@ -18,14 +19,13 @@
'dense_layers': [4, 2],
'dense_blocks': [5, 3],
'block_size': 3,
- 'conv_width': 3,
'learning_rate': 1e-3,
'optimizer': 'adam',
'conv_type': 'conv',
'conv_layers': [6, 5, 3],
- 'conv_x': [3],
- 'conv_y': [3],
- 'conv_z': [2],
+ 'conv_x': [3]*5,
+ 'conv_y': [3]*5,
+ 'conv_z': [2]*5,
'padding': 'same',
'max_pools': [],
'pool_type': 'max',
@@ -39,6 +39,16 @@
'dense_regularize_rate': .1,
'dense_normalize': 'batch_norm',
'bottleneck_type': BottleneckType.FlattenRestructure,
+ 'pair_loss': 'cosine',
+ 'training_steps': 12,
+ 'learning_rate': 0.00001,
+ 'epochs': 6,
+ 'optimizer': 'adam',
+ 'learning_rate_schedule': None,
+ 'model_layers': None,
+ 'model_file': None,
+ 'hidden_layer': 'embed',
+ 'u_connect': {},
}
@@ -54,19 +64,20 @@ def make_training_data(input_tmaps: List[TensorMap], output_tmaps: List[TensorMa
), ])
-def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None):
+def assert_model_trains(input_tmaps: List[TensorMap], output_tmaps: List[TensorMap], m: Optional[tf.keras.Model] = None, skip_shape_check: bool = False):
if m is None:
m = make_multimodal_multitask_model(
input_tmaps,
output_tmaps,
**DEFAULT_PARAMS,
)
- for tmap, tensor in zip(input_tmaps, m.inputs):
- assert tensor.shape[1:] == tmap.shape
- assert tensor.shape[1:] == tmap.shape
- for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs):
- assert tensor.shape[1:] == tmap.shape
- assert tensor.shape[1:] == tmap.shape
+ if not skip_shape_check:
+ for tmap, tensor in zip(input_tmaps, m.inputs):
+ assert tensor.shape[1:] == tmap.shape
+ assert tensor.shape[1:] == tmap.shape
+ for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs):
+ assert tensor.shape[1:] == tmap.shape
+ assert tensor.shape[1:] == tmap.shape
data = make_training_data(input_tmaps, output_tmaps)
history = m.fit(data, steps_per_epoch=2, epochs=2, validation_data=data, validation_steps=2)
for tmap in output_tmaps:
@@ -294,8 +305,8 @@ def test_parents(self, output_tmaps):
def test_language_models(self, input_output_tmaps, tmpdir):
params = DEFAULT_PARAMS.copy()
m = make_multimodal_multitask_model(
- input_output_tmaps[0],
- input_output_tmaps[1],
+ tensor_maps_in=input_output_tmaps[0],
+ tensor_maps_out=input_output_tmaps[1],
**params
)
assert_model_trains(input_output_tmaps[0], input_output_tmaps[1], m)
@@ -309,6 +320,36 @@ def test_language_models(self, input_output_tmaps, tmpdir):
**DEFAULT_PARAMS,
)
+ @pytest.mark.parametrize(
+ 'pairs',
+ [
+ [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1])],
+ [(CATEGORICAL_TMAPS[2], CATEGORICAL_TMAPS[1])],
+ [(CONTINUOUS_TMAPS[2], CONTINUOUS_TMAPS[1]), (CONTINUOUS_TMAPS[2], CATEGORICAL_TMAPS[3])]
+ ],
+ )
+ def test_paired_models(self, pairs, tmpdir):
+ params = DEFAULT_PARAMS.copy()
+ pair_list = list(set([p[0] for p in pairs] + [p[1] for p in pairs]))
+ params['u_connect'] = {tm: [] for tm in pair_list}
+ m, encoders, decoders = make_paired_autoencoder_model(
+ pairs=pairs,
+ tensor_maps_in=pair_list,
+ tensor_maps_out=pair_list,
+ **params
+ )
+ assert_model_trains(pair_list, pair_list, m, skip_shape_check=True)
+ m.save(os.path.join(tmpdir, 'paired_ae.h5'))
+ path = os.path.join(tmpdir, f'm{MODEL_EXT}')
+ m.save(path)
+ make_paired_autoencoder_model(
+ pairs=pairs,
+ tensor_maps_in=pair_list,
+ tensor_maps_out=pair_list,
+ **params
+ )
+
+
@pytest.mark.parametrize(
'tmaps',
[_rotate(PARENT_TMAPS, i) for i in range(len(PARENT_TMAPS))],
diff --git a/tests/test_recipes.py b/tests/test_recipes.py
index 468dc44c6..df8f4a1e3 100644
--- a/tests/test_recipes.py
+++ b/tests/test_recipes.py
@@ -3,7 +3,7 @@
import pandas as pd
import numpy as np
-from ml4h.recipes import inference_file_name, hidden_inference_file_name
+from ml4h.recipes import inference_file_name, _hidden_file_name
from ml4h.recipes import train_multimodal_multitask, compare_multimodal_multitask_models
from ml4h.recipes import infer_multimodal_multitask, infer_hidden_layer_multimodal_multitask
from ml4h.recipes import compare_multimodal_scalar_task_models, _find_learning_rate
@@ -42,7 +42,7 @@ def test_infer_genetics(self, default_arguments):
def test_infer_hidden(self, default_arguments):
infer_hidden_layer_multimodal_multitask(default_arguments)
- tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id)
+ tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id)
inferred = pd.read_csv(tsv, sep='\t')
assert len(set(inferred['sample_id'])) == pytest.N_TENSORS
@@ -50,7 +50,7 @@ def test_infer_hidden_genetics(self, default_arguments):
default_arguments.tsv_style = 'genetics'
infer_hidden_layer_multimodal_multitask(default_arguments)
default_arguments.tsv_style = 'standard'
- tsv = hidden_inference_file_name(default_arguments.output_folder, default_arguments.id)
+ tsv = _hidden_file_name(default_arguments.output_folder, default_arguments.id)
inferred = pd.read_csv(tsv, sep='\t')
assert len(set(inferred['FID'])) == pytest.N_TENSORS
From 46fe794081c3d6b42bc20cd7e97ac1ff1d508dad Mon Sep 17 00:00:00 2001
From: Samwell Freeman
Date: Tue, 29 Sep 2020 17:15:39 -0400
Subject: [PATCH 02/21] paired
---
CONTRIBUTING.md | 192 ++++++++++
.../batch_image_annotations.py | 236 ++++++++++++
.../review_results/image_annotations.ipynb | 268 ++++++++++++++
.../image_annotations_demo.ipynb | 259 ++++++++++++++
pylintrc | 337 ++++++++++++++++++
5 files changed, 1292 insertions(+)
create mode 100644 CONTRIBUTING.md
create mode 100644 ml4h/visualization_tools/batch_image_annotations.py
create mode 100644 notebooks/review_results/image_annotations.ipynb
create mode 100644 notebooks/terra_featured_workspace/image_annotations_demo.ipynb
create mode 100644 pylintrc
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..2e2006c07
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,192 @@
+# Contributing
+
+1. Before making a substantial pull request, consider first [filing an issue](https://github.com/broadinstitute/ml/issues) describing the feature addition or change you wish to make.
+1. [Get setup](#setup-for-code-contributions)
+1. [Follow the coding style](#python-coding-style)
+1. [Test your code](#testing)
+1. Send a [pull request](https://github.com/broadinstitute/ml/pulls)
+
+## Setup for code contributions
+
+### Get setup for GitHub
+
+Small typos in code or documentation may be edited directly using the GitHub web interface. Otherwise:
+
+1. If you are new to GitHub, don't start here. Instead, work through a GitHub tutorial such as https://guides.github.com/activities/hello-world/.
+1. Create a fork of https://github.com/broadinstitute/ml
+1. Clone your fork.
+1. Work from a feature branch. See the [Appendix](#appendix) for detailed `git` commands.
+
+### Install precommit
+
+[`pre-commit`](https://pre-commit.com/) is a framework for managing and maintaining multi-language pre-commit hooks.
+
+```
+# Install pre-commit
+pip3 install pre-commit
+# Install the git hook scripts by running this within the git clone directory
+cd ${HOME}/ml
+pre-commit install
+```
+
+See [.pre-commit-config.yaml](https://github.com/broadinstitute/ml/blob/master/.pre-commit-config.yaml) for the currently configured pre-commit hooks for ml4cvd.
+
+### Install git-secrets
+
+```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our
+repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be
+installed with Homebrew by running ```brew install git-secrets```.
+
+To add hooks to all repositories that you initialize or clone in the future:
+
+```git secrets --install --global```
+
+To add hooks to all local repositories:
+
+```
+git secrets --install ~/.git-templates/git-secrets
+git config --global init.templateDir ~/.git-templates/git-secrets
+```
+
+We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid
+committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to
+```git_secrets_provider_ml4cvd.txt``` to match types of critical data that are not currently covered by the patterns in that
+file. To register the patterns in this file with ```git-secrets```:
+
+```
+git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4cvd.txt
+```
+
+### Install pylint
+
+[`pylint`](https://www.pylint.org/) is a Python static code analysis tool which looks for programming errors, helps enforcing a coding standard, sniffs for code smells and offers simple refactoring suggestions.
+
+```
+# Install pylint
+pip3 install pylint
+```
+
+See [pylintrc](https://github.com/broadinstitute/ml/blob/master/pylintrc) for the current lint configuration for ml4cvd.
+
+# Python coding style
+
+Changes to ml4cvd should conform to [PEP 8 -- Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/). See also [Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) as another decription of this coding style.
+
+Use `pylint` to check your Python changes:
+
+```bash
+pylint --rcfile=${HOME}/ml/pylintrc myfile.py
+```
+
+Any messages returned by `pylint` are intended to be self-explanatory, but that isn't always the case.
+
+* Search for `pylint ` or `pylint ` for more details on the recommended code change to resolve the lint issue.
+* Or add comment `# pylint: disable=` to the end of the line of code.
+
+# Testing
+
+## Testing of `recipes`
+
+Unit tests can be run in Docker with
+```
+${HOME}/ml/scripts/tf.sh -T ${HOME}/ml/tests
+```
+Unit tests can be run locally in a conda environment with
+```
+python -m pytest ${HOME}/ml/tests
+```
+Some of the unit tests are slow due to creating, saving and loading `tensorflow` models.
+To skip those tests to move quickly, run
+```
+python -m pytest ${HOME}/ml/tests -m "not slow"
+```
+pytest can also run specific tests using `::`. For example
+
+```
+python -m pytest ${HOME}/ml/tests/test_models.py::TestMakeMultimodalMultitaskModel::test_u_connect_segment
+```
+
+For more pytest usage information, checkout the [usage guide](https://docs.pytest.org/en/latest/usage.html).
+
+## Testing of `visualization_tools`
+
+The code in [ml4cvd/visualization_tools](https://github.com/broadinstitute/ml/tree/master/ml4cvd/visualization_tools) is primarily interactive so we add test cases to notebook [test_error_handling_for_notebook_visualizations.ipynb](https://github.com/broadinstitute/ml/blob/master/notebooks/review_results/test_error_handling_for_notebook_visualizations.ipynb) and visually inspect the output of `Cells -> Run all`.
+
+# Appendix
+
+For the ml4cvd GitHub repository, we are doing ‘merge and squash’ of pull requests. So that means your fork does not match upstream after your pull request has been merged. The easiest way to manage this is to always work in a feature branch, instead of checking changes into your fork’s master branch.
+
+
+## How to work on a new feature
+
+(1) Get the latest version of the upstream repo
+
+```
+git fetch upstream
+```
+
+Note: If you get an error saying that upstream is unknown, run the following remote add command and then re-run the fetch command. You only need to do this once per git clone.
+
+```
+git remote add upstream https://github.com/broadinstitute/ml.git
+```
+
+(2) Make sure your master branch is “even” with upstream.
+
+```
+git checkout master
+git merge --ff-only upstream/master
+git push
+```
+
+Now the master branch of your fork on GitHub should say *"This branch is even with broadinstitute:master."*.
+
+
+(3) Create a feature branch for your change.
+
+```
+git checkout -b my-feature-branch-name
+```
+
+Because you created this feature branch from your master branch that was up to date with upstream (step 2), your feature branch is also up to date with upstream. Commit your changes to this branch until you are happy with them.
+
+(4) Push your changes to GitHub and send a pull request.
+
+```
+git push --set-upstream origin my-feature-branch-name
+```
+
+After your pull request is merged, its safe to delete your branch!
+
+## I accidentally checked a new change to my master branch instead of a feature branch. How to fix this?
+
+(1) Soft undo your change(s). This leaves the changes in the files on disk but undoes the commit.
+
+```
+git checkout master
+# Moves pointer back to previous HEAD
+git reset --soft HEAD@{1}
+```
+
+Or if you need to move back several commits to the most recent one in common with upstream, you can change ‘1’ to be however many commits back you need to go.
+
+(2) “stash” your now-unchecked-in changes so that you can get them back later.
+
+```
+git stash
+```
+
+(3) Now do the [How to work on a new feature](#how-to-work-on-a-new-feature) step to bring master up to date and create your new feature branch that is “even” with upstream. Here are those commands again:
+
+```
+git fetch upstream
+git merge --ff-only upstream/master
+git checkout -b my-feature-branch-name
+```
+
+(4) “unstash” your changes.
+
+```
+git stash pop
+```
+Now you can proceed with your work!
diff --git a/ml4h/visualization_tools/batch_image_annotations.py b/ml4h/visualization_tools/batch_image_annotations.py
new file mode 100644
index 000000000..34ff731df
--- /dev/null
+++ b/ml4h/visualization_tools/batch_image_annotations.py
@@ -0,0 +1,236 @@
+"""Methods for batch annotations of images stored as 3D tensors, such as MRIs, from within notebooks."""
+
+import json
+import os
+import socket
+import tempfile
+from typing import Any, Dict, List
+
+from IPython.display import display
+import numpy as np
+import pandas as pd
+import h5py
+from ipyannotations import PolygonAnnotator
+import ipywidgets as widgets
+from ml4h.visualization_tools.hd5_mri_plots import MRI_TMAPS
+from ml4h.visualization_tools.annotation_storage import AnnotationStorage
+from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage
+from PIL import Image
+import tensorflow as tf
+
+
+class BatchImageAnnotator():
+ """Annotate batches of images with polygons drawn over regions of interest."""
+
+ SUBMIT_BUTTON_DESCRIPTION = 'Submit polygons, goto next sample'
+ USE_INSTRUCTIONS = '''
+
+
To draw a polygon, click anywhere you'd like to start. Continue to click
+ along the edge of the polygon until arrive back where you started. To
+ finish, simply click the first point (highlighted in red). It may be
+ helpful to increase the point size if you're struggling (using the slider).
+
+
You can change the class of a polygon using the dropdown menu while the
+ polygon is still "open", or unfinished. If you make a mistake, use the Undo
+ button until the point that's wrong has disappeared.
+
+
You can move, but not add / subtract polygon points, by clicking the "Edit"
+ button. Simply drag a point you want to adjust. Again, if you have
+ difficulty aiming at the points, you can increase the point size.
+
+
You can increase or decrease the contrast and brightness of the image
+ using the sliders to make it easier to annotate. Sometimes you need to see
+ what's behind already-created annotations, and for this purpose you can
+ make them more see-through using the "Opacity" slider.
+
+ '''
+ EXPECTED_COLUMN_NAMES = ['sample_id', 'tmap_name', 'instance_number', 'folder']
+ DEFAULT_ANNOTATION_CLASSNAME = 'region_of_interest'
+ CSS = '''
+
+ '''
+
+ def __init__(
+ self, samples: pd.DataFrame, annotation_categories: List[str] = None,
+ zoom: float = 1.5, annotation_storage: AnnotationStorage = TransientAnnotationStorage(),
+ ):
+ """Initializes an instance of BatchImageAnnotator.
+
+ Args:
+ samples: A dataframe of samples to annotate. Columns must include those
+ in BatchImageAnnotator.EXPECTED_COLUMN_NAMES.
+ annotation_categories: A list of one or more strings to serve as tags for the polygons.
+ zoom: Desired zoom level for the image.
+ annotation_storage: An instance of AnnotationStorage. This faciltates the use of a user-provided
+ strategy for the storage and processing of annotations.
+
+ Raises:
+ ValueError: The provided dataframe does not contain the expected columns.
+ """
+ if not set(self.EXPECTED_COLUMN_NAMES).issubset(samples.columns):
+ raise ValueError(f'samples Dataframe must contain columns {self.EXPECTED_COLUMN_NAMES}')
+ self.samples = samples
+ self.current_sample = 0
+ # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11
+ self.zoom = zoom
+ self.annotation_storage = annotation_storage
+ if annotation_categories is None:
+ annotation_categories = [self.DEFAULT_ANNOTATION_CLASSNAME]
+
+ self.annotation_widget = PolygonAnnotator(
+ options=annotation_categories,
+ canvas_size=(900, 280 * self.zoom),
+ )
+ self.annotation_widget.on_submit(self._store_annotations)
+ self.annotation_widget.submit_button.description = self.SUBMIT_BUTTON_DESCRIPTION
+ self.annotation_widget.submit_button.layout = widgets.Layout(width='300px')
+
+ self.title_widget = widgets.HTML('')
+ self.results_widget = widgets.HTML('')
+
+ def _store_annotations(self, data: Dict[Any, Any]) -> None:
+ """Transfer widget state to the annotation storage and advance to the next sample."""
+ if self.current_sample >= self.samples.shape[0]:
+ self.results_widget.value = '
Annotation batch complete!
Thank you for making the model better.'
+ return
+
+ # Convert polygon points in canvas coordinates to tensor coordinates.
+ image_canvas_position = self.annotation_widget.canvas.image_extent
+ x_offset, y_offset, _, _ = image_canvas_position
+ tensor_coords = [
+ (
+ a['label'],
+ [(
+ int((p[0] - x_offset) / self.zoom),
+ int((p[1] - y_offset) / self.zoom),
+ ) for p in a['points']],
+ ) for a in data
+ ]
+ # Store the annotation using the provided annotation storage strategy.
+ self.annotation_storage.submit_annotation(
+ sample_id=self.samples.loc[self.current_sample, 'sample_id'],
+ annotator=os.getenv('OWNER_EMAIL') if os.getenv('OWNER_EMAIL') else socket.gethostname(),
+ key=self.samples.loc[self.current_sample, 'tmap_name'],
+ value_numeric=self.samples.loc[self.current_sample, 'instance_number'],
+ value_string=self.samples.loc[self.current_sample, 'folder'],
+ comment=json.dumps(tensor_coords),
+ )
+
+ # Display this annotation at the bottom of the widget.
+ results = f'''
+
+
Prior sample's submitted annotations
+ The {self.SUBMIT_BUTTON_DESCRIPTION} button is both printing out the polygons below and storing the polygons
+ via strategy {self.annotation_storage.__class__.__name__}.
+ Details: {self.annotation_storage.describe()}
+
sample info
+ {self._format_info_for_current_sample()}
+
canvas coordinates
+ image extent {image_canvas_position}
+ {[f'
{json.dumps(x)}
' for x in data]}
+
source tensor coordinates
+ {[f'
{json.dumps(x)}
' for x in tensor_coords]}
+
+ '''
+ self.results_widget.value = results
+
+ # Advance to the next sample.
+ self.current_sample += 1
+ self._annotate_image_for_current_sample()
+
+ def _format_info_for_current_sample(self) -> str:
+ """Convert information about the current sample to an HTML table for display within the widget."""
+ headings = ' '.join([f'
+ '''
+
+ def _annotate_image_for_current_sample(self) -> None:
+ """Retrieve the data for the current sample and display its image in the annotation widget.
+
+ If all samples have been processed, display the completion message.
+ """
+ if self.current_sample >= self.samples.shape[0]:
+ self.annotation_widget.canvas.clear()
+ # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this
+ # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15
+ self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access
+ self.title_widget.value = '
Annotation batch complete!
Thank you for making the model better.'
+ return
+
+ sample_id = self.samples.loc[self.current_sample, 'sample_id']
+ tmap_name = self.samples.loc[self.current_sample, 'tmap_name']
+ instance_number = self.samples.loc[self.current_sample, 'instance_number']
+ folder = self.samples.loc[self.current_sample, 'folder']
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ sample_hd5 = str(sample_id) + '.hd5'
+ local_path = os.path.join(tmpdirname, sample_hd5)
+ try:
+ tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path)
+ hd5 = h5py.File(local_path, mode='r')
+ except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
+ self.annotation_widget.canvas.clear()
+ # Note: the above command clears the canvas, but any incomplete polygons will be redrawn. Call this
+ # private method to clear those. TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/15
+ self.annotation_widget.canvas._init_empty_data() # pylint: disable=protected-access
+ self.title_widget.value = f'''
+
+
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}
+ Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket.
+
{e.message}
+
+ '''
+ return
+
+ tensor = MRI_TMAPS[tmap_name].tensor_from_file(MRI_TMAPS[tmap_name], hd5)
+ tensor_instance = tensor[:, :, instance_number]
+ if self.zoom > 1.0:
+ # TODO(deflaux) remove this after https://github.com/janfreyberg/ipyannotations/issues/11
+ img = Image.fromarray(tensor_instance)
+ zoomed_img = img.resize([int(self.zoom * s) for s in img.size], Image.LANCZOS)
+ tensor_instance = np.asarray(zoomed_img)
+
+ self.annotation_widget.display(tensor_instance)
+ self.title_widget.value = f'''
+ {self.CSS}
+
+
Batch annotation of {self.samples.shape[0]} samples
+ {self.USE_INSTRUCTIONS}
+
+
Current sample
+ {self._format_info_for_current_sample()}
+
+ '''
+
+ def annotate_images(self) -> None:
+ """Begin the batch annotation task by displaying the annotation widget populated with the first sample.
+
+ The submit button is used to proceed to the next sample until all samples have been processed.
+ """
+ self._annotate_image_for_current_sample()
+ display(widgets.VBox([self.title_widget, self.annotation_widget, self.results_widget]))
+
+ def view_recent_submissions(self, count: int = 10) -> pd.DataFrame:
+ """View a dataframe of up to [count] most recent submissions.
+
+ Args:
+ count: The number of the most recent submissions to return.
+
+ Returns:
+ A dataframe of the most recent annotations.
+ """
+ return self.annotation_storage.view_recent_submissions(count=count)
diff --git a/notebooks/review_results/image_annotations.ipynb b/notebooks/review_results/image_annotations.ipynb
new file mode 100644
index 000000000..6e644f15a
--- /dev/null
+++ b/notebooks/review_results/image_annotations.ipynb
@@ -0,0 +1,268 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image annotations for a batch of samples\n",
+ "\n",
+ "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setup\n",
+ "\n",
+ "
\n",
+ " This notebook assumes\n",
+ "
\n",
+ "
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
\n",
+ "
ml4h is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
\n",
+ " Edit the CSV file path below, if needed, to either a local file or one in Cloud Storage.\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#---[ EDIT AND RUN THIS CELL TO READ FROM A LOCAL FILE OR A FILE IN CLOUD STORAGE ]---\n",
+ "SAMPLE_BATCH_FILE = None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if SAMPLE_BATCH_FILE:\n",
+ " samples_df = pd.read_csv(tf.io.gfile.GFile(SAMPLE_BATCH_FILE))\n",
+ "\n",
+ "else:\n",
+ " # Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n",
+ " # more obvious in this demo that we are processing different samples.\n",
+ " samples_df = pd.DataFrame(\n",
+ " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n",
+ " data=[\n",
+ " [1655349, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n",
+ " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n",
+ " [1655349, 'cine_lax_4ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n",
+ " [1655349, 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n",
+ " [2403657, 'cine_lax_3ch_192', 25, 'gs://ml4cvd/deflaux/ukbb_tensors/'],\n",
+ " ])\n",
+ "\n",
+ "samples_df.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "samples_df.head(n = 10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Annotate the batch! "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n",
+ "# use the PIL library to scale the image.\n",
+ "\n",
+ "annotator = BatchImageAnnotator(samples=samples_df,\n",
+ " zoom=2.0,\n",
+ " annotation_categories=['region_of_interest'],\n",
+ " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n",
+ "annotator.annotate_images()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# View the stored annotations "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "annotator.view_recent_submissions(count=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Provenance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "print(datetime.datetime.now())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "pip3 freeze"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.8"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "199px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/terra_featured_workspace/image_annotations_demo.ipynb b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb
new file mode 100644
index 000000000..ce15e4d73
--- /dev/null
+++ b/notebooks/terra_featured_workspace/image_annotations_demo.ipynb
@@ -0,0 +1,259 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image annotations for a batch of samples\n",
+ "\n",
+ "Using this notebook, cardiologists are able to quickly view and annotate MRI images for a batch of samples. These annotated images become the training data for the next round of modeling."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setup\n",
+ "\n",
+ "
\n",
+ " This notebook assumes\n",
+ "
\n",
+ "
Terra is running custom Docker image gcr.io/uk-biobank-sek-data/ml4h_terra:20200918_091608.
\n",
+ "
ml4cvd is running custom Docker image gcr.io/broad-ml4cvd/deeplearning:tf2-latest-gpu.
\n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "attachments": {
+ "Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png": {
+ "image/png": ""
+ }
+ },
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "![Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png](attachment:Screen%20Shot%202020-06-22%20at%202.50.48%20PM.png)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ml4cvd.visualization_tools.batch_image_annotations import BatchImageAnnotator\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "code_folding": []
+ },
+ "outputs": [],
+ "source": [
+ "%%javascript\n",
+ "// Display cell outputs to full height (no vertical scroll bar)\n",
+ "IPython.OutputArea.auto_scroll_threshold = 9999;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pd.set_option('display.max_colwidth', -1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Define the batch of samples to annotate\n",
+ "\n",
+ "In general, we would read in a CSV file but for this demo we define the batch right here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Normally these would all be the same or similar TMAP. We are using different ones here just to make it\n",
+ "# more obvious in this demo that we are processing different samples.\n",
+ "samples_df = pd.DataFrame(\n",
+ " columns=BatchImageAnnotator.EXPECTED_COLUMN_NAMES,\n",
+ " data=[\n",
+ " ['fake_1', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n",
+ " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_1', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n",
+ " ['fake_1', 'cine_lax_4ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n",
+ " ['fake_1', 't2_flair_sag_p2_1mm_fs_ellip_pf78_2', 50, 'gs://ml4cvd/projects/fake_hd5s/'],\n",
+ " ['fake_2', 'cine_lax_3ch_192', 25, 'gs://ml4cvd/projects/fake_hd5s/'],\n",
+ " ])\n",
+ "\n",
+ "samples_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Annotate the batch! "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n",
+ "# use the PIL library to scale the image.\n",
+ "\n",
+ "annotator = BatchImageAnnotator(samples=samples_df, zoom=2.0, annotation_categories=['region_of_interest'])\n",
+ "annotator.annotate_images()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## via BigQuery annotation storage "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ml4cvd.visualization_tools.annotation_storage import BigQueryAnnotationStorage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "BIG_QUERY_ANNOTATIONS_STORAGE = BigQueryAnnotationStorage('uk-biobank-sek-data.ml_results.annotations')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Note: a zoom level of 1.0 displays the tensor as-is. For higher zoom levels, this code currently\n",
+ "# use the PIL library to scale the image.\n",
+ "\n",
+ "annotator = BatchImageAnnotator(samples=samples_df,\n",
+ " zoom=2.0,\n",
+ " annotation_categories=['region_of_interest'],\n",
+ " annotation_storage=BIG_QUERY_ANNOTATIONS_STORAGE)\n",
+ "annotator.annotate_images()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# View the stored annotations "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "annotator.view_recent_submissions(count=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Provenance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "print(datetime.datetime.now())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "pip3 freeze"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Questions about these particular notebooks? Reach out to Puneet Batra pbatra@broadinstitute.org, Paolo Di Achille pdiachil@broadinstitute.org, and Nicole Deflaux deflaux@verily.com."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.8"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "199px"
+ },
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/pylintrc b/pylintrc
new file mode 100644
index 000000000..8a5e40122
--- /dev/null
+++ b/pylintrc
@@ -0,0 +1,337 @@
+# This configuration was copied from https://github.com/tensorflow/tensorflow/blob/18ebe824d2f6f20b09839cb0a0073032a2d6c5fe/tensorflow/tools/ci_build/pylintrc and then further modified.
+
+[MASTER]
+
+# Specify a configuration file.
+#rcfile=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Profiled execution.
+profile=no
+
+# Add files or directories to the denylist. They should be base names, not
+# paths.
+ignore=CVS
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+
+[MESSAGES CONTROL]
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time. See also the "--disable" option for examples.
+enable=indexing-exception,old-raise-syntax
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager
+
+
+# Set the cache size for astng objects.
+cache-size=500
+
+
+[REPORTS]
+
+# Set the output format. Available formats are text, parseable, colorized, msvs
+# (visual studio) and html. You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Put messages in a separate file for each module / package specified on the
+# command line instead of printing them on stdout. Reports (if any) will be
+# written in a file name "pylint_global.[txt|html]".
+files-output=no
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Add a comment according to your evaluation note. This is used by the global
+# evaluation report (RP0004).
+comment=no
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+
+[TYPECHECK]
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# List of classes names for which member attributes should not be checked
+# (useful for classes with attributes dynamically set).
+ignored-classes=SQLObject
+
+# When zope mode is activated, add a predefined set of Zope acquired attributes
+# to generated-members.
+zope=no
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E0201 when accessed. Python regular
+# expressions are accepted.
+generated-members=REQUEST,acl_users,aq_parent
+
+# List of decorators that create context managers from functions, such as
+# contextlib.contextmanager.
+contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
+
+
+[VARIABLES]
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# A regular expression matching the beginning of the name of dummy variables
+# (i.e. not used).
+dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+
+[BASIC]
+
+# Required attributes for module, separated by a comma
+required-attributes=
+
+# List of builtins function names that should not be used, separated by a comma
+bad-functions=apply,input,reduce
+
+
+# Disable the report(s) with the given id(s).
+# All non-Google reports are disabled by default.
+disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923
+
+# Regular expression which should only match correct module names
+module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
+
+# Regular expression which should only match correct module level names
+const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression which should only match correct class names
+class-rgx=^_?[A-Z][a-zA-Z0-9]*$
+
+# Regular expression which should only match correct function names
+function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
+
+# Regular expression which should only match correct method names
+method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
+
+# Regular expression which should only match correct instance attribute names
+attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
+
+# Regular expression which should only match correct argument names
+argument-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression which should only match correct variable names
+variable-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression which should only match correct attribute names in class
+# bodies
+class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression which should only match correct list comprehension /
+# generator expression variable names
+inlinevar-rgx=^[a-z][a-z0-9_]*$
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=main,_
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=(__.*__|main)
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=10
+
+
+[FORMAT]
+
+# Maximum number of characters on a single line.
+max-line-length=120
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=(?x)
+ (^\s*(import|from)\s
+ |\$Id:\s\/\/depot\/.+#\d+\s\$
+ |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+')
+ |^\s*\#\ LINT\.ThenChange
+ |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$
+ |pylint
+ |"""
+ |\#
+ |lambda
+ |(https?|ftp):)
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=y
+
+# List of optional constructs for which whitespace checking is disabled
+no-space-check=
+
+# Maximum number of lines in a module
+max-module-lines=99999
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+
+[SIMILARITIES]
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=
+
+
+[IMPORTS]
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+
+[CLASSES]
+
+# List of interface methods to ignore, separated by a comma. This is used for
+# instance to not check methods defines in Zope's Interface base class.
+ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,__new__,setUp
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls,class_
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=Exception,StandardError,BaseException
+
+
+[AST]
+
+# Maximum line length for lambdas
+short-func-length=1
+
+# List of module members that should be marked as deprecated.
+# All of the string functions are listed in 4.1.4 Deprecated string functions
+# in the Python 2.4 docs.
+deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc
+
+
+[DOCSTRING]
+
+# List of exceptions that do not need to be mentioned in the Raises section of
+# a docstring.
+ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError
+
+
+
+[TOKENS]
+
+# Number of spaces of indent required when the last token on the preceding line
+# is an open (, [, or {.
+indent-after-paren=4
+
+
+[GOOGLE LINES]
+
+# Regexp for a proper copyright notice.
+copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\.
From 837d019b5090b87827c804e6184ebf5149ca6709 Mon Sep 17 00:00:00 2001
From: Samwell Freeman
Date: Tue, 29 Sep 2020 17:26:25 -0400
Subject: [PATCH 03/21] paired
---
README.md | 32 +---
docker/terra_image/Dockerfile | 8 +-
docker/terra_image/README.md | 8 +-
docker/vm_boot_images/Dockerfile | 5 +-
.../config/tensorflow-requirements.txt | 2 +
ml4h/plots.py | 36 ++--
ml4h/tensorize/tensor_writer_ukbb.py | 181 +++++++++++-------
7 files changed, 154 insertions(+), 118 deletions(-)
diff --git a/README.md b/README.md
index 0335a4885..8996bfe39 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# ml4h
`ml4h` is a project aimed at using machine learning to model multi-modal cardiovascular
time series and imaging data. `ml4h` began as a set of tools to make it easy to work
-with the UK Biobank on the Google Cloud and has since expanded to include other data sources
+with the UK Biobank on Google Cloud Platform and has since expanded to include other data sources
and functionality.
@@ -9,6 +9,7 @@ Getting Started
* [Setting up your local environment](#setting-up-your-local-environment)
* [Setting up a remote VM](#setting-up-a-remote-vm)
* Modeling/Data Sources/Tests [(`ml4h/DATA_MODELING_TESTS.md`)](ml4h/DATA_MODELING_TESTS.md)
+* [Contributing Code](#contributing-code)
Advanced Topics:
* Tensorizing Data (going from raw data to arrays suitable for modeling, in `ml4h/tensorize/README.md, TENSORIZE.md` )
@@ -19,7 +20,7 @@ Clone the repo
```
git clone git@github.com:broadinstitute/ml.git
```
-Make sure you have installed the [google cloud tools (gcloud)](https://cloud.google.com/storage/docs/gsutil_install). With [Homebrew](https://brew.sh/), you can use
+Make sure you have installed the [Google Cloud SDK (gcloud)](https://cloud.google.com/sdk/docs/downloads-interactive). With [Homebrew](https://brew.sh/), you can use
```
brew cask install google-cloud-sdk
```
@@ -145,29 +146,6 @@ If you get a public key error run: `gcloud compute config-ssh`
Now open a browser on your laptop and go to the URL `http://localhost:8888`
+## Contributing code
-### Installing git-secrets
-
-```git-secrets``` helps us avoid committing secrets (e.g. private keys) and other critical data (e.g. PHI) to our
-repositories. ```git-secrets``` can be obtained via [github](https://github.com/awslabs/git-secrets) or on MacOS can be
-installed with Homebrew by running ```brew install git-secrets```.
-
-To add hooks to all repositories that you initialize or clone in the future:
-
-```git secrets --install --global```
-
-To add hooks to all local repositories:
-
-```
-git secrets --install ~/.git-templates/git-secrets
-git config --global init.templateDir ~/.git-templates/git-secrets
-```
-
-We maintain our own custom "provider" to cover any private keys or other critical data that we would like to avoid
-committing to our repositories. Feel free to add ```egrep```-compatible regular expressions to
-```git_secrets_provider_ml4h.txt``` to match types of critical data that are not currently covered by the patterns in that
-file. To register the patterns in this file with ```git-secrets```:
-
-```
-git secrets --add-provider -- cat ${HOME}/ml/git_secrets_provider_ml4h.txt
-```
+Want to contribute code to this project? Please see [CONTRIBUTING](./CONTRIBUTING.md) for developer setup and other details.
diff --git a/docker/terra_image/Dockerfile b/docker/terra_image/Dockerfile
index 721f94500..a59ecd6ae 100644
--- a/docker/terra_image/Dockerfile
+++ b/docker/terra_image/Dockerfile
@@ -1,4 +1,4 @@
-FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.0
+FROM us.gcr.io/broad-dsp-gcr-public/terra-jupyter-gatk:1.0.6
# https://github.com/DataBiosphere/terra-docker/blob/master/terra-jupyter-gatk/CHANGELOG.md
USER root
@@ -19,6 +19,10 @@ RUN pip3 install --user -r $HOME/ml4h_pkg/config/tensorflow-requirements.txt \
# first few rows of the downloaded dataframe of query results.
# Pin version due to https://github.com/googleapis/google-cloud-python/issues/9965
&& pip3 install --upgrade --user google-cloud-bigquery[pandas]==1.22.0 \
+ # Upgrade to a newer version. The one on the base Terra image was a bit too old.
+ && pip3 install --upgrade --user numpy \
# Configure notebook extensions.
&& jupyter nbextension install --user --py vega \
- && jupyter nbextension enable --user --py vega
+ && jupyter nbextension enable --user --py vega \
+ && jupyter nbextension install --user --py ipycanvas \
+ && jupyter nbextension enable --user --py ipycanvas
diff --git a/docker/terra_image/README.md b/docker/terra_image/README.md
index 71284c0bd..9a81dc74a 100644
--- a/docker/terra_image/README.md
+++ b/docker/terra_image/README.md
@@ -2,13 +2,13 @@
To build and push:
```
-mv ml4cvd ml4cvdBAK_$(date +"%Y%m%d_%H%M%S") \
+mv ml4h ml4hBAK_$(date +"%Y%m%d_%H%M%S") \
&& mv config configBAK_$(date +"%Y%m%d_%H%M%S") \
- && cp -r ../../ml4cvd . \
+ && cp -r ../../ml4h . \
&& cp -r ../vm_boot_images/config . \
&& gcloud --project uk-biobank-sek-data builds submit \
--timeout 20m \
- --tag gcr.io/uk-biobank-sek-data/ml4cvd_terra:`date +"%Y%m%d_%H%M%S"` .
+ --tag gcr.io/uk-biobank-sek-data/ml4h_terra:`date +"%Y%m%d_%H%M%S"` .
```
Notes:
@@ -20,5 +20,5 @@ available to docker.
cd notebooks
find . -name "*.ipynb" -type f -print0 | \
xargs -0 perl -i -pe \
- 's/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4cvd_terra:20200623_145127/g'
+ 's/gcr.io\/uk-biobank-sek-data\/ml4h_terra:\d{8}_\d{6}/gcr.io\/uk-biobank-sek-data\/ml4h_terra:20200623_145127/g'
```
diff --git a/docker/vm_boot_images/Dockerfile b/docker/vm_boot_images/Dockerfile
index a62694ca0..59e5b32be 100644
--- a/docker/vm_boot_images/Dockerfile
+++ b/docker/vm_boot_images/Dockerfile
@@ -34,4 +34,7 @@ RUN apt-get install python3-tk libgl1-mesa-glx libxt-dev -y
# Requirements for the tensorflow project
RUN pip3 install --upgrade pip
RUN pip3 install -r pre_requirements.txt
-RUN pip3 install -r tensorflow-requirements.txt
+RUN pip3 install -r tensorflow-requirements.txt \
+ # Configure notebook extensions.
+ && jupyter nbextension install --user --py ipycanvas \
+ && jupyter nbextension enable --user --py ipycanvas
diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt
index bb6a1e777..d782967af 100644
--- a/docker/vm_boot_images/config/tensorflow-requirements.txt
+++ b/docker/vm_boot_images/config/tensorflow-requirements.txt
@@ -28,3 +28,5 @@ altair
facets-overview
plotnine
vega
+ipycanvas==0.4.1
+ipyannotations==0.2.0
diff --git a/ml4h/plots.py b/ml4h/plots.py
index c60282803..6b5ad6a78 100755
--- a/ml4h/plots.py
+++ b/ml4h/plots.py
@@ -40,6 +40,9 @@
from scipy.ndimage.filters import gaussian_filter
from scipy import stats
+import ml4h.tensormap.ukb.ecg
+import ml4h.tensormap.mgb.ecg
+from ml4h.tensormap.mgb.dynamic import make_waveform_maps
from ml4h.TensorMap import TensorMap
from ml4h.metrics import concordance_index, coefficient_of_determination
from ml4h.defines import IMAGE_EXT, JOIN_CHAR, PDF_EXT, TENSOR_EXT, ECG_REST_LEADS, ECG_REST_MEDIAN_LEADS, PARTNERS_DATETIME_FORMAT, PARTNERS_DATE_FORMAT, HD5_GROUP_CHAR
@@ -1227,16 +1230,15 @@ def _plot_partners_figure(
def plot_partners_ecgs(args):
plot_tensors = [
- 'partners_ecg_patientid', 'partners_ecg_firstname', 'partners_ecg_lastname',
- 'partners_ecg_sex', 'partners_ecg_dob', 'partners_ecg_age',
- 'partners_ecg_datetime', 'partners_ecg_sitename', 'partners_ecg_location',
- 'partners_ecg_read_md', 'partners_ecg_taxis_md', 'partners_ecg_rate_md',
- 'partners_ecg_pr_md', 'partners_ecg_qrs_md', 'partners_ecg_qt_md',
- 'partners_ecg_paxis_md', 'partners_ecg_raxis_md', 'partners_ecg_qtc_md',
+ ml4h.tensormap.mgb.ecg.partners_ecg_patientid, ml4h.tensormap.mgb.ecg.partners_ecg_firstname, ml4h.tensormap.mgb.ecg.partners_ecg_lastname,
+ ml4h.tensormap.mgb.ecg.partners_ecg_sex, ml4h.tensormap.mgb.ecg.partners_ecg_dob, ml4h.tensormap.mgb.ecg.partners_ecg_age,
+ ml4h.tensormap.mgb.ecg.partners_ecg_datetime, ml4h.tensormap.mgb.ecg.partners_ecg_sitename, ml4h.tensormap.mgb.ecg.partners_ecg_location,
+ ml4h.tensormap.mgb.ecg.partners_ecg_read_md, ml4h.tensormap.mgb.ecg.partners_ecg_taxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_rate_md,
+ ml4h.tensormap.mgb.ecg.partners_ecg_pr_md, ml4h.tensormap.mgb.ecg.partners_ecg_qrs_md, ml4h.tensormap.mgb.ecg.partners_ecg_qt_md,
+ ml4h.tensormap.mgb.ecg.partners_ecg_paxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_raxis_md, ml4h.tensormap.mgb.ecg.partners_ecg_qtc_md,
]
- voltage_tensor = 'partners_ecg_2500_raw'
- from ml4h.tensor_maps_partners_ecg_labels import TMAPS
- tensor_maps_in = [TMAPS[it] for it in plot_tensors + [voltage_tensor]]
+ voltage_tensor = make_waveform_maps('partners_ecg_2500_raw')
+ tensor_maps_in = plot_tensors + [voltage_tensor]
tensor_paths = [os.path.join(args.tensors, tp) for tp in os.listdir(args.tensors) if os.path.splitext(tp)[-1].lower()==TENSOR_EXT]
if 'clinical' == args.plot_mode:
@@ -1503,13 +1505,13 @@ def plot_ecg_rest(
:param is_blind: if True, the plot gets blinded (helpful for review and annotation)
"""
map_fields_to_tmaps = {
- 'ramp': 'ecg_rest_ramplitude_raw',
- 'samp': 'ecg_rest_samplitude_raw',
- 'aVL': 'ecg_rest_lvh_avl',
- 'Sokolow_Lyon': 'ecg_rest_lvh_sokolow_lyon',
- 'Cornell': 'ecg_rest_lvh_cornell',
- }
- from ml4h.tensor_from_file import TMAPS
+ 'ramp': ml4h.tensormap.ukb.ecg.ecg_rest_ramplitude_raw,
+ 'samp': ml4h.tensormap.ukb.ecg.ecg_rest_samplitude_raw,
+ 'aVL': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_avl,
+ 'Sokolow_Lyon': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_sokolow_lyon,
+ 'Cornell': ml4h.tensormap.ukb.ecg.ecg_rest_lvh_cornell,
+ }
+
raw_scale = 0.005 # Conversion from raw to mV
default_yrange = ECG_REST_PLOT_DEFAULT_YRANGE # mV
time_interval = 2.5 # time-interval per plot in seconds. ts_Reference data is in s, voltage measurement is 5 uv per lsb
@@ -1521,7 +1523,7 @@ def plot_ecg_rest(
with h5py.File(tensor_path, 'r') as hd5:
traces, text = _ecg_rest_traces_and_text(hd5)
for field in map_fields_to_tmaps:
- tm = TMAPS[map_fields_to_tmaps[field]]
+ tm = map_fields_to_tmaps[field]
patient_dic[field] = np.zeros(tm.shape)
try:
patient_dic[field][:] = tm.tensor_from_file(tm, hd5)
diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py
index 64242bcbe..ceb4389e8 100644
--- a/ml4h/tensorize/tensor_writer_ukbb.py
+++ b/ml4h/tensorize/tensor_writer_ukbb.py
@@ -87,6 +87,10 @@ def write_tensors(
mri_unzip: str,
mri_field_ids: List[int],
xml_field_ids: List[int],
+ zoom_x: int,
+ zoom_y: int,
+ zoom_width: int,
+ zoom_height: int,
write_pngs: bool,
min_sample_id: int,
max_sample_id: int,
@@ -105,6 +109,13 @@ def write_tensors(
:param mri_unzip: Folder where zipped DICOM will be decompressed
:param mri_field_ids: List of MRI field IDs from UKBB
:param xml_field_ids: List of ECG field IDs from UKBB
+ :param x: Maximum x dimension of MRIs
+ :param y: Maximum y dimension of MRIs
+ :param z: Maximum z dimension of MRIs
+ :param zoom_x: x coordinate of the zoom
+ :param zoom_y: y coordinate of the zoom
+ :param zoom_width: width of the zoom
+ :param zoom_height: height of the zoom
:param write_pngs: write MRIs as PNG images for debugging
:param min_sample_id: Minimum sample id to generate, for parallelization
:param max_sample_id: Maximum sample id to generate, for parallelization
@@ -126,7 +137,7 @@ def write_tensors(
continue
try:
with h5py.File(tp, 'w') as hd5:
- _write_tensors_from_zipped_dicoms(write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats)
+ _write_tensors_from_zipped_dicoms(zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats)
_write_tensors_from_zipped_niftis(zip_folder, mri_field_ids, hd5, sample_id, stats)
_write_tensors_from_xml(xml_field_ids, xml_folder, hd5, sample_id, write_pngs, stats, continuous_stats)
stats['Tensors written'] += 1
@@ -177,26 +188,19 @@ def write_tensors_from_dicom_pngs(
continue
stats[sample_header + '_' + sample_id] += 1
dicom_file = row[dicom_index]
-
try:
png = imageio.imread(os.path.join(png_path, dicom_file + png_postfix))
- if len(png.shape) == 3 and png.mean() == png[:, :, 0].mean():
- png = png[:, :, 0]
- elif len(png.shape) == 3:
- raise ValueError(f'PNG has color information but no method to tensorize it {png.mean()}, 0ch :{png[:, :, 0].mean()}, 1ch :{png[:, :, 1].mean()}, 2ch :{png[:, :, 2].mean()}.')
full_tensor = np.zeros((x, y), dtype=np.float32)
full_tensor[:png.shape[0], :png.shape[1]] = png
tensor_file = os.path.join(tensors, str(sample_id) + TENSOR_EXT)
if not os.path.exists(os.path.dirname(tensor_file)):
os.makedirs(os.path.dirname(tensor_file))
with h5py.File(tensor_file, 'a') as hd5:
- tensor_name = series.lower() + '_annotated_' + row[instance_index]
+ tensor_name = series + '_annotated_' + row[instance_index]
tp = tensor_path(path_prefix, tensor_name)
if tp in hd5:
tensor = first_dataset_at_path(hd5, tp)
- min_x = min(png.shape[0], tensor.shape[0])
- min_y = min(png.shape[1], tensor.shape[1])
- tensor[:min_x, :min_y] = full_tensor[:min_x, :min_y]
+ tensor[:] = full_tensor
stats['updated'] += 1
else:
create_tensor_in_hd5(hd5, path_prefix, tensor_name, full_tensor, stats)
@@ -324,7 +328,7 @@ def _dicts_and_plots_from_tensorization(
continuous = {}
value_counter = Counter()
for k in sorted(list(stats.keys())):
- #logging.info("{} has {}".format(k, stats[k]))
+ logging.info("{} has {}".format(k, stats[k]))
if 'categorical' not in k and 'continuous' not in k:
continue
@@ -342,10 +346,10 @@ def _dicts_and_plots_from_tensorization(
plot_value_counter(list(categories.keys()), value_counter, a_id + '_v_count', os.path.join(output_folder, a_id))
plot_histograms(continuous_stats, a_id, os.path.join(output_folder, a_id))
- # logging.info("Continuous tensor map: {}".format(continuous))
- # logging.info("Continuous Columns: {}".format(len(continuous)))
- # logging.info("Category tensor map: {}".format(categories))
- # logging.info("Categories Columns: {}".format(len(categories)))
+ logging.info("Continuous tensor map: {}".format(continuous))
+ logging.info("Continuous Columns: {}".format(len(continuous)))
+ logging.info("Category tensor map: {}".format(categories))
+ logging.info("Categories Columns: {}".format(len(categories)))
def _to_float_or_false(s):
@@ -363,6 +367,10 @@ def _to_float_or_nan(s):
def _write_tensors_from_zipped_dicoms(
+ zoom_x: int,
+ zoom_y: int,
+ zoom_width: int,
+ zoom_height: int,
write_pngs: bool,
tensors: str,
dicoms: str,
@@ -382,8 +390,10 @@ def _write_tensors_from_zipped_dicoms(
os.makedirs(dicom_folder)
with zipfile.ZipFile(zipped, "r") as zip_ref:
zip_ref.extractall(dicom_folder)
- ukb_instance = zipped.split('_')[2]
- _write_tensors_from_dicoms(write_pngs, tensors, dicom_folder, hd5, sample_str, ukb_instance, stats)
+ _write_tensors_from_dicoms(
+ zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, dicom_folder,
+ hd5, sample_str, stats,
+ )
stats['MRI fields written'] += 1
shutil.rmtree(dicom_folder)
@@ -400,31 +410,36 @@ def _write_tensors_from_zipped_niftis(zip_folder: str, mri_field_ids: List[str],
def _write_tensors_from_dicoms(
- write_pngs: bool, tensors: str, dicom_folder: str, hd5: h5py.File, sample_str: str, ukb_instance: str, stats: Dict[str, int],
+ zoom_x: int, zoom_y: int, zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str,
+ dicom_folder: str, hd5: h5py.File, sample_str: str, stats: Dict[str, int],
) -> None:
"""Convert a folder of DICOMs from a sample into tensors for each series
Segmented dicoms require special processing and are written to tensor per-slice
Arguments
+ :param x: Width of the tensors (actual MRI width will be padded with 0s or cropped to this number)
+ :param y: Height of the tensors (actual MRI width will be padded with 0s or cropped to this number)
+ :param z: Minimum number of slices to include in the each tensor if more slices are found they will be kept
+ :param zoom_x: x coordinate of the zoom
+ :param zoom_y: y coordinate of the zoom
+ :param zoom_width: width of the zoom
+ :param zoom_height: height of the zoom
:param write_pngs: write MRIs as PNG images for debugging
:param tensors: Folder where hd5 tensor files are being written
:param dicom_folder: Folder with all dicoms associated with one sample.
:param hd5: Tensor file in which to create datasets for each series and each segmented slice
:param sample_str: The current sample ID as a string
- :param ukb_instance: The UK Biobank assessment visit instance number
:param stats: Counter to keep track of summary statistics
"""
views = defaultdict(list)
- series_to_numbers = defaultdict(set)
min_ideal_series = 9e9
for dicom in os.listdir(dicom_folder):
if os.path.splitext(dicom)[-1] != DICOM_EXT:
continue
d = pydicom.read_file(os.path.join(dicom_folder, dicom))
series = d.SeriesDescription.lower().replace(' ', '_')
- series_to_numbers[series].add(int(d.SeriesNumber))
if series + '_12bit' in MRI_LIVER_SERIES_12BIT and d.LargestImagePixelValue > 2048:
views[series + '_12bit'].append(d)
stats[series + '_12bit'] += 1
@@ -447,61 +462,99 @@ def _write_tensors_from_dicoms(
else:
mri_group = 'ukb_mri'
- if len(series_to_numbers[v]) > 1 and v not in MRI_BRAIN_SERIES:
- max_series = max(series_to_numbers[v])
- single_series = [dicom for dicom in views[v] if int(dicom.SeriesNumber) == max_series]
- # for d in views[v]:
- # logging.warning(f'{d.SeriesNumber} with Date: {_datetime_from_dicom(d)} Time {d.AcquisitionTime}')
- logging.warning(f'{v} has {len(views[v])} series:{series_to_numbers[v]} Using only max series: {max_series} with {len(single_series)}')
- views[v] = single_series
if v == MRI_TO_SEGMENT:
- _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, ukb_instance, hd5, mri_date, mri_group, stats)
+ _tensorize_short_and_long_axis_segmented_cardiac_mri(views[v], v, zoom_x, zoom_y, zoom_width, zoom_height, write_pngs, tensors, hd5, mri_date, mri_group, stats)
elif v in MRI_BRAIN_SERIES:
_tensorize_brain_mri(views[v], v, mri_date, mri_group, hd5)
else:
- pass
- # mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32)
- # for slicer in views[v]:
- # _save_pixel_dimensions_if_missing(slicer, v, hd5)
- # _save_slice_thickness_if_missing(slicer, v, hd5)
- # _save_series_orientation_and_position_if_missing(slicer, v, hd5)
- # slice_index = slicer.InstanceNumber - 1
- # if v in MRI_LIVER_IDEAL_PROTOCOL:
- # slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series)
- # mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32)
- # create_tensor_in_hd5(hd5, mri_group, f'{v}/{ukb_instance}', mri_data, stats, mri_date)
+ mri_data = np.zeros((views[v][0].Rows, views[v][0].Columns, len(views[v])), dtype=np.float32)
+ for slicer in views[v]:
+ _save_pixel_dimensions_if_missing(slicer, v, hd5)
+ _save_slice_thickness_if_missing(slicer, v, hd5)
+ _save_series_orientation_and_position_if_missing(slicer, v, hd5)
+ slice_index = slicer.InstanceNumber - 1
+ if v in MRI_LIVER_IDEAL_PROTOCOL:
+ slice_index = _slice_index_from_ideal_protocol(slicer, min_ideal_series)
+ mri_data[..., slice_index] = slicer.pixel_array.astype(np.float32)
+ create_tensor_in_hd5(hd5, mri_group, v, mri_data, stats, mri_date)
def _tensorize_short_and_long_axis_segmented_cardiac_mri(
- slices: List[pydicom.Dataset], series: str, instance: str,
- hd5: h5py.File, mri_date: datetime.datetime, mri_group: str, stats: Dict[str, int],
+ slices: List[pydicom.Dataset], series: str, zoom_x: int, zoom_y: int,
+ zoom_width: int, zoom_height: int, write_pngs: bool, tensors: str,
+ hd5: h5py.File, mri_date: datetime.datetime, mri_group: str,
+ stats: Dict[str, int],
) -> None:
+ systoles = {}
+ diastoles = {}
+ systoles_pix = {}
+ systoles_masks = {}
+ diastoles_masks = {}
+
for slicer in slices:
- #full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32)
+ full_mask = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32)
+ full_slice = np.zeros((slicer.Rows, slicer.Columns), dtype=np.float32)
+
if _has_overlay(slicer):
if _is_mitral_valve_segmentation(slicer):
series = series.replace('sax', 'lax')
else:
series = series.replace('lax', 'sax')
-
series_segmented = f'{series}_segmented'
+ series_zoom = f'{series}_zoom'
+ series_zoom_segmented = f'{series}_zoom_segmented'
+
try:
overlay, mask, ventricle_pixels, _ = _get_overlay_from_dicom(slicer)
except KeyError:
logging.exception(f'Got key error trying to make anatomical mask, skipping.')
continue
- # _save_pixel_dimensions_if_missing(slicer, series, hd5)
- # _save_slice_thickness_if_missing(slicer, series, hd5)
- # _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber))
+ _save_pixel_dimensions_if_missing(slicer, series, hd5)
+ _save_slice_thickness_if_missing(slicer, series, hd5)
+ _save_series_orientation_and_position_if_missing(slicer, series, hd5, str(slicer.InstanceNumber))
_save_pixel_dimensions_if_missing(slicer, series_segmented, hd5)
_save_slice_thickness_if_missing(slicer, series_segmented, hd5)
_save_series_orientation_and_position_if_missing(slicer, series_segmented, hd5, str(slicer.InstanceNumber))
- #
- # cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based
- #full_slice[:] = slicer.pixel_array.astype(np.float32)
- #create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{instance}', full_slice, stats, mri_date, slicer.InstanceNumber)
- create_tensor_in_hd5(hd5, mri_group, f'{series_segmented}{HD5_GROUP_CHAR}{instance}', mask, stats, mri_date, slicer.InstanceNumber)
+
+ cur_angle = (slicer.InstanceNumber - 1) // MRI_FRAMES # dicom InstanceNumber is 1-based
+ full_slice[:] = slicer.pixel_array.astype(np.float32)
+ create_tensor_in_hd5(hd5, mri_group, f'{series}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', full_slice, stats, mri_date)
+ create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', mask, stats, mri_date)
+
+ zoom_slice = full_slice[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height]
+ zoom_mask = mask[zoom_x: zoom_x + zoom_width, zoom_y: zoom_y + zoom_height]
+ create_tensor_in_hd5(hd5, mri_group, f'{series_zoom}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_slice, stats, mri_date)
+ create_tensor_in_hd5(hd5, mri_group, f'{series_zoom_segmented}{HD5_GROUP_CHAR}{slicer.InstanceNumber}', zoom_mask, stats, mri_date)
+
+ if (slicer.InstanceNumber - 1) % MRI_FRAMES == 0: # Diastole frame is always the first
+ diastoles[cur_angle] = slicer
+ diastoles_masks[cur_angle] = mask
+ if cur_angle not in systoles:
+ systoles[cur_angle] = slicer
+ systoles_pix[cur_angle] = ventricle_pixels
+ systoles_masks[cur_angle] = mask
+ else:
+ if ventricle_pixels < systoles_pix[cur_angle]:
+ systoles[cur_angle] = slicer
+ systoles_pix[cur_angle] = ventricle_pixels
+ systoles_masks[cur_angle] = mask
+
+ for angle in diastoles:
+ logging.info(f'Found systole, instance:{systoles[angle].InstanceNumber} ventricle pixels:{systoles_pix[angle]}')
+ full_slice = diastoles[angle].pixel_array.astype(np.float32)
+ create_tensor_in_hd5(hd5, mri_group, f'diastole_frame_b{angle}', full_slice, stats, mri_date)
+ create_tensor_in_hd5(hd5, mri_group, f'diastole_mask_b{angle}', diastoles_masks[angle], stats, mri_date)
+ if write_pngs:
+ plt.imsave(tensors + 'diastole_frame_b' + str(angle) + IMAGE_EXT, full_slice)
+ plt.imsave(tensors + 'diastole_mask_b' + str(angle) + IMAGE_EXT, full_mask)
+
+ full_slice = systoles[angle].pixel_array.astype(np.float32)
+ create_tensor_in_hd5(hd5, mri_group, f'systole_frame_b{angle}', full_slice, stats, mri_date)
+ create_tensor_in_hd5(hd5, mri_group, f'systole_mask_b{angle}', systoles_masks[angle], stats, mri_date)
+ if write_pngs:
+ plt.imsave(tensors + 'systole_frame_b' + str(angle) + IMAGE_EXT, full_slice)
+ plt.imsave(tensors + 'systole_mask_b' + str(angle) + IMAGE_EXT, full_mask)
def _tensorize_brain_mri(slices: List[pydicom.Dataset], series: str, mri_date: datetime.datetime, mri_group: str, hd5: h5py.File) -> None:
@@ -535,16 +588,13 @@ def _save_slice_thickness_if_missing(slicer, series, hd5):
def _save_series_orientation_and_position_if_missing(slicer, series, hd5, instance=None):
orientation_ds_name = MRI_PATIENT_ORIENTATION + '_' + series
position_ds_name = MRI_PATIENT_POSITION + '_' + series
- if instance is not None:
- orientation_ds_name = f'{orientation_ds_name}_{instance}'
- position_ds_name = f'{position_ds_name}_{instance}'
- try:
- if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT:
- hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient])
- if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT:
- hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient])
- except RuntimeError as e:
- logging.warning(f' got error {e} \n orientation : {orientation_ds_name} {slicer.ImageOrientationPatient} and pos: {position_ds_name} {slicer.ImagePositionPatient}')
+ if instance:
+ orientation_ds_name += HD5_GROUP_CHAR + instance
+ position_ds_name += HD5_GROUP_CHAR + instance
+ if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT:
+ hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient])
+ if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT:
+ hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient])
def _has_overlay(d) -> bool:
@@ -695,16 +745,13 @@ def _write_ecg_rest_tensors(ecgs, xml_field, hd5, sample_id, write_pngs, stats,
def create_tensor_in_hd5(
hd5: h5py.File, path_prefix: str, name: str, value, stats: Counter = None, date: datetime.datetime = None,
- instance: str = None, storage_type: StorageType = None, attributes: Dict[str, Any] = None,
+ storage_type: StorageType = None, attributes: Dict[str, Any] = None,
):
hd5_path = tensor_path(path_prefix, name)
- if instance is not None:
- hd5_path = f'{hd5_path}instance_{instance}/'
if hd5_path in hd5:
hd5_path = f'{hd5_path}instance_{len(hd5[hd5_path])}'
- elif instance is None:
+ else:
hd5_path = f'{hd5_path}instance_0'
-
if stats is not None:
stats[hd5_path] += 1
if storage_type == StorageType.STRING:
From 6021abb4b65527349b8042cd52a29c25f3d0532e Mon Sep 17 00:00:00 2001
From: Samwell Freeman
Date: Tue, 29 Sep 2020 17:27:41 -0400
Subject: [PATCH 04/21] paired
---
.../visualization_tools/annotation_storage.py | 36 ++--
ml4h/visualization_tools/annotations.py | 55 +++---
.../dicom_interactive_plots.py | 74 +++----
ml4h/visualization_tools/dicom_plots.py | 122 ++++++------
.../ecg_interactive_plots.py | 22 ++-
ml4h/visualization_tools/ecg_reshape.py | 58 +++---
ml4h/visualization_tools/ecg_static_plots.py | 11 +-
ml4h/visualization_tools/facets.py | 13 +-
ml4h/visualization_tools/hd5_mri_plots.py | 181 +++++++++---------
9 files changed, 311 insertions(+), 261 deletions(-)
diff --git a/ml4h/visualization_tools/annotation_storage.py b/ml4h/visualization_tools/annotation_storage.py
index ac8e89249..d0020b1ec 100644
--- a/ml4h/visualization_tools/annotation_storage.py
+++ b/ml4h/visualization_tools/annotation_storage.py
@@ -2,9 +2,11 @@
import abc
import datetime
-import pandas as pd
+from typing import Optional, Union
+
from google.cloud import bigquery
from google.cloud.bigquery import magics as bqmagics
+import pandas as pd
class AnnotationStorage(abc.ABC):
@@ -14,12 +16,14 @@ class AnnotationStorage(abc.ABC):
"""
@abc.abstractmethod
- def describe(self):
+ def describe(self) -> str:
"""Return a string describing how annotations are stored."""
- pass
@abc.abstractmethod
- def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment):
+ def submit_annotation(
+ self, sample_id: Union[int, str], annotator: str, key: str,
+ value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str,
+ ) -> bool:
"""Add an annotation to the collection of annotations.
Args:
@@ -32,10 +36,9 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri
Returns:
Whether the submission was successful. Throws an Exception on failure.
"""
- pass
@abc.abstractmethod
- def view_recent_submissions(self, count=10):
+ def view_recent_submissions(self, count: int = 10) -> pd.DataFrame:
"""View a dataframe of up to [count] most recent submissions.
Args:
@@ -44,7 +47,6 @@ def view_recent_submissions(self, count=10):
Returns:
A dataframe of the most recent annotations.
"""
- pass
class TransientAnnotationStorage(AnnotationStorage):
@@ -56,11 +58,14 @@ class TransientAnnotationStorage(AnnotationStorage):
def __init__(self):
self.annotations = []
- def describe(self):
+ def describe(self) -> str:
return '''Annotations will be stored in memory only during the duration of this demo.\n
For durable storage of annotations, use BigQueryAnnotationStorage instead.'''
- def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment):
+ def submit_annotation(
+ self, sample_id: Union[int, str], annotator: str, key: str,
+ value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str,
+ ) -> bool:
"""Add this annotation to our in-memory collection of annotations.
Args:
@@ -85,7 +90,7 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri
self.annotations.append(annotation)
return True
- def view_recent_submissions(self, count=10):
+ def view_recent_submissions(self, count: int = 10) -> pd.DataFrame:
"""View a dataframe of up to [count] most recent submissions.
Args:
@@ -110,14 +115,17 @@ class BigQueryAnnotationStorage(AnnotationStorage):
annotations_schema.json
"""
- def __init__(self, table):
+ def __init__(self, table: str):
"""This table should already exist."""
self.table = table
- def describe(self):
+ def describe(self) -> str:
return f'''Annotations are stored in BigQuery table {self.table}'''
- def submit_annotation(self, sample_id, annotator, key, value_numeric, value_string, comment):
+ def submit_annotation(
+ self, sample_id: Union[int, str], annotator: str, key: str,
+ value_numeric: Optional[Union[int, float]], value_string: Optional[str], comment: str,
+ ) -> bool:
"""Call a BigQuery INSERT statement to add a row containing annotation information.
Args:
@@ -150,7 +158,7 @@ def submit_annotation(self, sample_id, annotator, key, value_numeric, value_stri
# Return whether the submission completed.
return submission.done()
- def view_recent_submissions(self, count=10):
+ def view_recent_submissions(self, count: int = 10) -> pd.DataFrame:
"""View a dataframe of up to [count] most recent submissions.
This is a convenience method for use within the annotation flow. For full access to the underlying annotations,
diff --git a/ml4h/visualization_tools/annotations.py b/ml4h/visualization_tools/annotations.py
index 2400a07d8..9ca9c1b44 100644
--- a/ml4h/visualization_tools/annotations.py
+++ b/ml4h/visualization_tools/annotations.py
@@ -2,8 +2,11 @@
import os
import socket
+from typing import Any, Dict, Union
+
from IPython.display import display
from IPython.display import HTML
+import pandas as pd
import ipywidgets as widgets
from ml4h.visualization_tools.annotation_storage import AnnotationStorage
from ml4h.visualization_tools.annotation_storage import TransientAnnotationStorage
@@ -11,14 +14,18 @@
DEFAULT_ANNOTATION_STORAGE = TransientAnnotationStorage()
-def _get_df_sample(sample_info, sample_id):
+def _get_df_sample(sample_info: pd.DataFrame, sample_id: Union[int, str]) -> pd.DataFrame:
"""Return a dataframe containing only the row for the indicated sample_id."""
df_sample = sample_info[sample_info['sample_id'] == str(sample_id)]
- if 0 == df_sample.shape[0]: df_sample = sample_info.query('sample_id == ' + str(sample_id))
+ if df_sample.shape[0] == 0: df_sample = sample_info.query('sample_id == ' + str(sample_id))
return df_sample
-def display_annotation_collector(sample_info, sample_id, annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE, custom_annotation_key=None):
+def display_annotation_collector(
+ sample_info: pd.DataFrame, sample_id: Union[int, str],
+ annotation_storage: AnnotationStorage = DEFAULT_ANNOTATION_STORAGE,
+ custom_annotation_key: str = None,
+) -> None:
"""Method to create a gui (set of widgets) through which the user can create an annotation and submit it to storage.
Args:
@@ -26,15 +33,16 @@ def display_annotation_collector(sample_info, sample_id, annotation_storage: Ann
sample_id: The selected sample for which the values will be displayed.
annotation_storage: An instance of AnnotationStorage.
custom_annotation_key: The key for an annotation of data other than the tabular fields.
-
- Returns:
- A notebook-friendly messages indicating the status of the submission.
"""
df_sample = _get_df_sample(sample_info, sample_id)
if df_sample.shape[0] == 0:
- return HTML(f'''
- Warning: Sample {sample_id} not present in sample_info DataFrame.
''')
+ display(
+ HTML(f'''
+ Warning: Sample {sample_id} not present in sample_info DataFrame.
+
'''),
+ Annotation not submitted. Please try again.
+ '''),
)
- submit_button.on_click(on_button_clicked)
+ submit_button.on_click(cb_on_button_clicked)
# Display all the widgets.
display(sample, box1, comment, submit_button, output)
-def _format_annotation(sample_id, key, keyvalue, comment):
+def _format_annotation(
+ sample_id: Union[int, str], key: str, keyvalue: Union[int, float, str], comment: str,
+) -> Dict[str, Any]:
"""Helper method to clean and reshape info from the widgets and the environment into a dictionary representing the annotation."""
# Programmatically get the identity of the person running this Terra notebook.
current_user = os.getenv('OWNER_EMAIL')
@@ -128,11 +140,10 @@ def _format_annotation(sample_id, key, keyvalue, comment):
if current_user is None:
current_user = socket.gethostname() # By convention, we prefix the hostname with our username.
+ value_numeric = None
+ value_string = None
# Check whether the value is string or numeric.
- if keyvalue is None:
- value_numeric = None
- value_string = None
- else:
+ if keyvalue is not None:
try:
value_numeric = float(keyvalue) # this will fail if the value is text
value_string = None
diff --git a/ml4h/visualization_tools/dicom_interactive_plots.py b/ml4h/visualization_tools/dicom_interactive_plots.py
index d9850e841..ec9d63834 100644
--- a/ml4h/visualization_tools/dicom_interactive_plots.py
+++ b/ml4h/visualization_tools/dicom_interactive_plots.py
@@ -1,4 +1,4 @@
-"""Methods for integration of interactive dicom plots within notebooks.
+"""Methods for integration of interactive DICOM plots within notebooks.
TODO:
* Continue to *pragmatically* improve this to make the visualization controls
@@ -8,14 +8,15 @@
import collections
import os
import tempfile
+from typing import Any, DefaultDict, Dict, Optional, Tuple
import zipfile
from IPython.display import display
from IPython.display import HTML
+import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_mri_folders
-import numpy as np
import pydicom
import tensorflow as tf
@@ -27,15 +28,12 @@
MAX_COLOR_RANGE = 6000
-def choose_mri(sample_id, folder=None):
+def choose_mri(sample_id, folder: Optional[str] = None) -> None:
"""Render widget to choose the MRI to plot.
Args:
sample_id: The id of the sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
-
- Returns:
- ipywidget or HTML upon error.
"""
if folder is None:
folders = get_mri_folders(sample_id)
@@ -45,22 +43,26 @@ def choose_mri(sample_id, folder=None):
sample_mris = []
sample_mri_glob = str(sample_id) + '_*.zip'
try:
- for folder in folders:
- sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob)))
+ for f in folders:
+ sample_mris.extend(tf.io.gfile.glob(pattern=os.path.join(f, sample_mri_glob)))
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: MRI not available for sample {sample_id} in {folders}:
{e.message}
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
''')
+
'''),
+ )
+ return
if not sample_mris:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: MRI DICOMs not available for sample {sample_id} in {folders}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
''')
+
'''),
+ )
+ return
mri_chooser = widgets.Dropdown(
options=sample_mris,
@@ -77,14 +79,11 @@ def choose_mri(sample_id, folder=None):
display(file_controls_ui, file_controls_output)
-def choose_mri_series(sample_mri):
+def choose_mri_series(sample_mri: str) -> None:
"""Render widgets and interactive plots for MRIs.
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
-
- Returns:
- ipywidget or HTML upon error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
local_path = os.path.join(tmpdirname, os.path.basename(sample_mri))
@@ -93,13 +92,15 @@ def choose_mri_series(sample_mri):
with zipfile.ZipFile(local_path, 'r') as zip_ref:
zip_ref.extractall(tmpdirname)
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:
{e.message}
-
''')
+
'''),
+ )
+ return
- unordered_dicoms = collections.defaultdict(dict)
+ unordered_dicoms: DefaultDict[Any, Any] = collections.defaultdict(dict)
for dcm_file in os.listdir(tmpdirname):
if not dcm_file.endswith('.dcm'):
continue
@@ -112,8 +113,13 @@ def choose_mri_series(sample_mri):
unordered_dicoms[key1][key2] = dcm
if not unordered_dicoms:
- print(f'\n\nNo series available in MRI for sample {os.path.basename(sample_mri)}\n\nTry a different MRI.')
- return None
+ display(
+ HTML(f'''
+ No series available in MRI for sample {os.path.basename(sample_mri)}.
+ Try a different MRI.
+
'''),
+ )
+ return
# Convert from dict of dicts to dict of ordered lists.
dicoms = {}
@@ -134,7 +140,7 @@ def choose_mri_series(sample_mri):
style={'description_width': 'initial'},
layout=widgets.Layout(width='800px'),
)
- # Slide through dicom image instances using a slide bar.
+ # Slide through DICOM image instances using a slide bar.
instance_chooser = widgets.IntSlider(
continuous_update=True,
value=default_instance_value,
@@ -212,25 +218,25 @@ def on_value_change(change):
display(viz_controls_ui, viz_controls_output)
-def compute_color_range(dicoms, series_name):
+def compute_color_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]:
"""Compute the mean values for the color ranges of instances in the series."""
vmin = np.mean([np.min(d.pixel_array) for d in dicoms[series_name]])
vmax = np.mean([np.max(d.pixel_array) for d in dicoms[series_name]])
- return(vmin, vmax)
+ return (vmin, vmax)
-def compute_instance_range(dicoms, series_name):
+def compute_instance_range(dicoms: Dict[str, Any], series_name: str) -> Tuple[int, int]:
"""Compute middle and max instances."""
middle_instance = int(len(dicoms[series_name]) / 2)
max_instance = len(dicoms[series_name])
- return(middle_instance, max_instance)
+ return (middle_instance, max_instance)
def dicom_animation(
- dicoms, series_name, instance, vmin, vmax, transpose,
- fig_width, title_prefix='',
-):
- """Render one frame of a dicom animation.
+ dicoms: Dict[str, Any], series_name: str, instance: int, vmin: int, vmax: int, transpose: bool,
+ fig_width: int, title_prefix: str = '',
+) -> None:
+ """Render one frame of a DICOM animation.
Args:
dicoms: the dictionary DICOM series and instances lists
@@ -250,7 +256,7 @@ def dicom_animation(
dcm = dicoms[series_name][instance - 1]
if instance != dcm.InstanceNumber:
# Notice invalid input, but don't throw an error.
- print(f'WARNING: Instance parameter {str(instance)} and dicom instance number {str(dcm.InstanceNumber)} do not match.')
+ print(f'WARNING: Instance parameter {str(instance)} and instance number {str(dcm.InstanceNumber)} do not match.')
if transpose:
height = dcm.pixel_array.T.shape[0]
diff --git a/ml4h/visualization_tools/dicom_plots.py b/ml4h/visualization_tools/dicom_plots.py
index ce2b3e083..093691382 100644
--- a/ml4h/visualization_tools/dicom_plots.py
+++ b/ml4h/visualization_tools/dicom_plots.py
@@ -1,16 +1,17 @@
-"""Methods for integration of dicom plots within notebooks."""
+"""Methods for integration of DICOM plots within notebooks."""
import collections
import os
import tempfile
+from typing import Dict, List, Optional, Tuple, Union
import zipfile
from IPython.display import display
from IPython.display import HTML
+import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_cardiac_mri_folder
-import numpy as np
import pydicom
from scipy.ndimage.morphology import binary_closing
from scipy.ndimage.morphology import binary_erosion
@@ -27,21 +28,21 @@
MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2}
-def _is_mitral_valve_segmentation(d): # -> bool:
- """Determine whether a dicom has mitral valve segmentation.
+def _is_mitral_valve_segmentation(d: pydicom.FileDataset) -> bool:
+ """Determine whether a DICOM has mitral valve segmentation.
This is used for visualization of CINE_segmented_SAX_InlineVF.
Args:
- d: the dicom file
+ d: the DICOM file
Returns:
- Whether or not the dicom has mitral valve segmentation
+ Whether or not the DICOM has mitral valve segmentation
"""
return d.SliceThickness == 6
-def _get_overlay_from_dicom(d):
+def _get_overlay_from_dicom(d: pydicom.FileDataset) -> Tuple[int, int, int]:
"""Get an overlay from a DICOM file.
Morphological operators are used to transform the pixel outline of the
@@ -49,7 +50,7 @@ def _get_overlay_from_dicom(d):
is used for visualization of CINE_segmented_SAX_InlineVF.
Args:
- d: the dicom file
+ d: the DICOM file
Returns:
Raw overlay array with myocardium outline, anatomical mask (a pixel
@@ -77,29 +78,30 @@ def _get_overlay_from_dicom(d):
byte >>= 1
bit += 1
overlay = overlay[:expected_bit_length]
- if overlay_frames == 1:
- overlay = overlay.reshape(rows, cols)
- idx = np.where(overlay == 1)
- min_pos = (np.min(idx[0]), np.min(idx[1]))
- max_pos = (np.max(idx[0]), np.max(idx[1]))
- short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1]))
- small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR)
- big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR)
- small_structure = _unit_disk(small_radius)
- m1 = binary_closing(overlay, small_structure).astype(np.int)
- big_structure = _unit_disk(big_radius)
- m2 = binary_closing(overlay, big_structure).astype(np.int)
- anatomical_mask = m1 + m2
+ if overlay_frames != 1:
+ raise ValueError(f'DICOM has {overlay_frames} overlay frames, but only one expected.')
+ overlay = overlay.reshape(rows, cols)
+ idx = np.where(overlay == 1)
+ min_pos = (np.min(idx[0]), np.min(idx[1]))
+ max_pos = (np.max(idx[0]), np.max(idx[1]))
+ short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1]))
+ small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR)
+ big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR)
+ small_structure = _unit_disk(small_radius)
+ m1 = binary_closing(overlay, small_structure).astype(np.int)
+ big_structure = _unit_disk(big_radius)
+ m2 = binary_closing(overlay, big_structure).astype(np.int)
+ anatomical_mask = m1 + m2
+ ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
+ myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium'])
+ if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM:
+ erode_structure = _unit_disk(small_radius*1.5)
+ anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int)
ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
- myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium'])
- if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM:
- erode_structure = _unit_disk(small_radius*1.5)
- anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int)
- ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
- return overlay, anatomical_mask, ventricle_pixels
+ return overlay, anatomical_mask, ventricle_pixels
-def _unit_disk(r): # -> np.ndarray:
+def _unit_disk(r: int) -> np.ndarray:
"""Get the unit disk for a radius.
This is used for visualization of CINE_segmented_SAX_InlineVF.
@@ -114,7 +116,9 @@ def _unit_disk(r): # -> np.ndarray:
return (x ** 2 + y ** 2 <= r ** 2).astype(np.int)
-def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''):
+def plot_cardiac_long_axis(
+ b_series: List[pydicom.FileDataset], sides: int = 7, fig_width: int = 18, title_prefix: str = '',
+) -> None:
"""Visualize CINE_segmented_SAX_InlineVF series.
Args:
@@ -168,9 +172,9 @@ def plot_cardiac_long_axis(b_series, sides=7, fig_width=18, title_prefix=''):
def plot_cardiac_short_axis(
- series, transpose=False, fig_width=18,
- title_prefix='',
-):
+ series: List[pydicom.FileDataset], transpose: bool = False, fig_width: int = 18,
+ title_prefix: str = '',
+) -> None:
"""Visualize CINE_segmented_LAX series.
Args:
@@ -225,14 +229,14 @@ def plot_cardiac_short_axis(
def plot_mri_series(
- sample_mri, dicoms, series_name, sax_sides,
- lax_transpose, fig_width,
-):
+ sample_mri: str, dicoms: Dict[str, pydicom.FileDataset], series_name: str, sax_sides: int,
+ lax_transpose: bool, fig_width: int,
+) -> None:
"""Visualize the applicable series within this DICOM.
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
- dicoms: A dictionary of dicoms.
+ dicoms: A dictionary of DICOMs.
series_name: The name of the chosen series.
sax_sides: How many sides to display for CINE_segmented_SAX_InlineVF.
lax_transpose: Whether to transpose when plotting CINE_segmented_LAX.
@@ -258,10 +262,9 @@ def plot_mri_series(
)
else:
print(f'Visualization not currently implemented for {series_name}.')
- return None
-def choose_mri_series(sample_mri):
+def choose_mri_series(sample_mri: str) -> None:
"""Render widgets and plots for cardiac MRIs.
Visualization is supported for CINE_segmented_SAX_InlineVF series and
@@ -269,9 +272,6 @@ def choose_mri_series(sample_mri):
Args:
sample_mri: The local or Cloud Storage path to the MRI file.
-
- Returns:
- ipywidget or HTML upon error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
local_path = os.path.join(tmpdirname, os.path.basename(sample_mri))
@@ -280,11 +280,13 @@ def choose_mri_series(sample_mri):
with zipfile.ZipFile(local_path, 'r') as zip_ref:
zip_ref.extractall(tmpdirname)
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: Cardiac MRI not available for sample {os.path.basename(sample_mri)}:
{e.message}
-
''')
+
'''),
+ )
+ return
filtered_dicoms = collections.defaultdict(list)
series_descriptions = []
@@ -295,7 +297,7 @@ def choose_mri_series(sample_mri):
series_descriptions.append(dcm.SeriesDescription)
if 'cine_segmented_lax' in dcm.SeriesDescription.lower():
filtered_dicoms[dcm.SeriesDescription.lower()].append(dcm)
- if 'cine_segmented_sax_inlinevf' == dcm.SeriesDescription.lower():
+ if dcm.SeriesDescription.lower() == 'cine_segmented_sax_inlinevf':
cur_angle = (dcm.InstanceNumber - 1) // MRI_FRAMES
filtered_dicoms[f'{dcm.SeriesDescription.lower()}_angle_{str(cur_angle)}'].append(dcm)
@@ -350,22 +352,20 @@ def choose_mri_series(sample_mri):
)
display(viz_controls_ui, viz_controls_output)
else:
- print(
- f'\n\nNeither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.',
- '\n\nTry a different MRI.',
+ display(
+ HTML(f'''
+ Neither CINE_segmented_SAX_InlineVF nor CINE_segmented_LAX available in MRI for sample {os.path.basename(sample_mri)}.
+ Try a different MRI.
+
'''),
)
- return None
-def choose_cardiac_mri(sample_id, folder=None):
+def choose_cardiac_mri(sample_id: Union[int, str], folder: Optional[str] = None) -> None:
"""Render widget to choose the cardiac MRI to plot.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
-
- Returns:
- ipywidget or HTML upon error.
"""
if folder is None:
folder = get_cardiac_mri_folder(sample_id)
@@ -374,19 +374,23 @@ def choose_cardiac_mri(sample_id, folder=None):
try:
sample_mris = tf.io.gfile.glob(pattern=os.path.join(folder, sample_mri_glob))
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: Cardiac MRI not available for sample {sample_id} in {folder}:
{e.message}
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
''')
+
'''),
+ )
+ return
if not sample_mris:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: Cardiac MRI DICOM not available for sample {sample_id} in {folder}.
Use the folder parameter to read DICOMs from a different local directory or Cloud Storage bucket.
-
''')
+
'''),
+ )
+ return
mri_chooser = widgets.Dropdown(
options=[(os.path.basename(mri), mri) for mri in sample_mris],
diff --git a/ml4h/visualization_tools/ecg_interactive_plots.py b/ml4h/visualization_tools/ecg_interactive_plots.py
index 97a4e1547..18ed39a9b 100644
--- a/ml4h/visualization_tools/ecg_interactive_plots.py
+++ b/ml4h/visualization_tools/ecg_interactive_plots.py
@@ -2,10 +2,12 @@
import os
import tempfile
+from typing import Optional, Union
-import altair as alt # Interactive data visualization for plots.
from IPython.display import HTML
-from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME
+import altair as alt # Interactive data visualization for plots.
+from ml4h.TensorMap import TensorMap
+from ml4h.visualization_tools.ecg_reshape import DEFAULT_RESTING_ECG_SIGNAL_TMAP
from ml4h.visualization_tools.ecg_reshape import reshape_exercise_ecg_to_tidy
from ml4h.visualization_tools.ecg_reshape import reshape_resting_ecg_to_tidy
@@ -31,18 +33,21 @@
)
-def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME):
+def resting_ecg_interactive_plot(
+ sample_id: Union[int, str], folder: Optional[str] = None,
+ tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP,
+) -> Union[HTML, alt.Chart]:
"""Wrangle resting ECG data to tidy and present it as an interactive plot.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap_name: The name of the TMAP to use for ecg input.
+ tmap: The TensorMap to use for ECG input.
Returns:
An Altair plot or a notebook-friendly error.
"""
- tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap_name)
+ tidy_resting_ecg_signal = reshape_resting_ecg_to_tidy(sample_id, folder, tmap)
if tidy_resting_ecg_signal.shape[0] == 0:
return HTML(f'''
@@ -85,7 +90,9 @@ def resting_ecg_interactive_plot(sample_id, folder=None, tmap_name=DEFAULT_RESTI
return upper & lower
-def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds=10):
+def exercise_ecg_interactive_plot(
+ sample_id: Union[int, str], folder: Optional[str] = None, time_interval_seconds: int = 10,
+) -> Union[HTML, alt.Chart]:
"""Wrangle exercise ECG data to tidy and present it as an interactive plot.
Args:
@@ -140,7 +147,8 @@ def exercise_ecg_interactive_plot(sample_id, folder=None, time_interval_seconds=
lead_select,
).transform_filter(
# https://github.com/altair-viz/altair/issues/1960
- f'((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time) && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})',
+ f'''((toNumber({brush.name}.time) - {time_interval_seconds/2.0}) < datum.time)
+ && (datum.time < toNumber({brush.name}.time) + {time_interval_seconds/2.0})''',
)
return trend.encode(y='heartrate:Q') & trend.encode(y='load:Q') & signal
diff --git a/ml4h/visualization_tools/ecg_reshape.py b/ml4h/visualization_tools/ecg_reshape.py
index b3213d359..167eb5012 100644
--- a/ml4h/visualization_tools/ecg_reshape.py
+++ b/ml4h/visualization_tools/ecg_reshape.py
@@ -1,53 +1,57 @@
"""Methods for reshaping raw ECG signal data for use in the pandas ecosystem."""
import os
import tempfile
+from typing import Any, Dict, Optional, Tuple, Union
+import numpy as np
+import pandas as pd
from biosppy.signals.tools import filter_signal
import h5py
from ml4h.defines import ECG_BIKE_LEADS
from ml4h.defines import ECG_REST_LEADS
from ml4h.runtime_data_defines import get_exercise_ecg_hd5_folder
from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder
-from ml4h.tensor_maps_by_hand import TMAPS
-import numpy as np
-import pandas as pd
+from ml4h.TensorMap import TensorMap
+import ml4h.tensormap.ukb.ecg as ecg_tmaps
import tensorflow as tf
RAW_SCALE = 0.005 # Convert to mV.
SAMPLING_RATE = 500.0
-DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME = 'ecg_rest'
+DEFAULT_RESTING_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_rest
# TODO(deflaux): parameterize exercise ECG by TMAP name if there is similar ECG data from other studies.
-EXERCISE_ECG_SIGNAL_TMAP = TMAPS['ecg-bike-raw-full']
+EXERCISE_ECG_SIGNAL_TMAP = ecg_tmaps.ecg_bike_raw_full
EXERCISE_ECG_TREND_TMAPS = [
- TMAPS['ecg-bike-raw-trend-hr'],
- TMAPS['ecg-bike-raw-trend-load'],
- TMAPS['ecg-bike-raw-trend-grade'],
- TMAPS['ecg-bike-raw-trend-artifact'],
- TMAPS['ecg-bike-raw-trend-mets'],
- TMAPS['ecg-bike-raw-trend-pacecount'],
- TMAPS['ecg-bike-raw-trend-phasename'],
- TMAPS['ecg-bike-raw-trend-phasetime'],
- TMAPS['ecg-bike-raw-trend-time'],
- TMAPS['ecg-bike-raw-trend-vecount'],
+ ecg_tmaps.ecg_bike_raw_trend_hr,
+ ecg_tmaps.ecg_bike_raw_trend_load,
+ ecg_tmaps.ecg_bike_raw_trend_grade,
+ ecg_tmaps.ecg_bike_raw_trend_artifact,
+ ecg_tmaps.ecg_bike_raw_trend_mets,
+ ecg_tmaps.ecg_bike_raw_trend_pacecount,
+ ecg_tmaps.ecg_bike_raw_trend_phasename,
+ ecg_tmaps.ecg_bike_raw_trend_phasetime,
+ ecg_tmaps.ecg_bike_raw_trend_time,
+ ecg_tmaps.ecg_bike_raw_trend_vecount,
]
EXERCISE_PHASES = {0.0: 'Pretest', 1.0: 'Exercise', 2.0: 'Recovery'}
-def _examine_available_keys(hd5):
+def _examine_available_keys(hd5: Dict[str, Any]) -> None:
print(f'hd5 ECG keys {[k for k in hd5.keys() if "ecg" in k]}')
for key in [k for k in hd5.keys() if 'ecg' in k]:
- print(f'hd5 {key} keys {[k for k in hd5[key].keys()]}')
+ print(f'hd5 {key} keys {k for k in hd5[key]}')
-def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTING_ECG_SIGNAL_TMAP_NAME):
+def reshape_resting_ecg_to_tidy(
+ sample_id: Union[int, str], folder: Optional[str] = None, tmap: TensorMap = DEFAULT_RESTING_ECG_SIGNAL_TMAP,
+) -> pd.DataFrame:
"""Wrangle resting ECG data to tidy.
Args:
sample_id: The id of the ECG sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap_name: The name of the TMAP to use for ecg input.
+ tmap: The TensorMap to use for ECG input.
Returns:
A pandas dataframe in tidy format or print a notebook-friendly error and return an empty dataframe.
@@ -55,7 +59,7 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN
if folder is None:
folder = get_resting_ecg_hd5_folder(sample_id)
- data = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []}
+ data: Dict[str, Any] = {'lead': [], 'raw': [], 'ts_reference': [], 'filtered': [], 'filtered_1': [], 'filtered_2': []}
with tempfile.TemporaryDirectory() as tmpdirname:
sample_hd5 = str(sample_id) + '.hd5'
@@ -69,10 +73,10 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN
with h5py.File(local_path, mode='r') as hd5:
try:
- signals = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5)
+ signals = tmap.tensor_from_file(tmap, hd5)
except (KeyError, ValueError) as e:
- print(f'''Warning: Resting ECG TMAP {tmap_name} not available for sample {sample_id}.
- Use the tmap_name parameter to choose a different TMAP.\n\n{e}''')
+ print(f'''Warning: Resting ECG TMAP {tmap.name} not available for sample {sample_id}.
+ Use the tmap parameter to choose a different TMAP.\n\n{e}''')
_examine_available_keys(hd5)
return pd.DataFrame(data)
for (lead, channel) in ECG_REST_LEADS.items():
@@ -136,7 +140,9 @@ def reshape_resting_ecg_to_tidy(sample_id, folder=None, tmap_name=DEFAULT_RESTIN
return tidy_signal_df
-def reshape_exercise_ecg_to_tidy(sample_id, folder=None):
+def reshape_exercise_ecg_to_tidy(
+ sample_id: Union[int, str], folder: Optional[str] = None,
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Wrangle exercise ECG signal data to tidy format.
Args:
@@ -208,7 +214,9 @@ def reshape_exercise_ecg_to_tidy(sample_id, folder=None):
return (trend_df, tidy_signal_df)
-def reshape_exercise_ecg_and_trend_to_tidy(sample_id, folder=None):
+def reshape_exercise_ecg_and_trend_to_tidy(
+ sample_id: Union[int, str], folder: Optional[str] = None,
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Wrangle exercise ECG signal and trend data to tidy format.
Args:
diff --git a/ml4h/visualization_tools/ecg_static_plots.py b/ml4h/visualization_tools/ecg_static_plots.py
index 2ebcfc3e1..ac7283237 100644
--- a/ml4h/visualization_tools/ecg_static_plots.py
+++ b/ml4h/visualization_tools/ecg_static_plots.py
@@ -1,17 +1,18 @@
"""Methods for integration of static plots within notebooks."""
import os
import tempfile
+from typing import List, Optional, Union
from IPython.display import HTML
from IPython.display import SVG
+import numpy as np
from ml4h.plots import plot_ecg_rest
from ml4h.runtime_data_defines import get_resting_ecg_hd5_folder
from ml4h.runtime_data_defines import get_resting_ecg_svg_folder
-import numpy as np
import tensorflow as tf
-def display_resting_ecg(sample_id, folder=None):
+def display_resting_ecg(sample_id: Union[int, str], folder: Optional[str] = None) -> Union[HTML, SVG]:
"""Retrieve (or render) and display the SVG of the resting ECG.
Args:
@@ -53,8 +54,8 @@ def display_resting_ecg(sample_id, folder=None):
try:
# We don't need the resulting SVG, so send it to a temporary directory.
with tempfile.TemporaryDirectory() as tmpdirname:
- plot_ecg_rest(tensor_paths = [local_path], rows=[0], out_folder=tmpdirname, is_blind=False)
- except Exception as e:
+ return plot_ecg_rest(tensor_paths=[local_path], rows=[0], out_folder=tmpdirname, is_blind=False)
+ except Exception as e: # pylint: disable=broad-except
return HTML(f'''
Warning: Unable to render static plot of resting ECG for sample {sample_id} from {hd5_folder}:
@@ -62,7 +63,7 @@ def display_resting_ecg(sample_id, folder=None):
''')
-def major_breaks_x_resting_ecg(limits):
+def major_breaks_x_resting_ecg(limits: List[float]) -> np.array:
"""Method to compute breaks for plotnine plots of ECG resting data.
Args:
diff --git a/ml4h/visualization_tools/facets.py b/ml4h/visualization_tools/facets.py
index a45ea88da..18f96327d 100644
--- a/ml4h/visualization_tools/facets.py
+++ b/ml4h/visualization_tools/facets.py
@@ -2,6 +2,7 @@
import base64
import os
+import pandas as pd
from facets_overview.generic_feature_statistics_generator import GenericFeatureStatisticsGenerator
FACETS_DEPENDENCIES = {
@@ -25,10 +26,10 @@
FACETS_DEPENDENCIES[dep] = os.path.basename(url)
-class FacetsOverview(object):
+class FacetsOverview():
"""Methods for Facets Overview notebook integration."""
- def __init__(self, data):
+ def __init__(self, data: pd.DataFrame):
# This takes the dataframe and computes all the inputs to the Facets
# Overview plots such as:
# - numeric variables: histogram bins, mean, min, median, max, etc..
@@ -39,7 +40,7 @@ def __init__(self, data):
[{'name': 'data', 'table': data}],
)
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
"""Html representation of Facets Overview for use in a Jupyter notebook."""
protostr = base64.b64encode(self._proto.SerializeToString()).decode('utf-8')
html_template = '''
@@ -57,14 +58,14 @@ def _repr_html_(self):
return html
-class FacetsDive(object):
+class FacetsDive():
"""Methods for Facets Dive notebook integration."""
- def __init__(self, data, height=1000):
+ def __init__(self, data: pd.DataFrame, height: int = 1000):
self._data = data
self.height = height
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
"""Html representation of Facets Dive for use in a Jupyter notebook."""
html_template = """
diff --git a/ml4h/visualization_tools/hd5_mri_plots.py b/ml4h/visualization_tools/hd5_mri_plots.py
index 20b3305b1..d3894b39d 100644
--- a/ml4h/visualization_tools/hd5_mri_plots.py
+++ b/ml4h/visualization_tools/hd5_mri_plots.py
@@ -1,29 +1,34 @@
"""Methods for integration of plots of mri data processed to 3D tensors from within notebooks."""
+from collections import OrderedDict
from enum import Enum, auto
import os
import tempfile
+from typing import Any, Dict, List, Optional, Tuple, Union
-import h5py
from IPython.display import display
from IPython.display import HTML
+import numpy as np
+import h5py
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ml4h.runtime_data_defines import get_mri_hd5_folder
-from ml4h.tensor_maps_by_hand import TMAPS
-from ml4h.TensorMap import Interpretation
-import numpy as np
+import ml4h.tensormap.ukb.mri as ukb_mri
+import ml4h.tensormap.ukb.mri_vtk as ukb_mri_vtk
+from ml4h.TensorMap import Interpretation, TensorMap
import tensorflow as tf
-# Discover applicable TMAPS.
-CARDIAC_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if ('_lax_' in k or '_sax_' in k) and TMAPS[k].axes() == 3]
-CARDIAC_MRI_TMAP_NAMES.extend(
- [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_cardiac_mri' and TMAPS[k].axes() == 3],
+# Discover applicable TensorMaps.
+MRI_TMAPS = {
+ key: value for key, value in ukb_mri.__dict__.items() if isinstance(value, TensorMap)
+ and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3
+}
+MRI_TMAPS.update(
+ {
+ key: value for key, value in ukb_mri_vtk.__dict__.items()
+ if isinstance(value, TensorMap) and value.interpretation == Interpretation.CONTINUOUS and value.axes() == 3
+ },
)
-LIVER_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_liver_mri' and TMAPS[k].axes() == 3]
-BRAIN_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].path_prefix == 'ukb_brain_mri' and TMAPS[k].axes() == 3]
-# This includes more than just MRI TMAPS, it is a best effort.
-BEST_EFFORT_MRI_TMAP_NAMES = [k for k in TMAPS.keys() if TMAPS[k].interpretation == Interpretation.CONTINUOUS and TMAPS[k].axes() == 3]
MIN_IMAGE_WIDTH = 8
DEFAULT_IMAGE_WIDTH = 12
@@ -41,42 +46,30 @@ class PlotType(Enum):
class TensorMapCache:
"""Cache the tensor to display for reuse when re-plotting the same TMAP with different plot parameters."""
- def __init__(self, hd5, tmap_name):
+ def __init__(self, hd5: Dict[str, Any], tmap: TensorMap):
self.hd5 = hd5
- self.tmap_name = None
+ self.tmap: Optional[TensorMap] = None
self.tensor = None
- _ = self.get(tmap_name)
+ _ = self.get(tmap)
- def get(self, tmap_name):
- if self.tmap_name != tmap_name:
- self.tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], self.hd5)
- self.tmap_name = tmap_name
+ def get(self, tmap: TensorMap) -> np.array:
+ if self.tmap != tmap:
+ self.tensor = tmap.tensor_from_file(tmap, self.hd5)
+ self.tmap = tmap
return self.tensor
-def choose_cardiac_mri_tmap(sample_id, folder=None, tmap_name='cine_lax_4ch_192', default_tmap_names=CARDIAC_MRI_TMAP_NAMES):
- choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
-
-
-def choose_brain_mri_tmap(sample_id, folder=None, tmap_name='t2_flair_sag_p2_1mm_fs_ellip_pf78_1', default_tmap_names=BRAIN_MRI_TMAP_NAMES):
- choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
-
-
-def choose_liver_mri_tmap(sample_id, folder=None, tmap_name='liver_shmolli_segmented', default_tmap_names=LIVER_MRI_TMAP_NAMES):
- choose_mri_tmap(sample_id, folder, tmap_name, default_tmap_names)
-
-
-def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=BEST_EFFORT_MRI_TMAP_NAMES):
+def choose_mri_tmap(
+ sample_id: Union[int, str], folder: Optional[str] = None, tmap: Optional[TensorMap] = None,
+ default_tmaps: Dict[str, TensorMap] = MRI_TMAPS,
+) -> None:
"""Render widgets and plots for MRI tensors.
Args:
sample_id: The id of the sample to retrieve.
folder: The local or Cloud Storage folder under which the files reside.
- tmap_name: The TMAP name for the 3D MRI tensor to visualize.
- default_tmap_names: Other TMAP names to offer for visualization, if present in the hd5.
-
- Returns:
- ipywidget or HTML upon error.
+ tmap: The TensorMap for the 3D MRI tensor to visualize.
+ default_tmaps: Other TensorMaps to offer for visualization, if present in the hd5.
"""
if folder is None:
folder = get_mri_hd5_folder(sample_id)
@@ -88,42 +81,45 @@ def choose_mri_tmap(sample_id, folder=None, tmap_name=None, default_tmap_names=B
tf.io.gfile.copy(src=os.path.join(folder, sample_hd5), dst=local_path)
hd5 = h5py.File(local_path, mode='r')
except (tf.errors.NotFoundError, tf.errors.PermissionDeniedError) as e:
- return HTML(f'''
-
+ display(
+ HTML(f'''
Warning: MRI HD5 file not available for sample {sample_id} in folder {folder}:
{e.message}
Use the folder parameter to read HD5s from a different local directory or Cloud Storage bucket.
-
''')
-
- sample_tmap_names = []
- # Add the passed tmap_name parameter, if it is present in this hd5.
- if tmap_name:
- if TMAPS[tmap_name].hd5_key_guess() in hd5:
- if len(TMAPS[tmap_name].shape) == 3:
- sample_tmap_names.append(tmap_name)
+
'''),
+ )
+ return
+
+ sample_tmaps = OrderedDict()
+ # Add the passed tmap parameter, if it is present in this hd5.
+ if tmap:
+ if tmap.hd5_key_guess() in hd5:
+ if len(tmap.shape) == 3:
+ sample_tmaps[tmap.name] = tmap
else:
- print(f'{tmap_name} is not a 3D tensor, skipping it')
+ print(f'{tmap} is not a 3D tensor, skipping it')
else:
- print(f'{tmap_name} is not available in {sample_id}')
- # Also discover applicable TMAPS for this particular sample's HD5 file.
- sample_tmap_names.extend(
- sorted(set([k for k in default_tmap_names if TMAPS[k].hd5_key_guess() in hd5])),
- )
-
- if not sample_tmap_names:
- return HTML(f'''
- Neither {tmap_name} nor any of {default_tmap_names} are present in this HD5 for sample {sample_id} in {folder}.
- Use the tmap_name parameter to try a different TMAP or the folder parameter to try a different hd5 for the sample.
-
''')
-
- default_tmap_name_value = sample_tmap_names[0]
+ print(f'{tmap} is not available in {sample_id}')
+ # Also discover applicable TensorMaps for this particular sample's HD5 file.
+ sample_tmaps.update({n: t for n, t in sorted(default_tmaps.items(), key=lambda t: t[0]) if t.hd5_key_guess() in hd5})
+
+ if not sample_tmaps:
+ display(
+ HTML(f'''
+ Neither {tmap.name} nor any of {default_tmaps.keys()} are present in this HD5 for sample {sample_id} in {folder}.
+ Use the tmap parameter to try a different TensorMap or the folder parameter to try a different hd5 for the sample.
+
'),
- tmap_name_chooser,
+ tmap_chooser,
widgets.HBox([transpose_chooser, fig_width_chooser]),
widgets.HBox([flip_chooser, color_range_chooser]),
widgets.HBox([plot_type_chooser, instance_chooser]),
],
layout=widgets.Layout(width='auto', border='solid 1px grey'),
)
- tmap_cache = TensorMapCache(hd5=hd5, tmap_name=tmap_name_chooser.value)
+ tmap_cache = TensorMapCache(hd5=hd5, tmap=tmap_chooser.value)
viz_controls_output = widgets.interactive_output(
plot_mri_tmap,
{
'sample_id': widgets.fixed(sample_id),
'tmap_cache': widgets.fixed(tmap_cache),
- 'tmap_name': tmap_name_chooser,
+ 'tmap': tmap_chooser,
'plot_type': plot_type_chooser,
'instance': instance_chooser,
'color_range': color_range_chooser,
@@ -209,33 +205,36 @@ def on_plot_type_change(change):
else:
instance_chooser.layout.visibility = 'hidden'
- tmap_name_chooser.observe(on_tmap_value_change, names='value')
+ tmap_chooser.observe(on_tmap_value_change, names='value')
plot_type_chooser.observe(on_plot_type_change, names='value')
display(viz_controls_ui, viz_controls_output)
-def compute_color_range(hd5, tmap_name):
+def compute_color_range(hd5: Dict[str, Any], tmap: TensorMap) -> List[int]:
"""Compute the mean values for the color ranges of instances in the MRI series."""
- mri_tensor = TMAPS[tmap_name].tensor_from_file(TMAPS[tmap_name], hd5)
+ mri_tensor = tmap.tensor_from_file(tmap, hd5)
vmin = np.mean([np.min(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])])
vmax = np.mean([np.max(mri_tensor[:, :, i]) for i in range(0, mri_tensor.shape[2])])
- return[vmin, vmax]
+ return [vmin, vmax]
-def compute_instance_range(tmap_name):
+def compute_instance_range(tmap: TensorMap) -> Tuple[int, int]:
"""Compute middle and max instances."""
- middle_instance = int(TMAPS[tmap_name].shape[2] / 2)
- max_instance = TMAPS[tmap_name].shape[2]
- return(middle_instance, max_instance)
+ middle_instance = int(tmap.shape[2] / 2)
+ max_instance = tmap.shape[2]
+ return (middle_instance, max_instance)
-def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_range, transpose, flip, fig_width):
+def plot_mri_tmap(
+ sample_id: Union[int, str], tmap_cache: TensorMapCache, tmap: TensorMap, plot_type: PlotType,
+ instance: int, color_range: Tuple[int, int], transpose: bool, flip: bool, fig_width: int,
+) -> None:
"""Visualize the applicable MRI series within this HD5 file.
Args:
sample_id: The local or Cloud Storage path to the MRI file.
tmap_cache: The cache from which to retrieve the tensor to be plotted.
- tmap_name: The name of the chosen TMAP for the MRI series.
+ tmap: The chosen TensorMap for the MRI series.
plot_type: Whether to display instances interactively or in a panel view.
instance: The particular instance to display, if interactive.
color_range: Array of minimum and maximum value for the color range.
@@ -243,12 +242,9 @@ def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_r
flip: Whether to flip the image on its vertical axis
fig_width: The desired width of the figure. Note that height computed as
the proportion of the width based on the data to be plotted.
-
- Returns:
- The plot or a notebook-friendly error message.
"""
- title_prefix = f'{tmap_name} from MRI {sample_id}'
- mri_tensor = tmap_cache.get(tmap_name)
+ title_prefix = f'{tmap.name} from MRI {sample_id}'
+ mri_tensor = tmap_cache.get(tmap)
if plot_type == PlotType.INTERACTIVE:
plot_mri_tensor_as_animation(
mri_tensor=mri_tensor,
@@ -275,10 +271,13 @@ def plot_mri_tmap(sample_id, tmap_cache, tmap_name, plot_type, instance, color_r
title_prefix=title_prefix,
)
else:
- return HTML(f'''