Skip to content

Commit 1ac8c61

Browse files
committed
address comments
1 parent 9fa97b6 commit 1ac8c61

File tree

7 files changed

+63
-50
lines changed

7 files changed

+63
-50
lines changed

doc/whats_new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ API
142142

143143
- Added :class:`mne.decoding.UnsupervisedSpatialFilter` providing interface for scikit-learn decomposition algorithms to be used with MNE data, by `Jean-Remi King`_ and `Asish Panda`_
144144

145-
- Deprecated :func:`mne.time_frequency.cwt_morlet` and :func:`mne.time_frequency.single_trial_power` in favour of :func:`mne.time_frequency.tfr_morlet` with parameter average=False, by `Jean-Remi King`_ and 'Alex Gramfort'_
145+
- Deprecated :func:`mne.time_frequency.cwt_morlet` and :func:`mne.time_frequency.single_trial_power` in favour of :func:`mne.time_frequency.tfr_morlet` with parameter average=False, by `Jean-Remi King`_ and `Alex Gramfort`_
146146

147147

148148
.. _changes_0_12:

mne/epochs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
pick_channels, pick_info, _pick_data_channels,
3131
_pick_aux_channels, _DATA_CH_TYPES_SPLIT)
3232
from .io.proj import setup_proj, ProjMixin, _proj_equal
33-
from .io.base import _BaseRaw, ToDataFrameMixin, TimeMixin, SizeMixin
33+
from .io.base import _BaseRaw, ToDataFrameMixin, TimeMixin
3434
from .bem import _check_origin
3535
from .evoked import EvokedArray, _check_decim
3636
from .baseline import rescale, _log_rescale
@@ -43,7 +43,7 @@
4343
plot_epochs_image, plot_topo_image_epochs)
4444
from .utils import (check_fname, logger, verbose, _check_type_picks,
4545
_time_mask, check_random_state, warn, _check_copy_dep,
46-
sizeof_fmt)
46+
sizeof_fmt, SizeMixin)
4747
from .externals.six import iteritems, string_types
4848
from .externals.six.moves import zip
4949

mne/evoked.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .filter import resample, detrend, FilterMixin
1717
from .fixes import in1d
1818
from .utils import check_fname, logger, verbose, _time_mask, warn, sizeof_fmt
19+
from .utils import SizeMixin
1920
from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field,
2021
plot_evoked_image, plot_evoked_topo)
2122
from .viz.evoked import (_plot_evoked_white, plot_evoked_joint,
@@ -33,7 +34,7 @@
3334
from .io.write import (start_file, start_block, end_file, end_block,
3435
write_int, write_string, write_float_matrix,
3536
write_id)
36-
from .io.base import ToDataFrameMixin, TimeMixin, SizeMixin
37+
from .io.base import ToDataFrameMixin, TimeMixin
3738

3839
_aspect_dict = {'average': FIFF.FIFFV_ASPECT_AVERAGE,
3940
'standard_error': FIFF.FIFFV_ASPECT_STD_ERR}

mne/io/base.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,45 +34,15 @@
3434
from ..parallel import parallel_func
3535
from ..utils import (_check_fname, _check_pandas_installed, sizeof_fmt,
3636
_check_pandas_index_arguments, _check_copy_dep,
37-
check_fname, _get_stim_channel, object_hash,
38-
logger, verbose, _time_mask, warn, object_size)
37+
check_fname, _get_stim_channel,
38+
logger, verbose, _time_mask, warn, SizeMixin)
3939
from ..viz import plot_raw, plot_raw_psd, plot_raw_psd_topo
4040
from ..defaults import _handle_default
4141
from ..externals.six import string_types
4242
from ..event import find_events, concatenate_events
4343
from ..annotations import _combine_annotations, _onset_to_seconds
4444

4545

46-
class SizeMixin(object):
47-
"""Class to estimate MNE object sizes"""
48-
@property
49-
def _size(self):
50-
"""Estimate of the object size"""
51-
try:
52-
size = object_size(self.info)
53-
except Exception:
54-
warn('Could not get size for self.info')
55-
return -1
56-
if hasattr(self, 'data'):
57-
size += object_size(self.data)
58-
elif hasattr(self, '_data'):
59-
size += object_size(self._data)
60-
return size
61-
62-
def __hash__(self):
63-
from ..evoked import Evoked
64-
from ..epochs import _BaseEpochs
65-
if isinstance(self, Evoked):
66-
return object_hash(dict(info=self.info, data=self.data))
67-
elif isinstance(self, (_BaseEpochs, _BaseRaw)):
68-
if not self.preload:
69-
raise RuntimeError('Cannot hash %s unless data are loaded'
70-
% self.__class__.__name__)
71-
return object_hash(dict(info=self.info, data=self._data))
72-
else:
73-
raise RuntimeError('Hashing unknown object type: %s' % type(self))
74-
75-
7646
class ToDataFrameMixin(object):
7747
"""Class to add to_data_frame capabilities to certain classes."""
7848
def _get_check_picks(self, picks, picks_check):

