1919from ..fixes import partial
2020from ..baseline import rescale
2121from ..parallel import parallel_func
22- from ..utils import logger , verbose , _time_mask , check_fname , deprecated
22+ from ..utils import (logger , verbose , _time_mask , check_fname , deprecated ,
23+ sizeof_fmt )
2324from ..channels .channels import ContainsMixin , UpdateChannelsMixin
2425from ..channels .layout import _pair_grad_sensors
2526from ..io .pick import pick_info , pick_types
2627from ..io .meas_info import Info
28+ from ..utils import SizeMixin
2729from .multitaper import dpss_windows
2830from ..viz .utils import figure_nobar , plt_show
2931from ..externals .h5io import write_hdf5 , read_hdf5
@@ -249,7 +251,7 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
249251
250252
251253def _compute_tfr (epoch_data , frequencies , sfreq = 1.0 , method = 'morlet' ,
252- n_cycles = 7.0 , zero_mean = None , time_bandwidth = 4.0 ,
254+ n_cycles = 7.0 , zero_mean = None , time_bandwidth = None ,
253255 use_fft = True , decim = 1 , output = 'complex' , n_jobs = 1 ,
254256 verbose = None ):
255257 """Computes time-frequency transforms.
@@ -271,7 +273,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
271273 zero_mean : bool | None, defaults to None
272274 None means True for method='multitaper' and False for method='morlet'.
273275 If True, make sure the wavelets have a mean of zero.
274- time_bandwidth : float, defaults to 4.0 (3 tapers)
276+ time_bandwidth : float, defaults to None
277+ If method=multitaper, will be set to 4.0 (3 tapers).
275278 Time x (Full) Bandwidth product. Only applies if
276279 method == 'multitaper'. The number of good tapers (low-bias) is
277280 chosen automatically based on this to equal floor(time_bandwidth - 1).
@@ -296,8 +299,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
296299 coherence across trials.
297300
298301 n_jobs : int, defaults to 1
299- The number of epochs to process at the same time. The parallization is
300- implemented across channels.
302+ The number of epochs to process at the same time. The parallelization
303+ is implemented across channels.
301304 verbose : bool, str, int, or None, defaults to None
302305 If not None, override default verbose level (see mne.verbose).
303306
@@ -341,7 +344,8 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
341344 # Default zero_mean = True if multitaper else False
342345 zero_mean = method == 'multitaper' if zero_mean is None else zero_mean
343346 if not isinstance (zero_mean , bool ):
344- raise ValueError ('' )
347+ raise ValueError ('zero_mean should be of type bool. Got %s.'
348+ % type (zero_mean ))
345349 frequencies = np .asarray (frequencies )
346350
347351 # XXX Can we compute single-trial phases with multitaper?
@@ -354,10 +358,12 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
354358 if method == 'morlet' :
355359 W = morlet (sfreq , frequencies , n_cycles = n_cycles , zero_mean = zero_mean )
356360 Ws = [W ] # to have same dimensionality as the 'multitaper' case
357- if time_bandwidth != 4.0 :
361+ if time_bandwidth is not None :
358362 raise ValueError ('time_bandwidth only applies to "multitaper"'
359363 ' method.' )
360364 elif method == 'multitaper' :
365+ if time_bandwidth is None :
366+ time_bandwidth = 4.0
361367 Ws = _make_dpss (sfreq , frequencies , n_cycles = n_cycles ,
362368 time_bandwidth = time_bandwidth , zero_mean = zero_mean )
363369
@@ -395,6 +401,7 @@ def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
395401 out [channel_idx ] = tfr
396402
397403 if ('avg_' not in output ) and ('itc' not in output ):
404+ # This is to enforce that the first dimension is for epochs
398405 out = out .transpose (1 , 0 , 2 , 3 )
399406 return out
400407
@@ -759,11 +766,11 @@ def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1,
759766 zero_mean : bool, defaults to True
760767 Make sure the wavelet has a mean of zero.
761768
762- .. versionadded:: 0.12 .0
769+ .. versionadded:: 0.13 .0
763770 average : bool, defaults to True
764771 If True average accross Epochs.
765772
766- .. versionadded:: 0.12 .0
773+ .. versionadded:: 0.13 .0
767774 verbose : bool, str, int, or None, defaults to None
768775 If not None, override default verbose level (see mne.verbose).
769776
@@ -825,7 +832,7 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
825832 average : bool, defaults to True
826833 If True average accross Epochs.
827834
828- .. versionadded:: 0.12 .0
835+ .. versionadded:: 0.13 .0
829836 verbose : bool, str, int, or None, defaults to None
830837 If not None, override default verbose level (see mne.verbose).
831838
@@ -853,7 +860,7 @@ def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
853860
854861# TFR(s) class
855862
856- class _BaseTFR (ContainsMixin , UpdateChannelsMixin ):
863+ class _BaseTFR (ContainsMixin , UpdateChannelsMixin , SizeMixin ):
857864 @property
858865 def ch_names (self ):
859866 return self .info ['ch_names' ]
@@ -1419,6 +1426,7 @@ def __repr__(self):
14191426 s += ", freq : [%f, %f]" % (self .freqs [0 ], self .freqs [- 1 ])
14201427 s += ", nave : %d" % self .nave
14211428 s += ', channels : %d' % self .data .shape [0 ]
1429+ s += ', ~%s' % (sizeof_fmt (self ._size ),)
14221430 return "<AverageTFR | %s>" % s
14231431
14241432 def save (self , fname , overwrite = False ):
@@ -1464,7 +1472,7 @@ class EpochsTFR(_BaseTFR):
14641472
14651473 Notes
14661474 -----
1467- .. versionadded:: 0.12 .0
1475+ .. versionadded:: 0.13 .0
14681476 """
14691477 @verbose
14701478 def __init__ (self , info , data , times , freqs , comment = None ,
@@ -1493,6 +1501,7 @@ def __repr__(self):
14931501 s += ", freq : [%f, %f]" % (self .freqs [0 ], self .freqs [- 1 ])
14941502 s += ", epochs : %d" % self .data .shape [0 ]
14951503 s += ', channels : %d' % self .data .shape [1 ]
1504+ s += ', ~%s' % (sizeof_fmt (self ._size ),)
14961505 return "<EpochsTFR | %s>" % s
14971506
14981507 def average (self ):
0 commit comments