diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index f71b782cb8c..0b2c5f79589 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -147,7 +147,7 @@ def __init__(self, input_fname, eog=None, misc=None, stim_channel='auto', units = dict() for k, (this_ch, this_unit) in enumerate(orig_units.items()): - if this_unit != "" and this_unit in units: + if this_unit != "" and this_ch in units: raise ValueError(f'Unit for channel {this_ch} is present in ' 'the file. Cannot overwrite it with the ' 'units argument.') diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index 7426e76b66f..975059c328f 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -77,6 +77,14 @@ def test_orig_units(): assert orig_units['A1'] == 'µV' # formerly 'uV' edit by _check_orig_units +def test_units_params(): + """Test enforcing original channel units.""" + with pytest.raises(ValueError, + match=r"Unit for channel .* is present .* Cannot " + "overwrite it"): + _ = read_raw_edf(edf_path, units='V', preload=True) + + def test_subject_info(tmp_path): """Test exposure of original channel units.""" raw = read_raw_edf(edf_path)