diff --git a/doc/changes/dev/13856.bugfix.rst b/doc/changes/dev/13856.bugfix.rst new file mode 100644 index 00000000000..7ba1b6ba901 --- /dev/null +++ b/doc/changes/dev/13856.bugfix.rst @@ -0,0 +1 @@ +Fixed an indexing bug in fNIRS support in :meth:`mne.io.BaseRaw.interpolate_bads` (and related methods) that could errantly use incorrect donor channels, by :newcontrib:`Kalle Makela`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index d3dbb9c8995..ff8921c3e76 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -172,6 +172,7 @@ .. _Jukka Nenonen: https://www.linkedin.com/pub/jukka-nenonen/28/b5a/684 .. _Jussi Nurminen: https://github.com/jjnurminen .. _Kaisu Lankinen: http://bishoplab.berkeley.edu/Kaisu.html +.. _Kalle Makela: https://github.com/Kallemakela .. _Katarina Slama: https://github.com/katarinaslama .. _Katia Al-Amir: https://github.com/katia-sentry .. _Kay Robbins: https://github.com/VisLab diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 0f80e004fd6..6c38406c1cd 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -278,15 +278,19 @@ def _interpolate_bads_nirs(inst, exclude=(), verbose=None): dist = pdist(locs3d) dist = squareform(dist) - for bad in picks_bad: - dists_to_bad = dist[bad] + for bad_raw_idx in picks_bad: + # `bad_raw_idx` is the index of the bad channel in `inst` + # `bad_dist_idx` is the index of the bad channel in `dist` + bad_dist_idx = np.where(picks_nirs == bad_raw_idx)[0][0] + dists_to_bad = dist[bad_dist_idx].copy() # Ignore distances to self dists_to_bad[dists_to_bad == 0] = np.inf # Ignore distances to other bad channels dists_to_bad[bads_mask] = np.inf # Find closest remaining channels for same frequency - closest_idx = np.argmin(dists_to_bad) + (bad % 2) - inst._data[bad] = inst._data[closest_idx] + closest_dist_idx = np.argmin(dists_to_bad) + (bad_dist_idx % 2) + closest_raw_idx = picks_nirs[closest_dist_idx] + inst._data[bad_raw_idx] = inst._data[closest_raw_idx] # TODO: this seems like a bug because it does not respect reset_bads inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index c52a34ed887..4e9ba68efa3 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -10,7 +10,7 @@ from numpy.testing import assert_allclose, assert_array_equal import mne.channels.channels -from mne import Epochs, pick_channels, pick_types, read_events +from mne import Epochs, create_info, pick_channels, pick_types, read_events from mne._fiff.constants import FIFF from mne._fiff.proj import _has_eeg_average_ref_proj from mne.channels import make_dig_montage, make_standard_montage @@ -333,6 +333,44 @@ def test_interpolation_nirs(): assert raw_haemo.info["bads"] == [] +def test_interpolation_nirs_reordered_picks(): + """Test NIRS interpolation uses the closest donor in raw channel space.""" + ch_names = [ + "S1_D1 760", + "S1_D1 850", + "S2_D2 760", + "S2_D2 850", + "S3_D3 760", + "S3_D3 850", + "S10_D10 760", + "S10_D10 850", + ] + info = create_info(ch_names, sfreq=1.0, ch_types=["fnirs_cw_amplitude"] * 8) + pair_positions = { + "S1_D1": (0.009, 0.0, 0.0), + "S2_D2": (0.010, 0.0, 0.0), + "S3_D3": (0.030, 0.0, 0.0), + "S10_D10": (0.040, 0.0, 0.0), + } + for idx, ch in enumerate(info["chs"]): + pair = ch["ch_name"].rsplit(" ", 1)[0] + ch["loc"][:3] = pair_positions[pair] + ch["loc"][9] = 760.0 if idx % 2 == 0 else 850.0 + data = np.arange(len(ch_names), dtype=float).reshape(-1, 1) + data = np.repeat(data, 5, axis=1) + raw = RawArray(data, info, verbose=False) + raw.info["bads"] = ["S2_D2 760", "S2_D2 850"] + + raw.interpolate_bads( + method=dict(fnirs="nearest"), origin=(0.0, 0.0, 0.0), verbose=False + ) + + # Bad S2_D2 should copy from the nearest good pair, S1_D1. + picks_bad = pick_channels(raw.ch_names, ["S2_D2 760", "S2_D2 850"], exclude=[]) + picks_want = pick_channels(raw.ch_names, ["S1_D1 760", "S1_D1 850"], exclude=[]) + assert_allclose(raw.get_data(picks=picks_bad), raw.get_data(picks=picks_want)) + + @testing.requires_testing_data def test_interpolation_ecog(): """Test interpolation for ECoG."""