Skip to content

Commit

Permalink
Merge pull request #72 from the-virtual-brain/WID-69
Browse files Browse the repository at this point in the history
WID-69. Add support for EDF files
  • Loading branch information
romina1601 authored Apr 17, 2024
2 parents fcc188d + 767ba6e commit d18ef02
Show file tree
Hide file tree
Showing 7 changed files with 4,501 additions and 18 deletions.
202 changes: 202 additions & 0 deletions notebooks/TimeSeriesEDF.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bae5e88c-1828-4349-b67a-01082e41aaf7",
"metadata": {},
"source": [
"## EDF TimeSeries Widget"
]
},
{
"cell_type": "markdown",
"id": "866da714-9a4c-4242-b469-8643ba5a8573",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"id": "a28b8eb4-4709-4056-9e6e-4530a3735011",
"metadata": {},
"source": [
"### This notebook is dedicated to showcasing the TimeSeries widget using data from an EDF file"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "56d5c4b6-7804-46d6-931f-fcc5593977d1",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib widget"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "28420169-2274-4e0e-a73d-c1ab54e6da9c",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"17-04-2024 05:02:37 - DEBUG - tvbwidgets - Package is not fully installed\n",
"17-04-2024 05:02:37 - DEBUG - tvbwidgets - Version read from the internal package.json file\n",
"17-04-2024 05:02:37 - INFO - tvbwidgets - Version: 2.0.0\n",
"Using matplotlib as 2D backend.\n",
" INFO Cannot import syncrypto library.\n",
"17-04-2024 05:02:49 - INFO - tvbwidgets.core.pse.parameters - ImportError: Dask dependency is not included, so this functionality won't be available\n"
]
}
],
"source": [
"import numpy as np\n",
"from tvbwidgets.api import plot_timeseries\n",
"from tvbwidgets.readers import read_edf_file"
]
},
{
"cell_type": "markdown",
"id": "59594729-3afd-4537-87c2-7413dcb66606",
"metadata": {},
"source": [
"#### Reading the data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "94687d23-0222-4e9f-887e-e97169737304",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting EDF parameters from C:\\Work\\TVB\\tvb-widgets\\tvbwidgets\\tests\\data\\test_file.edf...\n",
"EDF file detected\n",
"Setting channel info structure...\n",
"Creating raw.info structure...\n"
]
}
],
"source": [
"edf_file_path = '../tvbwidgets/tests/data/test_file.edf' # replace path with your actual EDF file path\n",
"data, freq, index = read_edf_file(edf_file_path)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "12da93b3-f96f-44cf-8555-52eee8aaecfd",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(92000, 32)\n",
"400.0\n"
]
}
],
"source": [
"print(data.shape)\n",
"print(freq)"
]
},
{
"cell_type": "markdown",
"id": "94d8b0ec-654f-47a7-a0c8-ce3b28651e22",
"metadata": {},
"source": [
"#### TS Viewer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cac83fe3-b4fe-486a-8955-70203dda1004",
"metadata": {},
"outputs": [],
"source": [
"backend = 'plotly' # change to 'matplotlib' to see the other TS widget"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "20c3eae2-49dd-4d01-af9a-06a0dba15888",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"17-04-2024 05:02:49 - INFO - tvbwidgets.ui.ts.plotly_ts_widget - TimeSeries Widget with Plotly initialized\n",
"Creating RawArray with float64 data, n_channels=32, n_times=92000\n",
" Range : 0 ... 91999 = 0.000 ... 229.998 secs\n",
"Ready.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28d8ffb85be94f28b59e85b521e4aeeb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"TimeSeriesWidgetPlotly(children=(HBox(children=(Output(),)), VBox(children=(Dropdown(description='Colormap:', …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tsw = plot_timeseries(data=data, sample_freq=freq, ch_idx=index, backend=backend)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83da35ea-9c05-4f4b-90e8-93ee7002554c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
24 changes: 24 additions & 0 deletions tvbwidgets/readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
#
# "TheVirtualBrain - Widgets" package
#
# (c) 2022-2024, TVB Widgets Team
#

import mne
import numpy as np


def read_edf_file(filepath):
# type: (str) -> (numpy.ndarray, float, int)
"""
Reads an EDF file located at filepath and returns the data array, the sample frequency and the channel index,
all necessary for the `api.plot_timeseries` function
"""
raw = mne.io.read_raw_edf(filepath)
data, _ = raw[:]
data = np.transpose(data)
ch_idx = len(data.shape) - 1
sample_freq = raw.info['sfreq']

return data, sample_freq, ch_idx
4,221 changes: 4,221 additions & 0 deletions tvbwidgets/tests/data/test_file.edf

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions tvbwidgets/tests/test_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
#
# "TheVirtualBrain - Widgets" package
#
# (c) 2022-2024, TVB Widgets Team
#

import numpy as np
import os
import tvbwidgets.readers as readers

TEST_DATA = os.path.join(os.path.dirname(__file__), 'data')


def test_read_edf_file():
edf_test_file = 'test_file.edf'
edf_path = os.path.join(TEST_DATA, edf_test_file)

data, freq, idx = readers.read_edf_file(edf_path)
assert data is not None
assert freq is not None
assert idx is not None

assert isinstance(data, np.ndarray)
assert data.shape == (92000, 32)

assert isinstance(freq, float)
assert freq == 400

assert isinstance(idx, int)
assert idx == 1
2 changes: 2 additions & 0 deletions tvbwidgets/tests/ts/test_data_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ def test_get_hover_channel_value_tvb(wrapper_tvb):
val = wrapper_tvb.data.data[x_int, sel1, ch_index, sel2]
val = round(val, 4)
assert ch_value == val


3 changes: 2 additions & 1 deletion tvbwidgets/ui/ts/base_ts_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tvbwidgets.ui.ts.data_wrappers.tvb_data_wrapper import WrapperTVB
from tvbwidgets.ui.ts.data_wrappers.numpy_data_wrapper import WrapperNumpy


class TimeSeriesWidgetBase(widgets.VBox, TVBWidget):
# =========================================== SETUP ================================================================
def add_datatype(self, ts_tvb):
Expand Down Expand Up @@ -112,4 +113,4 @@ def _dimensions_selection_update(self, _):
# update self.raw
sel1, sel2 = self._get_selection_values()
new_slice = self.data.get_update_slice(sel1, sel2)
self.raw = self.data.build_raw(new_slice)
self.raw = self.data.build_raw(new_slice)
36 changes: 19 additions & 17 deletions tvbwidgets/ui/ts/plotly_ts_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ def __init__(self, **kwargs):
self.channel_selection_area = widgets.HBox(layout=widgets.Layout(width='90%'))
self.info_and_channels_area = widgets.HBox(layout=widgets.Layout(margin='0px 0px 0px 80px'))
self.plot_area.children += (self.output,)
self.scaling_title = widgets.Label(value='Increase/Decrease signal scaling (current scaling value to the right)')
self.scaling_title = widgets.Label(
value='Increase/Decrease signal scaling (current scaling value to the right)')
self.scaling_slider = widgets.IntSlider(value=1, layout=widgets.Layout(width='30%'))
self.colormaps = ['turbo', 'brg', 'gist_stern_r', 'nipy_spectral_r', 'coolwarm','plasma', 'magma', 'viridis', \
'cividis', 'twilight', 'twilight_shifted', 'CMRmap_r', 'Blues', \
'BuGn', 'BuPu', 'Greens', 'PuRd', 'RdPu', 'Spectral', 'YlGnBu', \
'YlOrBr', 'YlOrRd', 'cubehelix_r', 'gist_earth_r', 'terrain_r', \
'rainbow_r', 'pink_r', 'gist_ncar_r', 'uni-color(black)']
self.colormaps = ['turbo', 'brg', 'gist_stern_r', 'nipy_spectral_r', 'coolwarm', 'plasma', 'magma', 'viridis', \
'cividis', 'twilight', 'twilight_shifted', 'CMRmap_r', 'Blues', \
'BuGn', 'BuPu', 'Greens', 'PuRd', 'RdPu', 'Spectral', 'YlGnBu', \
'YlOrBr', 'YlOrRd', 'cubehelix_r', 'gist_earth_r', 'terrain_r', \
'rainbow_r', 'pink_r', 'gist_ncar_r', 'uni-color(black)']
self.colormap_dropdown = widgets.Dropdown(options=self.colormaps, description='Colormap:', disabled=False)
self.colormap_dropdown.observe(self.update_colormap, names='value')

super().__init__([self.plot_area, widgets.VBox([self.colormap_dropdown, self.scaling_title, self.scaling_slider],
layout=widgets.Layout(margin='0px 0px 0px 80px')),
self.info_and_channels_area],
layout=self.DEFAULT_BORDER)
super().__init__(
[self.plot_area, widgets.VBox([self.colormap_dropdown, self.scaling_title, self.scaling_slider],
layout=widgets.Layout(margin='0px 0px 0px 80px')),
self.info_and_channels_area],
layout=self.DEFAULT_BORDER)
self.logger.info("TimeSeries Widget with Plotly initialized")

# =========================================== SETUP ================================================================
Expand All @@ -76,14 +78,15 @@ def add_traces_to_plot(self, data, ch_names):
ch_names = ch_names[::-1]
if self.colormap == "uni-color(black)":
colormap = plt.get_cmap('gray')
colors = colormap(np.linspace(0, 0, len(ch_names)))
colors = colormap(np.linspace(0, 0, len(ch_names)))
else:
colormap = plt.get_cmap(self.colormap)
colors = colormap(np.linspace(0.3, 1, len(ch_names)))
colors = [mlt.to_hex(color, keep_alpha=False) for color in colors]

self.fig.add_traces(
[dict(y=ts * self.amplitude + i * self.std_step, name=ch_name, customdata=ts, hovertemplate='%{customdata}', line_color = colors[i])
[dict(y=ts * self.amplitude + i * self.std_step, name=ch_name, customdata=ts, hovertemplate='%{customdata}',
line_color=colors[i])
for i, (ch_name, ts) in enumerate(zip(ch_names, data))]
)

Expand Down Expand Up @@ -156,14 +159,14 @@ def plot_ts_with_plotly(self, data=None, ch_names=None):
self.output.clear_output(wait=True)
display(self.fig)

def update_colormap(self,change):
self.colormap = change['new']
def update_colormap(self, change):
self.colormap = change['new']
self.fig.data = []
data = self.raw[:, :][0]
data = data[self.ch_picked, :]
ch_names = [self.ch_names[i] for i in self.ch_picked]
self.add_traces_to_plot(data, ch_names)

# ================================================= SCALING ========================================================
def _setup_scaling_slider(self):
# set min and max scaling values
Expand All @@ -183,7 +186,6 @@ def update_scaling(self, val):
ch_names = [self.ch_names[i] for i in self.ch_picked]
self.add_traces_to_plot(data, ch_names)


# =========================================== CHANNELS SELECTION ===================================================
def _create_channel_selection_area(self, array_wrapper, no_checkbox_columns=5):
# type: (ABCDataWrapper) -> widgets.Accordion
Expand Down Expand Up @@ -222,7 +224,7 @@ def _create_channel_selection_area(self, array_wrapper, no_checkbox_columns=5):
def _update_ts(self, btn):
self.logger.debug('Updating TS')
ch_names = list(self.ch_names)

# save selected channels using their index in the ch_names list
self.ch_picked = []
for cb in list(self.checkboxes.values()):
Expand Down

0 comments on commit d18ef02

Please sign in to comment.