Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13856.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an indexing bug in ``_interpolate_bads_nirs`` that could use the wrong donor channels when :func:`mne.io.Raw.interpolate_bads` interpolated fNIRS channels reordered by ``_validate_nirs_info``, by :newcontrib:`Kalle Makela`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
.. _Jukka Nenonen: https://www.linkedin.com/pub/jukka-nenonen/28/b5a/684
.. _Jussi Nurminen: https://github.com/jjnurminen
.. _Kaisu Lankinen: http://bishoplab.berkeley.edu/Kaisu.html
.. _Kalle Makela: https://github.com/Kallemakela
.. _Katarina Slama: https://github.com/katarinaslama
.. _Katia Al-Amir: https://github.com/katia-sentry
.. _Kay Robbins: https://github.com/VisLab
Expand Down
12 changes: 8 additions & 4 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,19 @@ def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
for bad_raw_idx in picks_bad:
# `bad_raw_idx` is the index of the bad channel in `inst`
# `bad_dist_idx` is the index of the bad channel in `dist`
bad_dist_idx = np.where(picks_nirs == bad_raw_idx)[0][0]
dists_to_bad = dist[bad_dist_idx].copy()
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]
closest_dist_idx = np.argmin(dists_to_bad) + (bad_dist_idx % 2)
closest_raw_idx = picks_nirs[closest_dist_idx]
inst._data[bad_raw_idx] = inst._data[closest_raw_idx]

# TODO: this seems like a bug because it does not respect reset_bads
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
Expand Down
40 changes: 39 additions & 1 deletion mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.testing import assert_allclose, assert_array_equal

import mne.channels.channels
from mne import Epochs, pick_channels, pick_types, read_events
from mne import Epochs, create_info, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage, make_standard_montage
Expand Down Expand Up @@ -333,6 +333,44 @@ def test_interpolation_nirs():
assert raw_haemo.info["bads"] == []


def test_interpolation_nirs_reordered_picks():
"""Test NIRS interpolation uses the closest donor in raw channel space."""
ch_names = [
"S1_D1 760",
"S1_D1 850",
"S2_D2 760",
"S2_D2 850",
"S3_D3 760",
"S3_D3 850",
"S10_D10 760",
"S10_D10 850",
]
info = create_info(ch_names, sfreq=1.0, ch_types=["fnirs_cw_amplitude"] * 8)
pair_positions = {
"S1_D1": (0.009, 0.0, 0.0),
"S2_D2": (0.010, 0.0, 0.0),
"S3_D3": (0.030, 0.0, 0.0),
"S10_D10": (0.040, 0.0, 0.0),
}
for idx, ch in enumerate(info["chs"]):
pair = ch["ch_name"].rsplit(" ", 1)[0]
ch["loc"][:3] = pair_positions[pair]
ch["loc"][9] = 760.0 if idx % 2 == 0 else 850.0
data = np.arange(len(ch_names), dtype=float).reshape(-1, 1)
data = np.repeat(data, 5, axis=1)
raw = RawArray(data, info, verbose=False)
raw.info["bads"] = ["S2_D2 760", "S2_D2 850"]

raw.interpolate_bads(
method=dict(fnirs="nearest"), origin=(0.0, 0.0, 0.0), verbose=False
)

# Bad S2_D2 should copy from the nearest good pair, S1_D1.
picks_bad = pick_channels(raw.ch_names, ["S2_D2 760", "S2_D2 850"], exclude=[])
picks_want = pick_channels(raw.ch_names, ["S1_D1 760", "S1_D1 850"], exclude=[])
assert_allclose(raw.get_data(picks=picks_bad), raw.get_data(picks=picks_want))


@testing.requires_testing_data
def test_interpolation_ecog():
"""Test interpolation for ECoG."""
Expand Down