Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Allow pass-through of correct units #11143

Merged
merged 1 commit into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from ...utils import verbose, logger, warn
from ...utils import verbose, logger, warn, _validate_type
from ..utils import _blk_read_lims, _mult_cal_one
from ..base import BaseRaw, _get_scaling
from ..meas_info import _empty_info, _unique_channel_names
Expand Down Expand Up @@ -141,17 +141,21 @@ def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto',
preload, include)
logger.info('Creating raw.info structure...')

if units is not None and isinstance(units, str):
units = {ch_name: units for ch_name in info['ch_names']}
elif units is None:
_validate_type(units, (str, None, dict), 'units')
if units is None:
units = dict()
elif isinstance(units, str):
units = {ch_name: units for ch_name in info['ch_names']}

for k, (this_ch, this_unit) in enumerate(orig_units.items()):
if this_unit != "" and this_ch in units:
raise ValueError(f'Unit for channel {this_ch} is present in '
'the file. Cannot overwrite it with the '
'units argument.')
if this_unit == "" and this_ch in units:
if this_ch not in units:
continue
if this_unit not in ("", units[this_ch]):
raise ValueError(
f'Unit for channel {this_ch} is present in the file as '
f'{repr(this_unit)}, cannot overwrite it with the units '
f'argument {repr(units[this_ch])}.')
if this_unit == "":
orig_units[this_ch] = units[this_ch]
ch_type = edf_info["ch_types"][k]
scaling = _get_scaling(ch_type.lower(), orig_units[this_ch])
Expand Down
9 changes: 8 additions & 1 deletion mne/io/edf/tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_orig_units():
def test_units_params():
"""Test enforcing original channel units."""
with pytest.raises(ValueError,
match=r"Unit for channel .* is present .* Cannot "
match=r"Unit for channel .* is present .* cannot "
"overwrite it"):
_ = read_raw_edf(edf_path, units='V', preload=True)

Expand Down Expand Up @@ -601,6 +601,7 @@ def test_ch_types():
assert raw.ch_names == labels

raw = read_raw_edf(edf_chtypes_path, infer_types=True)
data = raw.get_data()

labels = ['Fp1-Ref', 'Fp2-Ref', 'F3-Ref', 'F4-Ref', 'C3-Ref', 'C4-Ref',
'P3-Ref', 'P4-Ref', 'O1-Ref', 'O2-Ref', 'F7-Ref', 'F8-Ref',
Expand All @@ -617,3 +618,9 @@ def test_ch_types():

assert raw.get_channel_types() == types
assert raw.ch_names == labels

with pytest.raises(ValueError, match="cannot overwrite"):
read_raw_edf(edf_chtypes_path, units='V')
raw = read_raw_edf(edf_chtypes_path, units='uV') # should be okay
data_units = raw.get_data()
assert_allclose(data, data_units)