mne/time_frequency/tests/test_tfr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def test_morlet():
3333

3434

3535
def test_time_frequency():
36-
"""Test the to-be-deprecated time frequency transform (PSD and ITC).
37-
"""
36+
"""Test the to-be-deprecated time frequency transform (PSD and ITC)"""
3837
# Set parameters
3938
event_id = 1
4039
tmin = -0.2
@@ -278,6 +277,8 @@ def test_tfr_multitaper():
278277
n_cycles=freqs / 2., time_bandwidth=4.0,
279278
return_itc=False, average=False).average()
280279

280+
print(power_evoked) # test repr for EpochsTFR
281+
281282
assert_raises(ValueError, tfr_multitaper, epochs,
282283
freqs=freqs, n_cycles=freqs / 2.,
283284
return_itc=True, average=False)
@@ -451,6 +452,7 @@ def test_add_channels():
451452

452453

453454
def test_compute_tfr():
455+
"""Test _compute_tfr function"""
454456
# Set parameters
455457
event_id = 1
456458
tmin = -0.2

mne/time_frequency/tfr.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
from ..fixes import partial
2020
from ..baseline import rescale
2121
from ..parallel import parallel_func
22-
from ..utils import logger, verbose, _time_mask, check_fname, deprecated
22+
from ..utils import (logger, verbose, _time_mask, check_fname, deprecated,
23+
sizeof_fmt)
2324
from ..channels.channels import ContainsMixin, UpdateChannelsMixin
2425
from ..channels.layout import _pair_grad_sensors
2526
from ..io.pick import pick_info, pick_types
2627
from ..io.meas_info import Info
28+
from ..utils import SizeMixin
2729
from .multitaper import dpss_windows
2830
from ..viz.utils import figure_nobar, plt_show
2931
from ..externals.h5io import write_hdf5, read_hdf5
@@ -249,7 +251,7 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
249251

250252

