Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c8e5c86
find_bad_channels_maxwell() returns segment info
hoechenberger May 28, 2020
d59abd2
Return a dict
hoechenberger Jun 8, 2020
7dcdda0
Docstrings
hoechenberger Jun 9, 2020
9cbae20
Fixes
hoechenberger Jun 9, 2020
bbab584
Docstring consistency
hoechenberger Jun 9, 2020
60d2b20
flake
hoechenberger Jun 9, 2020
298bcd3
Add ch_names and ch_types to output
hoechenberger Jun 9, 2020
42d8d8d
Add tests
hoechenberger Jun 9, 2020
527030e
Add comment
hoechenberger Jun 9, 2020
fe82ce8
Docstring
hoechenberger Jun 9, 2020
0a76dab
Tutorial
hoechenberger Jun 9, 2020
9d8d735
Cleanup
hoechenberger Jun 9, 2020
a1098fb
Typo & phrasing
hoechenberger Jun 9, 2020
dd074c6
Improve tutorial
hoechenberger Jun 9, 2020
694a8b2
Return all channels (not only good) & improve test
hoechenberger Jun 9, 2020
a9a1a4f
Naming
hoechenberger Jun 9, 2020
050d681
Simplify
hoechenberger Jun 9, 2020
f48c84a
Tutorial layout
hoechenberger Jun 9, 2020
9252465
Tutorial layout
hoechenberger Jun 9, 2020
4ca998d
Fix indexing bug
hoechenberger Jun 10, 2020
0cc3b32
Improve tutorial
hoechenberger Jun 10, 2020
df8be80
Return value dims, pre-fill with NaN's
hoechenberger Jun 10, 2020
518e111
Update tests
hoechenberger Jun 10, 2020
625f726
Update docstring
hoechenberger Jun 10, 2020
609d26d
Flake & cleanup
hoechenberger Jun 10, 2020
729b844
Return start, stop times of all chunks
hoechenberger Jun 10, 2020
51b2bf9
Docstring
hoechenberger Jun 10, 2020
1883d3b
Fix returned bin times
hoechenberger Jun 10, 2020
b6d00fe
Label bins with window edges in secs in tutorial
hoechenberger Jun 10, 2020
25902b6
Use np.full()
hoechenberger Jun 10, 2020
f72217b
Remove plt.show()
hoechenberger Jun 10, 2020
9998f7f
Add vertical dashes lines between segments
hoechenberger Jun 10, 2020
709c458
Only plot gradiometers
hoechenberger Jun 10, 2020
674f149
Forgot to adjust comment
hoechenberger Jun 10, 2020
0cbb729
Typo
hoechenberger Jun 10, 2020
5f36bd1
Use inf, -inf instead of nan
hoechenberger Jun 10, 2020
d061258
Cleanup
hoechenberger Jun 10, 2020
704ca7e
Remove leftover, elaborate on masking
hoechenberger Jun 10, 2020
eefd90e
Fix circle
hoechenberger Jun 10, 2020
cb386ce
Revert "Use inf, -inf instead of nan"
hoechenberger Jun 16, 2020
b601383
Update tutorial
hoechenberger Jun 16, 2020
8bebd06
Fix docstring
hoechenberger Jun 16, 2020
fea5a8d
Add warning
hoechenberger Jun 16, 2020
b1cefb8
Remove double linebreak
hoechenberger Jun 16, 2020
37775cb
Phrasing
hoechenberger Jun 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions mne/preprocessing/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def _trans_sss_basis(exp, all_coils, trans=None, coil_scale=100.):
# st_only
@verbose
def find_bad_channels_maxwell(
raw, limit=7., duration=5., min_count=5,
raw, limit=7., duration=5., min_count=5, return_scores=False,
origin='auto', int_order=8, ext_order=3, calibration=None,
cross_talk=None, coord_frame='head', regularize='in', ignore_ref=False,
bad_condition='error', head_pos=None, mag_scale=100.,
Expand All @@ -1877,10 +1877,19 @@ def find_bad_channels_maxwell(
Detection limit (default is 7.). Smaller values will find more bad
channels at increased risk of including good ones.
duration : float
Duration into which to window the data for processing. Default is 5.
Duration of the segments into which to slice the data for processing,
in seconds. Default is 5.
min_count : int
Minimum number of times a channel must show up as bad in a chunk.
Default is 5.
return_scores : bool
If ``True``, return a dictionary with scoring information for each
evaluated segment of the data. Default is ``False``.

.. warning:: This feature is experimental and may change in a future
version of MNE-Python without prior notice. Please
report any problems and enhancement proposals to the
developers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need ..versionadded

%(maxwell_origin_int_ext_calibration_cross)s
%(maxwell_coord)s
%(maxwell_reg_ref_cond_pos)s
Expand All @@ -1896,6 +1905,33 @@ def find_bad_channels_maxwell(
flat_chs : list
List of MEG channels that were detected as being flat in at least
``min_count`` segments.
scores : dict
A dictionary with information produced by the scoring algorithms.
Only returned when ``return_scores`` is ``True``. It contains the
following keys:

- ``ch_names`` : ndarray, shape (n_meg,)
The names of the MEG channels. Their order corresponds to the
order of rows in the ``scores`` and ``limits`` arrays.
- ``ch_types`` : ndarray, shape (n_meg,)
The types of the MEG channels in ``ch_names`` (``'mag'``,
``'grad'``).
- ``bins`` : ndarray, shape (n_windows, 2)
The inclusive window boundaries (start and stop; in seconds) used
to calculate the scores.
- ``scores_flat`` : ndarray, shape (n_meg, n_windows)
The scores for testing whether MEG channels are flat.
- ``limits_flat`` : ndarray, shape (n_meg, 1)
The score thresholds above which a segment was claffified as
"flat".
- ``scores_noisy`` : ndarray, shape (n_meg, n_windows)
The scores for testing whether MEG channels are noisy.
- ``limits_noisy`` : ndarray, shape (n_meg, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ndarray, shape (n_meg, 1)? It should be float, no? I see below you only set it for the good channels, but personally I would just leave it as a float (it's okay to set it for all channels because the bad ones get values of -np.inf so comparisons pass).

However, I just realized that a different, probably cleaner way around all of this: return ch_names : ndarray, shape (n_good_meg,) instead of shape (n_meg,) (and adjust all other vars, too). Then you only return what actually gets processed by the function.

Sorry for the run-around about this, hopefully just using n_good_meg makes everything cleaner...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I just realized that a different, probably cleaner way around all of this: return ch_names : ndarray, shape (n_good_meg,) instead of shape (n_meg,) (and adjust all other vars, too). Then you only return what actually gets processed by the function.

That's what I started out with, and then made the decision to return all meg channels, including the bad ones (even though they would not have any values set): Because I expect this to make it easier to create an interactive visualization based on the arrays returned here. I'm thinking about something like the viz I put into the tutorial, but also displaying channels from info['bads'], and where clicking on a tile of the heatmap would take you straight to the corresponding raw segment. It would be very useful to show all channels, including bads, in such an interactive visualization. And this would be easier to achieve if we got an array of the correct shape (i.e., with n_meg rows) right from the beginning…

Thoughts?

Copy link
Member Author

@hoechenberger hoechenberger Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ndarray, shape (n_meg, 1)? It should be float, no?

Also while this would work for the noisy channels, it wouldn't be applicable to the flat detection, because there we use different thresholds depending on channel type (mag, grad)

The score thresholds above which a segment was claffified as
"noisy".

.. note:: The scores and limits for channels marked as ``bad`` in the
input data will will be set to ``np.nan``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will will -> will


See Also
--------
Expand All @@ -1904,10 +1940,11 @@ def find_bad_channels_maxwell(

Notes
-----
All arguments after ``raw``, ``limit``, ``duration``, and ``min_count``
are the same as :func:`~maxwell_filter`, except that the following are
not allowed in this function because they are unused: ``st_duration``,
``st_correlation``, ``destination``, ``st_fixed``, and ``st_only``.
All arguments after ``raw``, ``limit``, ``duration``, ``min_count``, and
``return_scores`` are the same as :func:`~maxwell_filter`, except that the
following are not allowed in this function because they are unused:
``st_duration``, ``st_correlation``, ``destination``, ``st_fixed``, and
``st_only``.

This algorithm, for a given chunk of data:

Expand Down Expand Up @@ -1973,20 +2010,48 @@ def find_bad_channels_maxwell(
if pick in params['grad_picks'] else
flat_limits['mag']
for pick in good_meg_picks])

flat_step = max(20, int(30 * raw.info['sfreq'] / 1000.))
all_flats = set()

# Prepare variables to return if `return_scores=True`.
bins = np.empty((len(starts), 2)) # To store start, stop of each segment
# We create ndarrays with one row per channel, regardless of channel type
# and whether the channel has been marked as "bad" in info or not. This
# makes indexing in the loop easier. We only filter this down to the subset
# of MEG channels after all processing is done.
ch_names = np.array(raw.ch_names)
ch_types = np.array(raw.get_channel_types())

scores_flat = np.full((len(ch_names), len(starts)), np.nan)
scores_noisy = np.full_like(scores_flat, fill_value=np.nan)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, should you just use -np.inf?


thresh_flat = np.full((len(ch_names), 1), np.nan)
thresh_noisy = np.full_like(thresh_flat, fill_value=np.nan)

for si, (start, stop) in enumerate(zip(starts, stops)):
n_iter = 0
orig_data = raw.get_data(None, start, stop, verbose=False)
chunk_raw = RawArray(
orig_data, params['info'],
first_samp=raw.first_samp + start, copy='data', verbose=False)

t = chunk_raw.times[[0, -1]] + start / raw.info['sfreq']
logger.info(' Interval %3d: %8.3f - %8.3f'
% ((si + 1,) + tuple(t[[0, -1]])))

# Flat pass: var < 0.01 fT/cm or 0.01 fT for at 30 ms (or 20 samples)
n = stop - start
flat_stop = n - (n % flat_step)
data = chunk_raw.get_data(good_meg_picks, 0, flat_stop)
data.shape = (data.shape[0], -1, flat_step)
delta = np.std(data, axis=-1).min(-1) # min std across segments

# We may want to return this later if `return_scores=True`.
bins[si, :] = t[0], t[-1]
scores_flat[good_meg_picks, si] = delta
thresh_flat[good_meg_picks] = these_limits.reshape(-1, 1)

chunk_flats = delta < these_limits
chunk_flats = np.where(chunk_flats)[0]
chunk_flats = [raw.ch_names[good_meg_picks[chunk_flat]]
Expand All @@ -2000,9 +2065,6 @@ def find_bad_channels_maxwell(
chunk_noisy = list()
params['st_duration'] = int(round(
chunk_raw.times[-1] * raw.info['sfreq']))
t = chunk_raw.times[[0, -1]] + start / raw.info['sfreq']
logger.info(' Interval %3d: %8.3f - %8.3f'
% ((si + 1,) + tuple(t[[0, -1]])))
for n_iter in range(1, 101): # iteratively exclude the worst ones
assert set(raw.info['bads']) & set(chunk_noisy) == set()
params['good_mask'][:] = [
Expand All @@ -2028,8 +2090,14 @@ def find_bad_channels_maxwell(
z = (range_ - mean) / std
idx = np.argmax(z)
max_ = z[idx]

# We may want to return this later if `return_scores=True`.
scores_noisy[these_picks, si] = z
thresh_noisy[these_picks] = limit

if max_ < limit:
break

name = raw.ch_names[these_picks[idx]]
logger.debug(' Bad: %s %0.1f'
% (name, max_))
Expand All @@ -2040,7 +2108,27 @@ def find_bad_channels_maxwell(
key=lambda x: raw.ch_names.index(x))
flat_chs = sorted((f for f, c in flat_chs.items() if c >= min_count),
key=lambda x: raw.ch_names.index(x))

# Only include MEG channels.
ch_names = ch_names[params['meg_picks']]
ch_types = ch_types[params['meg_picks']]
scores_flat = scores_flat[params['meg_picks']]
thresh_flat = thresh_flat[params['meg_picks']]
scores_noisy = scores_noisy[params['meg_picks']]
thresh_noisy = thresh_noisy[params['meg_picks']]

logger.info(' Static bad channels: %s' % (noisy_chs,))
logger.info(' Static flat channels: %s' % (flat_chs,))
logger.info('[done]')
return noisy_chs, flat_chs

if return_scores:
scores = dict(ch_names=ch_names,
ch_types=ch_types,
bins=bins,
scores_flat=scores_flat,
limits_flat=thresh_flat,
scores_noisy=scores_noisy,
limits_noisy=thresh_noisy)
return noisy_chs, flat_chs, scores
else:
return noisy_chs, flat_chs
94 changes: 81 additions & 13 deletions mne/preprocessing/tests/test_maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,18 +1061,25 @@ def test_mf_skips():


@testing.requires_testing_data
@pytest.mark.parametrize('fname, bads, annot, add_ch, ignore_ref, want_bads', [
# Neuromag data tested against MF
(sample_fname, [], False, False, False, ['MEG 2443']),
# add 0111 to test picking, add annot to test it, and prepend chs for idx
(sample_fname, ['MEG 0111'], True, True, False, ['MEG 2443']),
# CTF data seems to be sensitive to linalg lib (?) because some channels
# are very close to the limit, so we just check that one shows up
(ctf_fname_continuous, [], False, False, False, {'BR1-4304'}),
(ctf_fname_continuous, [], False, False, True, ['MLC24-4304']), # faked
])
@pytest.mark.parametrize(
('fname', 'bads', 'annot', 'add_ch', 'ignore_ref', 'want_bads',
'return_scores'), [
# Neuromag data tested against MF
(sample_fname, [], False, False, False, ['MEG 2443'], False),
# add 0111 to test picking, add annot to test it, and prepend chs for
# idx
(sample_fname, ['MEG 0111'], True, True, False, ['MEG 2443'], False),
# CTF data seems to be sensitive to linalg lib (?) because some
# channels are very close to the limit, so we just check that one shows
# up
(ctf_fname_continuous, [], False, False, False, {'BR1-4304'}, False),
# faked
(ctf_fname_continuous, [], False, False, True, ['MLC24-4304'], False),
# For `return_scores=True`
(sample_fname, ['MEG 0111'], True, True, False, ['MEG 2443'], True)
])
def test_find_bad_channels_maxwell(fname, bads, annot, add_ch, ignore_ref,
want_bads):
want_bads, return_scores):
"""Test automatic bad channel detection."""
if fname.endswith('.ds'):
raw = read_raw_ctf(fname).load_data()
Expand All @@ -1086,6 +1093,9 @@ def test_find_bad_channels_maxwell(fname, bads, annot, add_ch, ignore_ref,
raw._data[flat_idx] = 0 # MaxFilter didn't have this but doesn't affect it
want_flats = [raw.ch_names[flat_idx]]
raw.apply_gradient_compensation(0)

min_count = 5

if add_ch:
raw_eeg = read_raw_fif(fname)
raw_eeg.pick_types(meg=False, eeg=True, exclude=()).load_data()
Expand All @@ -1105,10 +1115,19 @@ def test_find_bad_channels_maxwell(fname, bads, annot, add_ch, ignore_ref,
assert step == 1502
raw.annotations.append(step * dt + raw._first_time, dt, 'BAD')
with catch_logging() as log:
got_bads, got_flats = find_bad_channels_maxwell(
return_vals = find_bad_channels_maxwell(
raw, origin=(0., 0., 0.04), regularize=None,
bad_condition='ignore', skip_by_annotation='BAD', verbose=True,
ignore_ref=ignore_ref)
ignore_ref=ignore_ref, min_count=min_count,
return_scores=return_scores)

if return_scores:
assert len(return_vals) == 3
got_bads, got_flats, got_scores = return_vals
else:
assert len(return_vals) == 2
got_bads, got_flats = return_vals

if isinstance(want_bads, list):
assert got_bads == want_bads # from MaxFilter
else:
Expand All @@ -1118,5 +1137,54 @@ def test_find_bad_channels_maxwell(fname, bads, annot, add_ch, ignore_ref,
assert 'Interval 1: 0.00' in log
assert 'Interval 2: 5.00' in log

if return_scores:
meg_chs = raw.copy().pick_types(meg=True, exclude=[]).ch_names
ch_types = raw.get_channel_types(meg_chs)

assert list(got_scores['ch_names']) == meg_chs
assert list(got_scores['ch_types']) == ch_types
# Check that time is monotonically increasing.
assert (np.diff(got_scores['bins'].flatten()) >= 0).all()

assert (got_scores['scores_flat'].shape ==
got_scores['scores_noisy'].shape ==
(len(meg_chs), len(got_scores['bins'])))

assert (got_scores['limits_flat'].shape ==
got_scores['limits_noisy'].shape ==
(len(meg_chs), 1))

# Check "flat" scores.
scores_flat = got_scores['scores_flat']
limits_flat = got_scores['limits_flat']
# The following essentially is just this:
# n_segments_below_limit = (scores_flat < limits_flat).sum(-1)
# made to work with NaN's in the scores.
n_segments_below_limit = np.less(
scores_flat, limits_flat,
where=np.equal(np.isnan(scores_flat), False),
out=np.full_like(scores_flat, fill_value=False)).sum(-1)

ch_idx = np.where(n_segments_below_limit >=
min(min_count, len(got_scores['bins'])))
flats = set(got_scores['ch_names'][ch_idx])
assert flats == set(want_flats)

# Check "noisy" scores.
scores_noisy = got_scores['scores_noisy']
limits_noisy = got_scores['limits_noisy']
# The following essentially is just this:
# n_segments_beyond_limit = (scores_noisy > limits_noisy).sum(-1)
# made to work with NaN's in the scores.
n_segments_beyond_limit = np.greater(
scores_noisy, limits_noisy,
where=np.equal(np.isnan(scores_noisy), False),
out=np.full_like(scores_noisy, fill_value=False)).sum(-1)

ch_idx = np.where(n_segments_beyond_limit >=
min(min_count, len(got_scores['bins'])))
bads = set(got_scores['ch_names'][ch_idx])
assert bads == set(want_bads)


run_tests_if_main()
Loading