251253
def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
252-
n_cycles=7.0, zero_mean=None, time_bandwidth=4.0,
254+
n_cycles=7.0, zero_mean=None, time_bandwidth=None,
253255
use_fft=True, decim=1, output='complex', n_jobs=1,
254256
verbose=None):
255257
"""Computes time-frequency transforms.
@@ -271,7 +273,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
271273
zero_mean : bool | None, defaults to None
272274
None means True for method='multitaper' and False for method='morlet'.
273275
If True, make sure the wavelets have a mean of zero.
274-
time_bandwidth : float, defaults to 4.0 (3 tapers)
276+
time_bandwidth : float, defaults to None
277+
If method=multitaper, will be set to 4.0 (3 tapers).
275278
Time x (Full) Bandwidth product. Only applies if
276279
method == 'multitaper'. The number of good tapers (low-bias) is
277280
chosen automatically based on this to equal floor(time_bandwidth - 1).
@@ -296,8 +299,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
296299
coherence across trials.
297300
298301
n_jobs : int, defaults to 1
299-
The number of epochs to process at the same time. The parallization is
300-
implemented across channels.
302+
The number of epochs to process at the same time. The parallelization
303+
is implemented across channels.
301304
verbose : bool, str, int, or None, defaults to None
302305
If not None, override default verbose level (see mne.verbose).
303306
@@ -341,7 +344,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
341344
# Default zero_mean = True if multitaper else False
342345
zero_mean = method == 'multitaper' if zero_mean is None else zero_mean
343346
if not isinstance(zero_mean, bool):
344-
raise ValueError('')
347+
raise ValueError('zero_mean should be of type bool. Got %s.'
348+
% type(zero_mean))
345349
frequencies = np.asarray(frequencies)
346350

347351
# XXX Can we compute single-trial phases with multitaper?
@@ -354,10 +358,12 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
354358
if method == 'morlet':
355359
W = morlet(sfreq, frequencies, n_cycles=n_cycles, zero_mean=zero_mean)
356360
Ws = [W] # to have same dimensionality as the 'multitaper' case
357-
if time_bandwidth != 4.0:
361+
if time_bandwidth is not None:
358362
raise ValueError('time_bandwidth only applies to "multitaper"'
359363
' method.')
360364
elif method == 'multitaper':
365+
if time_bandwidth is None:
366+
time_bandwidth = 4.0
361367
Ws = _make_dpss(sfreq, frequencies, n_cycles=n_cycles,
362368
time_bandwidth=time_bandwidth, zero_mean=zero_mean)
363369

@@ -395,6 +401,7 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
395401
out[channel_idx] = tfr
396402

397403
if ('avg_' not in output) and ('itc' not in output):
404+
# This is to enforce that the first dimension is for epochs
398405
out = out.transpose(1, 0, 2, 3)
399406
return out
400407

@@ -759,11 +766,11 @@ def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1,
759766
zero_mean : bool, defaults to True
760767
Make sure the wavelet has a mean of zero.
761768
762-
.. versionadded:: 0.12.0
769+
.. versionadded:: 0.13.0
763770
average : bool, defaults to True
764771
If True average accross Epochs.
765772
766-
.. versionadded:: 0.12.0
773+
.. versionadded:: 0.13.0
767774
verbose : bool, str, int, or None, defaults to None
768775
If not None, override default verbose level (see mne.verbose).
769776
@@ -825,7 +832,7 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
825832
average : bool, defaults to True
826833
If True average accross Epochs.
827834
828-
.. versionadded:: 0.12.0
835+
.. versionadded:: 0.13.0
829836
verbose : bool, str, int, or None, defaults to None
830837
If not None, override default verbose level (see mne.verbose).
831838
@@ -853,7 +860,7 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
853860

854861
# TFR(s) class
855862

856-
class _BaseTFR(ContainsMixin, UpdateChannelsMixin):
863+
class _BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin):
857864
@property
858865
def ch_names(self):
859866
return self.info['ch_names']
@@ -1419,6 +1426,7 @@ def __repr__(self):
14191426
s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1])
14201427
s += ", nave : %d" % self.nave
14211428
s += ', channels : %d' % self.data.shape[0]
1429+
s += ', ~%s' % (sizeof_fmt(self._size),)
14221430
return "<AverageTFR | %s>" % s
14231431

14241432
def save(self, fname, overwrite=False):
@@ -1464,7 +1472,7 @@ class EpochsTFR(_BaseTFR):
14641472
14651473
Notes
14661474
-----
1467-
.. versionadded:: 0.12.0
1475+
.. versionadded:: 0.13.0
14681476
"""
14691477
@verbose
14701478
def __init__(self, info, data, times, freqs, comment=None,
@@ -1493,6 +1501,7 @@ def __repr__(self):
14931501
s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1])
14941502
s += ", epochs : %d" % self.data.shape[0]
14951503
s += ', channels : %d' % self.data.shape[1]
1504+
s += ', ~%s' % (sizeof_fmt(self._size),)
14961505
return "<EpochsTFR | %s>" % s
14971506

14981507
def average(self):

mne/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,6 +1727,37 @@ def sizeof_fmt(num):
17271727
return '1 byte'
17281728

17291729

1730+
class SizeMixin(object):
1731+
"""Class to estimate MNE object sizes"""
1732+
@property
1733+
def _size(self):
1734+
"""Estimate of the object size"""
1735+
try:
1736+
size = object_size(self.info)
1737+
except Exception:
1738+
warn('Could not get size for self.info')
1739+
return -1
1740+
if hasattr(self, 'data'):
1741+
size += object_size(self.data)
1742+
elif hasattr(self, '_data'):
1743+
size += object_size(self._data)
1744+
return size
1745+
1746+
def __hash__(self):
1747+
from ..evoked import Evoked
1748+
from ..epochs import _BaseEpochs
1749+
from ..io.base import _BaseRaw
1750+
if isinstance(self, Evoked):
1751+
return object_hash(dict(info=self.info, data=self.data))
1752+
elif isinstance(self, (_BaseEpochs, _BaseRaw)):
1753+
if not self.preload:
1754+
raise RuntimeError('Cannot hash %s unless data are loaded'
1755+
% self.__class__.__name__)
1756+
return object_hash(dict(info=self.info, data=self._data))
1757+
else:
1758+
raise RuntimeError('Hashing unknown object type: %s' % type(self))
1759+
1760+
17301761
def _url_to_local_path(url, path):
17311762
"""Mirror a url path in a local destination (keeping folder structure)"""
17321763
destination = urllib.parse.urlparse(url).path

0 commit comments

Comments
 (0)