MRG: Add second-order sections filtering (#3324)
ENH: Add second-order sections filtering
This commit is contained in:
committed by
Alexandre Gramfort
parent
0a52ab33fa
commit
42678aa7ec
@@ -6,3 +6,4 @@ omit =
|
||||
*/mne/externals/*
|
||||
*/bin/*
|
||||
*/setup.py
|
||||
*/mne/fixes*
|
||||
|
||||
@@ -461,6 +461,8 @@ EEG referencing:
|
||||
|
||||
band_pass_filter
|
||||
construct_iir_filter
|
||||
estimate_ringing_samples
|
||||
filter_data
|
||||
high_pass_filter
|
||||
low_pass_filter
|
||||
notch_filter
|
||||
|
||||
+4
-4
@@ -25,6 +25,8 @@ Changelog
|
||||
|
||||
- Added :func:`mne.viz.ica.plot_ica_properties` that allows ploting of independent component properties similar to ``pop_prop`` in EEGLAB. Also :class:`mne.preprocessing.ica.ICA` has :func:`mne.preprocessing.ica.ICA.plot_properties` method now. Added by `Mikołaj Magnuski`_
|
||||
|
||||
- Add second-order sections (instead of ``(b, a)`` form) IIR filtering for reduced numerical error by `Eric Larson`_
|
||||
|
||||
BUG
|
||||
~~~
|
||||
|
||||
@@ -63,6 +65,8 @@ API
|
||||
|
||||
- Added option to pass a list of axes to :func:`mne.viz.epochs.plot_epochs_image` by `Mikołaj Magnuski`_
|
||||
|
||||
- Constructing IIR filters in :func:`mne.filter.construct_iir_filter` defaults to ``output='ba'`` in 0.13 but this will be changed to ``output='sos'`` by `Eric Larson`_
|
||||
|
||||
.. _changes_0_12:
|
||||
|
||||
Version 0.12
|
||||
@@ -1605,7 +1609,3 @@ of commits):
|
||||
.. _Jon Houck: http://www.unm.edu/~jhouck/
|
||||
|
||||
.. _Pablo-Arias: https://github.com/Pablo-Arias
|
||||
|
||||
.. _Alexander Rudiuk: https://github.com/ARudiuk
|
||||
|
||||
.. _Mikołaj Magnuski: https://github.com/mmagnuski
|
||||
|
||||
+169
-69
@@ -8,7 +8,7 @@ from scipy.fftpack import fft, ifftshift, fftfreq
|
||||
from .cuda import (setup_cuda_fft_multiply_repeated, fft_multiply_repeated,
|
||||
setup_cuda_fft_resample, fft_resample, _smart_pad)
|
||||
from .externals.six import string_types, integer_types
|
||||
from .fixes import get_firwin2, get_filtfilt
|
||||
from .fixes import get_firwin2, get_filtfilt, get_sosfiltfilt, partial
|
||||
from .parallel import parallel_func, check_n_jobs
|
||||
from .time_frequency.multitaper import dpss_windows, _mt_spectra
|
||||
from .utils import logger, verbose, sum_squared, check_version, warn
|
||||
@@ -376,49 +376,100 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
|
||||
return x
|
||||
|
||||
|
||||
def _check_coefficients(b, a):
|
||||
def _check_coefficients(system):
|
||||
"""Check for filter stability"""
|
||||
from scipy.signal import tf2zpk
|
||||
z, p, k = tf2zpk(b, a)
|
||||
if isinstance(system, tuple):
|
||||
from scipy.signal import tf2zpk
|
||||
z, p, k = tf2zpk(*system)
|
||||
else: # sos
|
||||
from scipy.signal import sos2zpk
|
||||
z, p, k = sos2zpk(system)
|
||||
if np.any(np.abs(p) > 1.0):
|
||||
raise RuntimeError('Filter poles outside unit circle, filter will be '
|
||||
'unstable. Consider using different filter '
|
||||
'coefficients.')
|
||||
|
||||
|
||||
def _filtfilt(x, b, a, padlen, picks, n_jobs, copy):
|
||||
def _filtfilt(x, iir_params, picks, n_jobs, copy):
|
||||
"""Helper to more easily call filtfilt"""
|
||||
# set up array for filtering, reshape to 2D, operate on last axis
|
||||
filtfilt = get_filtfilt()
|
||||
padlen = min(iir_params['padlen'], len(x))
|
||||
n_jobs = check_n_jobs(n_jobs)
|
||||
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
|
||||
_check_coefficients(b, a)
|
||||
if 'sos' in iir_params:
|
||||
sosfiltfilt = get_sosfiltfilt()
|
||||
fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen)
|
||||
_check_coefficients(iir_params['sos'])
|
||||
else:
|
||||
filtfilt = get_filtfilt()
|
||||
fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'],
|
||||
padlen=padlen)
|
||||
_check_coefficients((iir_params['b'], iir_params['a']))
|
||||
if n_jobs == 1:
|
||||
for p in picks:
|
||||
x[p] = filtfilt(b, a, x[p], padlen=padlen)
|
||||
x[p] = fun(x=x[p])
|
||||
else:
|
||||
parallel, p_fun, _ = parallel_func(filtfilt, n_jobs)
|
||||
data_new = parallel(p_fun(b, a, x[p], padlen=padlen)
|
||||
for p in picks)
|
||||
parallel, p_fun, _ = parallel_func(fun, n_jobs)
|
||||
data_new = parallel(p_fun(x=x[p]) for p in picks)
|
||||
for pp, p in enumerate(picks):
|
||||
x[p] = data_new[pp]
|
||||
x.shape = orig_shape
|
||||
return x
|
||||
|
||||
|
||||
def _estimate_ringing_samples(b, a):
|
||||
"""Helper function for determining IIR padding"""
|
||||
# XXX Need to extend this to more than 1000 samples for long IIR filters!
|
||||
from scipy.signal import lfilter
|
||||
x = np.zeros(1000)
|
||||
def estimate_ringing_samples(system, max_try=100000):
|
||||
"""Estimate filter ringing
|
||||
|
||||
Parameters
|
||||
----------
|
||||
system : tuple | ndarray
|
||||
A tuple of (b, a) or ndarray of second-order sections coefficients.
|
||||
max_try : int
|
||||
Approximate maximum number of samples to try.
|
||||
This will be changed to a multple of 1000.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
The approximate ringing.
|
||||
"""
|
||||
from scipy import signal
|
||||
if isinstance(system, tuple): # TF
|
||||
kind = 'ba'
|
||||
b, a = system
|
||||
zi = [0.] * (len(a) - 1)
|
||||
else:
|
||||
kind = 'sos'
|
||||
sos = system
|
||||
zi = [[0.] * 2] * len(sos)
|
||||
n_per_chunk = 1000
|
||||
n_chunks_max = int(np.ceil(max_try / float(n_per_chunk)))
|
||||
x = np.zeros(n_per_chunk)
|
||||
x[0] = 1
|
||||
h = lfilter(b, a, x)
|
||||
return np.where(np.abs(h) > 0.001 * np.max(np.abs(h)))[0][-1]
|
||||
last_good = n_per_chunk
|
||||
thresh_val = 0
|
||||
for ii in range(n_chunks_max):
|
||||
if kind == 'ba':
|
||||
h, zi = signal.lfilter(b, a, x, zi=zi)
|
||||
else:
|
||||
h, zi = signal.sosfilt(sos, x, zi=zi)
|
||||
x[0] = 0 # for subsequent iterations we want zero input
|
||||
h = np.abs(h)
|
||||
thresh_val = max(0.001 * np.max(h), thresh_val)
|
||||
idx = np.where(np.abs(h) > thresh_val)[0]
|
||||
if len(idx) > 0:
|
||||
last_good = idx[-1]
|
||||
else: # this iteration had no sufficiently lange values
|
||||
idx = (ii - 1) * n_per_chunk + last_good
|
||||
break
|
||||
else:
|
||||
warn('Could not properly estimate ringing for the filter')
|
||||
idx = n_per_chunk * n_chunks_max
|
||||
return idx
|
||||
|
||||
|
||||
def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
f_pass=None, f_stop=None, sfreq=None, btype=None,
|
||||
return_copy=True):
|
||||
def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None,
|
||||
btype=None, return_copy=True):
|
||||
"""Use IIR parameters to get filtering coefficients
|
||||
|
||||
This function works like a wrapper for iirdesign and iirfilter in
|
||||
@@ -428,19 +479,39 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
function) with the filter coefficients ('b' and 'a') and an estimate
|
||||
of the padding necessary ('padlen') so IIR filtering can be performed.
|
||||
|
||||
.. note:: As of 0.14, second-order sections will be used in filter
|
||||
design by default (replacing ``output='ba'`` by
|
||||
``output='sos'``) to help ensure filter stability and
|
||||
reduce numerical error. Second-order sections filtering
|
||||
requires SciPy >= 16.0.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iir_params : dict
|
||||
Dictionary of parameters to use for IIR filtering.
|
||||
If iir_params['b'] and iir_params['a'] exist, these will be used
|
||||
as coefficients to perform IIR filtering. Otherwise, if
|
||||
iir_params['order'] and iir_params['ftype'] exist, these will be
|
||||
used with scipy.signal.iirfilter to make a filter. Otherwise, if
|
||||
iir_params['gpass'] and iir_params['gstop'] exist, these will be
|
||||
used with scipy.signal.iirdesign to design a filter.
|
||||
iir_params['padlen'] defines the number of samples to pad (and
|
||||
an estimate will be calculated if it is not given). See Notes for
|
||||
more details.
|
||||
|
||||
* If ``iir_params['sos']`` exists, it will be used as
|
||||
second-order sections to perform IIR filtering.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
|
||||
* Otherwise, if ``iir_params['b']`` and ``iir_params['a']``
|
||||
exist, these will be used as coefficients to perform IIR
|
||||
filtering.
|
||||
* Otherwise, if ``iir_params['order']`` and
|
||||
``iir_params['ftype']`` exist, these will be used with
|
||||
`scipy.signal.iirfilter` to make a filter.
|
||||
* Otherwise, if ``iir_params['gpass']`` and
|
||||
``iir_params['gstop']`` exist, these will be used with
|
||||
`scipy.signal.iirdesign` to design a filter.
|
||||
* ``iir_params['padlen']`` defines the number of samples to pad
|
||||
(and an estimate will be calculated if it is not given).
|
||||
See Notes for more details.
|
||||
* ``iir_params['output']`` defines the system output kind when
|
||||
designing filters, either "sos" or "ba". For 0.13 the
|
||||
default is 'ba' but will change to 'sos' in 0.14.
|
||||
|
||||
f_pass : float or list of float
|
||||
Frequency for the pass-band. Low-pass and high-pass filters should
|
||||
be a float, band-pass should be a 2-element list of float.
|
||||
@@ -451,23 +522,31 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
The sample rate.
|
||||
btype : str
|
||||
Type of filter. Should be 'lowpass', 'highpass', or 'bandpass'
|
||||
(or analogous string representations known to scipy.signal).
|
||||
(or analogous string representations known to
|
||||
:func:`scipy.signal.iirfilter`).
|
||||
return_copy : bool
|
||||
If False, the 'b', 'a', and 'padlen' entries in iir_params will be
|
||||
set inplace (if they weren't already). Otherwise, a new iir_params
|
||||
instance will be created and returned with these entries.
|
||||
If False, the 'sos', 'b', 'a', and 'padlen' entries in
|
||||
``iir_params`` will be set inplace (if they weren't already).
|
||||
Otherwise, a new ``iir_params`` instance will be created and
|
||||
returned with these entries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
iir_params : dict
|
||||
Updated iir_params dict, with the entries (set only if they didn't
|
||||
exist before) for 'b', 'a', and 'padlen' for IIR filtering.
|
||||
exist before) for 'sos' (or 'b', 'a'), and 'padlen' for
|
||||
IIR filtering.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.filter.filter_data
|
||||
mne.io.Raw.filter
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function triages calls to scipy.signal.iirfilter and iirdesign
|
||||
based on the input arguments (see descriptions of these functions
|
||||
and scipy's scipy.signal.filter_design documentation for details).
|
||||
This function triages calls to :func:`scipy.signal.iirfilter` and
|
||||
:func:`scipy.signal.iirdesign` based on the input arguments (see
|
||||
linked functions for more details).
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -478,20 +557,20 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
filter 'N' and the type of filtering 'ftype' are specified. To get
|
||||
coefficients for a 4th-order Butterworth filter, this would be:
|
||||
|
||||
>>> iir_params = dict(order=4, ftype='butter')
|
||||
>>> iir_params = dict(order=4, ftype='butter', output='sos')
|
||||
>>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False)
|
||||
>>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
|
||||
(5, 5, 82)
|
||||
>>> print((2 * len(iir_params['sos']), iir_params['padlen']))
|
||||
(4, 82)
|
||||
|
||||
Filters can also be constructed using filter design methods. To get a
|
||||
40 Hz Chebyshev type 1 lowpass with specific gain characteristics in the
|
||||
pass and stop bands (assuming the desired stop band is at 45 Hz), this
|
||||
would be a filter with much longer ringing:
|
||||
|
||||
>>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20)
|
||||
>>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20, output='sos')
|
||||
>>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low')
|
||||
>>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
|
||||
(6, 6, 439)
|
||||
>>> print((2 * len(iir_params['sos']), iir_params['padlen']))
|
||||
(6, 439)
|
||||
|
||||
Padding and/or filter coefficients can also be manually specified. For
|
||||
a 10-sample moving window with no padding during filtering, for example,
|
||||
@@ -502,17 +581,32 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
>>> print((iir_params['b'], iir_params['a'], iir_params['padlen']))
|
||||
(array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), [1, 0], 0)
|
||||
|
||||
For more information, see the tutorials :ref:`tut_background_filtering`
|
||||
and :ref:`tut_artifacts_filter`.
|
||||
""" # noqa
|
||||
from scipy.signal import iirfilter, iirdesign
|
||||
known_filters = ('bessel', 'butter', 'butterworth', 'cauer', 'cheby1',
|
||||
'cheby2', 'chebyshev1', 'chebyshev2', 'chebyshevi',
|
||||
'chebyshevii', 'ellip', 'elliptic')
|
||||
a = None
|
||||
b = None
|
||||
if not isinstance(iir_params, dict):
|
||||
raise TypeError('iir_params must be a dict, got %s' % type(iir_params))
|
||||
system = None
|
||||
# if the filter has been designed, we're good to go
|
||||
if 'a' in iir_params and 'b' in iir_params:
|
||||
[b, a] = [iir_params['b'], iir_params['a']]
|
||||
if 'sos' in iir_params:
|
||||
system = iir_params['sos']
|
||||
output = 'sos'
|
||||
elif 'a' in iir_params and 'b' in iir_params:
|
||||
system = (iir_params['b'], iir_params['a'])
|
||||
output = 'ba'
|
||||
else:
|
||||
output = iir_params.get('output', None)
|
||||
if output is None:
|
||||
warn('The default output type is "ba" in 0.13 but will change '
|
||||
'to "sos" in 0.14')
|
||||
output = 'ba'
|
||||
if not isinstance(output, string_types) or output not in ('ba', 'sos'):
|
||||
raise ValueError('Output must be "ba" or "sos", got %s'
|
||||
% (output,))
|
||||
# ensure we have a valid ftype
|
||||
if 'ftype' not in iir_params:
|
||||
raise RuntimeError('ftype must be an entry in iir_params if ''b'' '
|
||||
@@ -526,30 +620,36 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
|
||||
# use order-based design
|
||||
Wp = np.asanyarray(f_pass) / (float(sfreq) / 2)
|
||||
if 'order' in iir_params:
|
||||
[b, a] = iirfilter(iir_params['order'], Wp, btype=btype,
|
||||
ftype=ftype)
|
||||
system = iirfilter(iir_params['order'], Wp, btype=btype,
|
||||
ftype=ftype, output=output)
|
||||
else:
|
||||
# use gpass / gstop design
|
||||
Ws = np.asanyarray(f_stop) / (float(sfreq) / 2)
|
||||
if 'gpass' not in iir_params or 'gstop' not in iir_params:
|
||||
raise ValueError('iir_params must have at least ''gstop'' and'
|
||||
' ''gpass'' (or ''N'') entries')
|
||||
[b, a] = iirdesign(Wp, Ws, iir_params['gpass'],
|
||||
iir_params['gstop'], ftype=ftype)
|
||||
system = iirdesign(Wp, Ws, iir_params['gpass'],
|
||||
iir_params['gstop'], ftype=ftype, output=output)
|
||||
|
||||
if a is None or b is None:
|
||||
if system is None:
|
||||
raise RuntimeError('coefficients could not be created from iir_params')
|
||||
# do some sanity checks
|
||||
_check_coefficients(system)
|
||||
|
||||
# now deal with padding
|
||||
if 'padlen' not in iir_params:
|
||||
padlen = _estimate_ringing_samples(b, a)
|
||||
padlen = estimate_ringing_samples(system)
|
||||
else:
|
||||
padlen = iir_params['padlen']
|
||||
|
||||
if return_copy:
|
||||
iir_params = deepcopy(iir_params)
|
||||
|
||||
iir_params.update(dict(b=b, a=a, padlen=padlen))
|
||||
iir_params.update(dict(padlen=padlen))
|
||||
if output == 'sos':
|
||||
iir_params.update(sos=system)
|
||||
else:
|
||||
iir_params.update(b=system[0], a=system[1])
|
||||
return iir_params
|
||||
|
||||
|
||||
@@ -564,8 +664,6 @@ def _check_method(method, iir_params, extra_types):
|
||||
if method == 'iir':
|
||||
if iir_params is None:
|
||||
iir_params = dict(order=4, ftype='butter')
|
||||
if not isinstance(iir_params, dict):
|
||||
raise ValueError('iir_params must be a dict')
|
||||
elif iir_params is not None:
|
||||
raise ValueError('iir_params must be None if method != "iir"')
|
||||
method = method.lower()
|
||||
@@ -641,6 +739,16 @@ def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='10s',
|
||||
-------
|
||||
data : ndarray, shape (n_channels, n_times)
|
||||
The filtered data.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.filter.construct_iir_filter
|
||||
mne.io.Raw.filter
|
||||
|
||||
Notes
|
||||
-----
|
||||
For more information, see the tutorials :ref:`tut_background_filtering`
|
||||
and :ref:`tut_artifacts_filter`.
|
||||
"""
|
||||
if not isinstance(data, np.ndarray) or data.ndim != 2:
|
||||
raise ValueError('data must be an array with two dimensions')
|
||||
@@ -795,9 +903,7 @@ def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='10s',
|
||||
else:
|
||||
iir_params = construct_iir_filter(iir_params, [Fp1, Fp2],
|
||||
[Fs1, Fs2], Fs, 'bandpass')
|
||||
padlen = min(iir_params['padlen'], len(x))
|
||||
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
|
||||
picks, n_jobs, copy)
|
||||
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
|
||||
|
||||
return xf
|
||||
|
||||
@@ -911,9 +1017,7 @@ def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='10s',
|
||||
for fp_1, fp_2, fs_1, fs_2 in zip(Fp1, Fp2, Fs1, Fs2):
|
||||
iir_params_new = construct_iir_filter(iir_params, [fp_1, fp_2],
|
||||
[fs_1, fs_2], Fs, 'bandstop')
|
||||
padlen = min(iir_params_new['padlen'], len(x))
|
||||
xf = _filtfilt(x, iir_params_new['b'], iir_params_new['a'], padlen,
|
||||
picks, n_jobs, copy)
|
||||
xf = _filtfilt(x, iir_params_new, picks, n_jobs, copy)
|
||||
|
||||
return xf
|
||||
|
||||
@@ -1002,9 +1106,7 @@ def low_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
|
||||
xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
|
||||
else:
|
||||
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'low')
|
||||
padlen = min(iir_params['padlen'], len(x))
|
||||
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
|
||||
picks, n_jobs, copy)
|
||||
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
|
||||
|
||||
return xf
|
||||
|
||||
@@ -1095,9 +1197,7 @@ def high_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
|
||||
xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
|
||||
else:
|
||||
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'high')
|
||||
padlen = min(iir_params['padlen'], len(x))
|
||||
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
|
||||
picks, n_jobs, copy)
|
||||
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
|
||||
|
||||
return xf
|
||||
|
||||
|
||||
+246
-11
@@ -589,25 +589,260 @@ def get_firwin2():
|
||||
return firwin2
|
||||
|
||||
|
||||
def _filtfilt(*args, **kwargs):
|
||||
"""wrap filtfilt, excluding padding arguments"""
|
||||
from scipy.signal import filtfilt
|
||||
# cut out filter args
|
||||
if len(args) > 4:
|
||||
args = args[:4]
|
||||
if 'padlen' in kwargs:
|
||||
del kwargs['padlen']
|
||||
return filtfilt(*args, **kwargs)
|
||||
def _filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None):
|
||||
"""copy of modern SciPy filtfilt without "method" or "irlen" arguments"""
|
||||
from scipy.signal import lfilter_zi, lfilter
|
||||
b = np.atleast_1d(b)
|
||||
a = np.atleast_1d(a)
|
||||
x = np.asarray(x)
|
||||
|
||||
# method == "pad"
|
||||
edge, ext = _validate_pad(padtype, padlen, x, axis,
|
||||
ntaps=max(len(a), len(b)))
|
||||
|
||||
# Get the steady state of the filter's step response.
|
||||
zi = lfilter_zi(b, a)
|
||||
|
||||
# Reshape zi and create x0 so that zi*x0 broadcasts
|
||||
# to the correct value for the 'zi' keyword argument
|
||||
# to lfilter.
|
||||
zi_shape = [1] * x.ndim
|
||||
zi_shape[axis] = zi.size
|
||||
zi = np.reshape(zi, zi_shape)
|
||||
x0 = axis_slice(ext, stop=1, axis=axis)
|
||||
|
||||
# Forward filter.
|
||||
(y, zf) = lfilter(b, a, ext, axis=axis, zi=zi * x0)
|
||||
|
||||
# Backward filter.
|
||||
# Create y0 so zi*y0 broadcasts appropriately.
|
||||
y0 = axis_slice(y, start=-1, axis=axis)
|
||||
(y, zf) = lfilter(b, a, axis_reverse(y, axis=axis), axis=axis, zi=zi * y0)
|
||||
|
||||
# Reverse y.
|
||||
y = axis_reverse(y, axis=axis)
|
||||
|
||||
if edge > 0:
|
||||
# Slice the actual signal from the extended signal.
|
||||
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
|
||||
"""copy of SciPy sosfiltfilt"""
|
||||
sos, n_sections = _validate_sos(sos)
|
||||
|
||||
# `method` is "pad"...
|
||||
ntaps = 2 * n_sections + 1
|
||||
ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
|
||||
edge, ext = _validate_pad(padtype, padlen, x, axis,
|
||||
ntaps=ntaps)
|
||||
|
||||
# These steps follow the same form as filtfilt with modifications
|
||||
zi = sosfilt_zi(sos) # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
|
||||
zi_shape = [1] * x.ndim
|
||||
zi_shape[axis] = 2
|
||||
zi.shape = [n_sections] + zi_shape
|
||||
x_0 = axis_slice(ext, stop=1, axis=axis)
|
||||
(y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
|
||||
y_0 = axis_slice(y, start=-1, axis=axis)
|
||||
(y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
|
||||
y = axis_reverse(y, axis=axis)
|
||||
if edge > 0:
|
||||
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
|
||||
return y
|
||||
|
||||
|
||||
def axis_slice(a, start=None, stop=None, step=None, axis=-1):
|
||||
"""Take a slice along axis 'axis' from 'a'"""
|
||||
a_slice = [slice(None)] * a.ndim
|
||||
a_slice[axis] = slice(start, stop, step)
|
||||
b = a[a_slice]
|
||||
return b
|
||||
|
||||
|
||||
def axis_reverse(a, axis=-1):
|
||||
"""Reverse the 1-d slices of `a` along axis `axis`."""
|
||||
return axis_slice(a, step=-1, axis=axis)
|
||||
|
||||
|
||||
def _validate_pad(padtype, padlen, x, axis, ntaps):
|
||||
"""Helper to validate padding for filtfilt"""
|
||||
if padtype not in ['even', 'odd', 'constant', None]:
|
||||
raise ValueError(("Unknown value '%s' given to padtype. padtype "
|
||||
"must be 'even', 'odd', 'constant', or None.") %
|
||||
padtype)
|
||||
|
||||
if padtype is None:
|
||||
padlen = 0
|
||||
|
||||
if padlen is None:
|
||||
# Original padding; preserved for backwards compatibility.
|
||||
edge = ntaps * 3
|
||||
else:
|
||||
edge = padlen
|
||||
|
||||
# x's 'axis' dimension must be bigger than edge.
|
||||
if x.shape[axis] <= edge:
|
||||
raise ValueError("The length of the input vector x must be at least "
|
||||
"padlen, which is %d." % edge)
|
||||
|
||||
if padtype is not None and edge > 0:
|
||||
# Make an extension of length `edge` at each
|
||||
# end of the input array.
|
||||
if padtype == 'even':
|
||||
ext = even_ext(x, edge, axis=axis)
|
||||
elif padtype == 'odd':
|
||||
ext = odd_ext(x, edge, axis=axis)
|
||||
else:
|
||||
ext = const_ext(x, edge, axis=axis)
|
||||
else:
|
||||
ext = x
|
||||
return edge, ext
|
||||
|
||||
|
||||
def _validate_sos(sos):
|
||||
"""Helper to validate a SOS input"""
|
||||
sos = np.atleast_2d(sos)
|
||||
if sos.ndim != 2:
|
||||
raise ValueError('sos array must be 2D')
|
||||
n_sections, m = sos.shape
|
||||
if m != 6:
|
||||
raise ValueError('sos array must be shape (n_sections, 6)')
|
||||
if not (sos[:, 3] == 1).all():
|
||||
raise ValueError('sos[:, 3] should be all ones')
|
||||
return sos, n_sections
|
||||
|
||||
|
||||
def odd_ext(x, n, axis=-1):
|
||||
"""Generate a new ndarray by making an odd extension of x along an axis."""
|
||||
if n < 1:
|
||||
return x
|
||||
if n > x.shape[axis] - 1:
|
||||
raise ValueError(("The extension length n (%d) is too big. " +
|
||||
"It must not exceed x.shape[axis]-1, which is %d.")
|
||||
% (n, x.shape[axis] - 1))
|
||||
left_end = axis_slice(x, start=0, stop=1, axis=axis)
|
||||
left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
|
||||
right_end = axis_slice(x, start=-1, axis=axis)
|
||||
right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
|
||||
ext = np.concatenate((2 * left_end - left_ext,
|
||||
x,
|
||||
2 * right_end - right_ext),
|
||||
axis=axis)
|
||||
return ext
|
||||
|
||||
|
||||
def even_ext(x, n, axis=-1):
|
||||
"""Create an ndarray that is an even extension of x along an axis."""
|
||||
if n < 1:
|
||||
return x
|
||||
if n > x.shape[axis] - 1:
|
||||
raise ValueError(("The extension length n (%d) is too big. " +
|
||||
"It must not exceed x.shape[axis]-1, which is %d.")
|
||||
% (n, x.shape[axis] - 1))
|
||||
left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
|
||||
right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
|
||||
ext = np.concatenate((left_ext,
|
||||
x,
|
||||
right_ext),
|
||||
axis=axis)
|
||||
return ext
|
||||
|
||||
|
||||
def const_ext(x, n, axis=-1):
|
||||
"""Create an ndarray that is a constant extension of x along an axis"""
|
||||
if n < 1:
|
||||
return x
|
||||
left_end = axis_slice(x, start=0, stop=1, axis=axis)
|
||||
ones_shape = [1] * x.ndim
|
||||
ones_shape[axis] = n
|
||||
ones = np.ones(ones_shape, dtype=x.dtype)
|
||||
left_ext = ones * left_end
|
||||
right_end = axis_slice(x, start=-1, axis=axis)
|
||||
right_ext = ones * right_end
|
||||
ext = np.concatenate((left_ext,
|
||||
x,
|
||||
right_ext),
|
||||
axis=axis)
|
||||
return ext
|
||||
|
||||
|
||||
def sosfilt_zi(sos):
|
||||
"""Compute an initial state `zi` for the sosfilt function"""
|
||||
from scipy.signal import lfilter_zi
|
||||
sos = np.asarray(sos)
|
||||
if sos.ndim != 2 or sos.shape[1] != 6:
|
||||
raise ValueError('sos must be shape (n_sections, 6)')
|
||||
|
||||
n_sections = sos.shape[0]
|
||||
zi = np.empty((n_sections, 2))
|
||||
scale = 1.0
|
||||
for section in range(n_sections):
|
||||
b = sos[section, :3]
|
||||
a = sos[section, 3:]
|
||||
zi[section] = scale * lfilter_zi(b, a)
|
||||
# If H(z) = B(z)/A(z) is this section's transfer function, then
|
||||
# b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
|
||||
# state value of this section's step response.
|
||||
scale *= b.sum() / a.sum()
|
||||
|
||||
return zi
|
||||
|
||||
|
||||
def sosfilt(sos, x, axis=-1, zi=None):
|
||||
"""Filter data along one dimension using cascaded second-order sections"""
|
||||
from scipy.signal import lfilter
|
||||
x = np.asarray(x)
|
||||
|
||||
sos = np.atleast_2d(sos)
|
||||
if sos.ndim != 2:
|
||||
raise ValueError('sos array must be 2D')
|
||||
|
||||
n_sections, m = sos.shape
|
||||
if m != 6:
|
||||
raise ValueError('sos array must be shape (n_sections, 6)')
|
||||
|
||||
use_zi = zi is not None
|
||||
if use_zi:
|
||||
zi = np.asarray(zi)
|
||||
x_zi_shape = list(x.shape)
|
||||
x_zi_shape[axis] = 2
|
||||
x_zi_shape = tuple([n_sections] + x_zi_shape)
|
||||
if zi.shape != x_zi_shape:
|
||||
raise ValueError('Invalid zi shape. With axis=%r, an input with '
|
||||
'shape %r, and an sos array with %d sections, zi '
|
||||
'must have shape %r.' %
|
||||
(axis, x.shape, n_sections, x_zi_shape))
|
||||
zf = np.zeros_like(zi)
|
||||
|
||||
for section in range(n_sections):
|
||||
if use_zi:
|
||||
x, zf[section] = lfilter(sos[section, :3], sos[section, 3:],
|
||||
x, axis, zi=zi[section])
|
||||
else:
|
||||
x = lfilter(sos[section, :3], sos[section, 3:], x, axis)
|
||||
out = (x, zf) if use_zi else x
|
||||
return out
|
||||
|
||||
|
||||
def get_filtfilt():
|
||||
"""Helper to get filtfilt from scipy"""
|
||||
from scipy.signal import filtfilt
|
||||
|
||||
if 'padlen' in _get_args(filtfilt):
|
||||
return filtfilt
|
||||
else:
|
||||
return _filtfilt
|
||||
|
||||
return _filtfilt
|
||||
|
||||
def get_sosfiltfilt():
|
||||
"""Helper to get sosfiltfilt from scipy"""
|
||||
try:
|
||||
from scipy.signal import sosfiltfilt
|
||||
except ImportError:
|
||||
sosfiltfilt = _sosfiltfilt
|
||||
return sosfiltfilt
|
||||
|
||||
|
||||
def _get_argrelmax():
|
||||
|
||||
@@ -893,6 +893,11 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
|
||||
mne.io.Raw.notch_filter
|
||||
mne.io.Raw.resample
|
||||
mne.filter.filter_data
|
||||
|
||||
Notes
|
||||
-----
|
||||
For more information, see the tutorials :ref:`tut_background_filtering`
|
||||
and :ref:`tut_artifacts_filter`.
|
||||
"""
|
||||
_check_preload(self, 'raw.filter')
|
||||
data_picks = _pick_data_or_ica(self.info)
|
||||
|
||||
+83
-17
@@ -3,19 +3,39 @@ from numpy.testing import (assert_array_almost_equal, assert_almost_equal,
|
||||
assert_array_equal, assert_allclose)
|
||||
from nose.tools import assert_equal, assert_true, assert_raises
|
||||
import warnings
|
||||
from scipy.signal import resample as sp_resample
|
||||
from scipy.signal import resample as sp_resample, butter
|
||||
|
||||
from mne.filter import (band_pass_filter, high_pass_filter, low_pass_filter,
|
||||
band_stop_filter, resample, _resample_stim_channels,
|
||||
construct_iir_filter, notch_filter, detrend,
|
||||
_overlap_add_filter, _smart_pad)
|
||||
_overlap_add_filter, _smart_pad,
|
||||
estimate_ringing_samples, filter_data)
|
||||
|
||||
from mne.utils import sum_squared, run_tests_if_main, slow_test, catch_logging
|
||||
from mne.utils import (sum_squared, run_tests_if_main, slow_test,
|
||||
catch_logging, requires_version)
|
||||
|
||||
warnings.simplefilter('always') # enable b/c these tests throw warnings
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
|
||||
@requires_version('scipy', '0.16')
|
||||
def test_estimate_ringing():
|
||||
"""Test our ringing estimation function"""
|
||||
# Actual values might differ based on system, so let's be approximate
|
||||
for kind in ('ba', 'sos'):
|
||||
for thresh, lims in ((0.1, (30, 60)), # 47
|
||||
(0.01, (300, 600)), # 475
|
||||
(0.001, (3000, 6000)), # 4758
|
||||
(0.0001, (30000, 60000))): # 37993
|
||||
n_ring = estimate_ringing_samples(butter(3, thresh, output=kind))
|
||||
assert_true(lims[0] <= n_ring <= lims[1],
|
||||
msg='%s %s: %s <= %s <= %s'
|
||||
% (kind, thresh, lims[0], n_ring, lims[1]))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
assert_equal(estimate_ringing_samples(butter(4, 0.00001)), 100000)
|
||||
assert_true(any('properly estimate' in str(ww.message) for ww in w))
|
||||
|
||||
|
||||
def test_1d_filter():
|
||||
"""Test our private overlap-add filtering function"""
|
||||
# make some random signals and filters
|
||||
@@ -68,17 +88,35 @@ def test_1d_filter():
|
||||
assert_allclose(x_expected, x_filtered)
|
||||
|
||||
|
||||
@requires_version('scipy', '0.16')
|
||||
def test_iir_stability():
|
||||
"""Test IIR filter stability check
|
||||
"""
|
||||
"""Test IIR filter stability check"""
|
||||
sig = np.empty(1000)
|
||||
sfreq = 1000
|
||||
# This will make an unstable filter, should throw RuntimeError
|
||||
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
|
||||
method='iir', iir_params=dict(ftype='butter', order=8))
|
||||
# can't pass iir_params if method='fir'
|
||||
method='iir', iir_params=dict(ftype='butter', order=8,
|
||||
output='ba'))
|
||||
# This one should work just fine
|
||||
high_pass_filter(sig, sfreq, 0.6, method='iir',
|
||||
iir_params=dict(ftype='butter', order=8, output='sos'))
|
||||
# bad system type
|
||||
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.6, method='iir',
|
||||
iir_params=dict(ftype='butter', order=8, output='foo'))
|
||||
# missing ftype
|
||||
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
|
||||
method='iir', iir_params=dict(order=8, output='sos'))
|
||||
# bad ftype
|
||||
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
|
||||
method='iir',
|
||||
iir_params=dict(order=8, ftype='foo', output='sos'))
|
||||
# missing gstop
|
||||
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
|
||||
method='iir', iir_params=dict(gpass=0.5, output='sos'))
|
||||
# can't pass iir_params if method='fft'
|
||||
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
|
||||
method='fir', iir_params=dict(ftype='butter', order=2))
|
||||
method='fft', iir_params=dict(ftype='butter', order=2,
|
||||
output='sos'))
|
||||
# method must be string
|
||||
assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
|
||||
method=1)
|
||||
@@ -86,12 +124,26 @@ def test_iir_stability():
|
||||
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
|
||||
method='blah')
|
||||
# bad iir_params
|
||||
assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
|
||||
method='iir', iir_params='blah')
|
||||
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
|
||||
method='fir', iir_params='blah')
|
||||
method='fft', iir_params=dict())
|
||||
|
||||
# should pass because dafault trans_bandwidth is not relevant
|
||||
high_pass_filter(sig, 250, 0.5, method='iir',
|
||||
iir_params=dict(ftype='butter', order=6))
|
||||
iir_params = dict(ftype='butter', order=2, output='sos')
|
||||
x_sos = high_pass_filter(sig, 250, 0.5, method='iir',
|
||||
iir_params=iir_params)
|
||||
iir_params_sos = construct_iir_filter(iir_params, f_pass=0.5, sfreq=250,
|
||||
btype='highpass')
|
||||
x_sos_2 = high_pass_filter(sig, 250, 0.5, method='iir',
|
||||
iir_params=iir_params_sos)
|
||||
assert_allclose(x_sos[100:-100], x_sos_2[100:-100])
|
||||
x_ba = high_pass_filter(sig, 250, 0.5, method='iir',
|
||||
iir_params=dict(ftype='butter', order=2,
|
||||
output='ba'))
|
||||
# Note that this will fail for higher orders (e.g., 6) showing the
|
||||
# hopefully decreased numerical error of SOS
|
||||
assert_allclose(x_sos[100:-100], x_ba[100:-100])
|
||||
|
||||
|
||||
def test_notch_filters():
|
||||
@@ -178,6 +230,7 @@ def test_resample_stim_channel():
|
||||
assert_equal(new_data.shape[1], new_data_len)
|
||||
|
||||
|
||||
@requires_version('scipy', '0.16')
|
||||
@slow_test
|
||||
def test_filters():
|
||||
"""Test low-, band-, high-pass, and band-stop filters plus resampling
|
||||
@@ -261,15 +314,22 @@ def test_filters():
|
||||
assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)
|
||||
|
||||
# let's construct some filters
|
||||
iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
|
||||
iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
|
||||
iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
|
||||
# this should be a third order filter
|
||||
assert_true(iir_params['a'].size - 1 == 3)
|
||||
assert_true(iir_params['b'].size - 1 == 3)
|
||||
iir_params = dict(ftype='butter', order=4)
|
||||
assert_equal(iir_params['a'].size - 1, 3)
|
||||
assert_equal(iir_params['b'].size - 1, 3)
|
||||
iir_params = dict(ftype='butter', order=4, output='ba')
|
||||
iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
|
||||
assert_true(iir_params['a'].size - 1 == 4)
|
||||
assert_true(iir_params['b'].size - 1 == 4)
|
||||
assert_equal(iir_params['a'].size - 1, 4)
|
||||
assert_equal(iir_params['b'].size - 1, 4)
|
||||
iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='sos')
|
||||
iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
|
||||
# this should be a third order filter, which requires 2 SOS ((2, 6))
|
||||
assert_equal(iir_params['sos'].shape, (2, 6))
|
||||
iir_params = dict(ftype='butter', order=4, output='sos')
|
||||
iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
|
||||
assert_equal(iir_params['sos'].shape, (2, 6))
|
||||
|
||||
# check that picks work for 3d array with one channel and picks=[0]
|
||||
a = rng.randn(5 * sfreq, 5 * sfreq)
|
||||
@@ -299,6 +359,12 @@ def test_filters():
|
||||
# the firwin2 function gets us this close
|
||||
assert_allclose(x, x_filt, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# degenerate conditions
|
||||
assert_raises(ValueError, filter_data, x, sfreq, 1, 10) # not 2D
|
||||
assert_raises(ValueError, filter_data, x[np.newaxis], -sfreq, 1, 10)
|
||||
assert_raises(ValueError, filter_data, x[np.newaxis], sfreq, 1,
|
||||
sfreq * 0.75)
|
||||
|
||||
|
||||
def test_cuda():
|
||||
"""Test CUDA-based filtering
|
||||
|
||||
@@ -14,9 +14,10 @@ from mne.utils import run_tests_if_main
|
||||
from mne.fixes import (_in1d, _tril_indices, _copysign, _unravel_index,
|
||||
_Counter, _unique, _bincount, _digitize,
|
||||
_sparse_block_diag, _matrix_rank, _meshgrid,
|
||||
_isclose)
|
||||
from mne.fixes import _firwin2 as mne_firwin2
|
||||
from mne.fixes import _filtfilt as mne_filtfilt
|
||||
_isclose,
|
||||
_firwin2 as mne_firwin2,
|
||||
_filtfilt as mne_filtfilt,
|
||||
_sosfiltfilt as mne_sosfiltfilt)
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
@@ -148,6 +149,8 @@ def test_filtfilt():
|
||||
# Filter with an impulse
|
||||
y = mne_filtfilt([1, 0], [1, 0], x, padlen=0)
|
||||
assert_array_equal(x, y)
|
||||
y = mne_sosfiltfilt(np.array([[1., 0., 0., 1, 0., 0.]]), x, padlen=0)
|
||||
assert_array_equal(x, y)
|
||||
|
||||
|
||||
def test_sparse_block_diag():
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
r"""
|
||||
.. _tut_background_filtering:
|
||||
|
||||
===================================
|
||||
Background information on filtering
|
||||
===================================
|
||||
|
||||
@@ -13,8 +14,10 @@ in MNE-Python on actual data, see the :ref:`tut_artifacts_filter` tutorial.
|
||||
|
||||
.. contents::
|
||||
|
||||
.. _filtering-basics:
|
||||
|
||||
Filtering basics
|
||||
----------------
|
||||
================
|
||||
Let's get some of the basic math down. In the frequency domain, digital
|
||||
filters have a transfer function that is given by:
|
||||
|
||||
@@ -76,7 +79,9 @@ In general, the sharper something is in frequency, the broader it is in time,
|
||||
and vice-versa. This is a fundamental time-frequency tradeoff, and it will
|
||||
show up below.
|
||||
|
||||
Here we will focus first on FIR filters, which are the default filters used by
|
||||
===========
|
||||
|
||||
First we will focus first on FIR filters, which are the default filters used by
|
||||
MNE-Python.
|
||||
"""
|
||||
|
||||
@@ -166,13 +171,24 @@ h = np.sinc(2 * f_p * t) / (4 * np.pi)
|
||||
|
||||
|
||||
def plot_filter(h, title, freq, gain, show=True):
|
||||
if h.ndim == 2: # second-order sections
|
||||
sos = h
|
||||
n = mne.filter.estimate_ringing_samples(sos)
|
||||
h = np.zeros(n)
|
||||
h[0] = 1
|
||||
h = signal.sosfilt(sos, h)
|
||||
H = np.ones(512, np.complex128)
|
||||
for section in sos:
|
||||
f, this_H = signal.freqz(section[:3], section[3:])
|
||||
H *= this_H
|
||||
else:
|
||||
f, H = signal.freqz(h)
|
||||
fig, axs = plt.subplots(2)
|
||||
t = np.arange(len(h)) / sfreq
|
||||
axs[0].plot(t, h, color=blue)
|
||||
axs[0].set(xlim=t[[0, -1]], xlabel='Time (sec)',
|
||||
ylabel='Amplitude h(n)', title=title)
|
||||
box_off(axs[0])
|
||||
f, H = signal.freqz(h)
|
||||
f *= sfreq / (2 * np.pi)
|
||||
axs[1].semilogx(f, 10 * np.log10((H * H.conj()).real), color=blue,
|
||||
linewidth=2, zorder=4)
|
||||
@@ -366,8 +382,121 @@ mne.viz.tight_layout()
|
||||
plt.show()
|
||||
|
||||
###############################################################################
|
||||
# IIR filters
|
||||
# ===========
|
||||
# MNE-Python also offers IIR filtering functionality that is based on the
|
||||
# methods from :mod:`scipy.signal`. Specifically, we use the general-purpose
|
||||
# functions :func:`scipy.signal.iirfilter` and :func:`scipy.signal.iirdesign`,
|
||||
# which provide unified interfaces to IIR filter design.
|
||||
#
|
||||
# Designing IIR filters
|
||||
# ---------------------
|
||||
# Let's continue with our design of a 40 Hz low-pass filter, and look at
|
||||
# some trade-offs of different IIR filters.
|
||||
#
|
||||
# Often the default IIR filter is a `Butterworth filter`_, which is designed
|
||||
# to have a *maximally flat pass-band*. Let's look at a few orders of filter,
|
||||
# i.e., a few different number of coefficients used and therefore steepness
|
||||
# of the filter:
|
||||
|
||||
sos = signal.iirfilter(2, f_p / nyq, btype='low', ftype='butter', output='sos')
|
||||
plot_filter(sos, 'Butterworth order=2', freq, gain)
|
||||
|
||||
# Eventually this will just be from scipy signal.sosfiltfilt, but 0.18 is
|
||||
# not widely adopted yet (as of June 2016), so we use our wrapper...
|
||||
sosfiltfilt = mne.fixes.get_sosfiltfilt()
|
||||
x_shallow = sosfiltfilt(sos, x)
|
||||
|
||||
###############################################################################
|
||||
# The falloff of this filter is not very steep.
|
||||
#
|
||||
# .. warning:: For brevity, we do not show the phase of these filters here.
|
||||
# In the FIR case, we can design linear-phase filters, and
|
||||
# compensate for the delay if necessary. This cannot be done
|
||||
# with IIR filters, and as the filter order increases, the
|
||||
# phase distortion near and in the transition band worsens.
|
||||
# However, if acausal (forward-backward) filtering can be used,
|
||||
# e.g. with :func:`scipy.signal.filtfilt`, these phase issues
|
||||
# can be mitigated.
|
||||
#
|
||||
# .. note:: Here we have made use of second-order sections (SOS)
|
||||
# by using :func:`scipy.signal.sosfilt` and, under the
|
||||
# hood, :func:`scipy.signal.zpk2sos` when passing the
|
||||
# ``output='sos'`` keyword argument to
|
||||
# :func:`scipy.signal.iirfilter`. The filter definitions
|
||||
# given in :ref:`filtering-basics` use the polynomial
|
||||
# numerator/denominator (sometimes called "tf") form ``(b, a)``,
|
||||
# which are theoretically equivalent to the SOS form used here.
|
||||
# In practice, however, the SOS form can give much better results
|
||||
# due to issues with numerical precision (see
|
||||
# :func:`scipy.signal.sosfilt` for an example), so SOS should be
|
||||
# used when possible to do IIR filtering.
|
||||
#
|
||||
# Let's increase the order, and note that now we have better attenuation,
|
||||
# with a longer impulse response:
|
||||
|
||||
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='butter', output='sos')
|
||||
plot_filter(sos, 'Butterworth order=8', freq, gain)
|
||||
x_steep = sosfiltfilt(sos, x)
|
||||
|
||||
###############################################################################
|
||||
# There are other types of IIR filters that we can use. For a complete list,
|
||||
# check out the documentation for :func:`scipy.signal.iirdesign`. Let's
|
||||
# try a Chebychev (type I) filter, which trades off ripple in the pass-band
|
||||
# to get better attenuation in the stop-band:
|
||||
|
||||
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
|
||||
rp=1) # dB of acceptable pass-band ripple
|
||||
plot_filter(sos, 'Chebychev-1 order=8, ripple=1 dB', freq, gain)
|
||||
|
||||
###############################################################################
|
||||
# And if we can live with even more ripple, we can get it slightly steeper,
|
||||
# but the impulse response begins to ring substantially longer (note the
|
||||
# different x-axis scale):
|
||||
|
||||
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
|
||||
rp=6)
|
||||
plot_filter(sos, 'Chebychev-1 order=8, ripple=6 dB', freq, gain)
|
||||
|
||||
###############################################################################
|
||||
# Applying IIR filters
|
||||
# --------------------
|
||||
# Now let's look at how our shallow and steep Butterworth IIR filters
|
||||
# perform on our morlet signal from before:
|
||||
|
||||
axs = plt.subplots(2)[1]
|
||||
yticks = np.arange(4) / -30.
|
||||
yticklabels = ['Original', 'Noisy', 'Butterworth-2', 'Butterworth-8']
|
||||
plot_signal(x_orig, offset=yticks[0])
|
||||
plot_signal(x, offset=yticks[1])
|
||||
plot_signal(x_shallow, offset=yticks[2])
|
||||
plot_signal(x_steep, offset=yticks[3])
|
||||
axs[0].set(xlim=tlim, title='Lowpass=%d Hz' % f_p, xticks=tticks,
|
||||
ylim=[-0.125, 0.025], yticks=yticks, yticklabels=yticklabels,)
|
||||
for text in axs[0].get_yticklabels():
|
||||
text.set(rotation=45, size=8)
|
||||
axs[1].set(xlim=flim, ylim=ylim, xlabel='Frequency (Hz)',
|
||||
ylabel='Magnitude (dB)')
|
||||
box_off(axs[0])
|
||||
box_off(axs[1])
|
||||
mne.viz.tight_layout()
|
||||
plt.show()
|
||||
|
||||
###############################################################################
|
||||
# Filtering in MNE-Python
|
||||
# =======================
|
||||
# Most often, filtering in MNE-Python is done at the :class:`mne.io.Raw` level,
|
||||
# and thus :func:`mne.io.Raw.filter` is used. This function under the hood
|
||||
# (among other things) calls :func:`mne.filter.filter_data` to actually
|
||||
# filter the data.
|
||||
#
|
||||
# :func:`mne.filter.filter_data` by default applies a FIR filter designed using
|
||||
# :func:`scipy.signal.firwin2`. For more information on how to use the
|
||||
# MNE-Python filtering functions with real data, consult the preprocessing
|
||||
# tutorial on :ref:`tut_artifacts_filter`.
|
||||
#
|
||||
# Summary
|
||||
# -------
|
||||
# =======
|
||||
# When filtering, there are always tradeoffs that should be considered.
|
||||
# One important tradeoff is between time-domain characteristics (like ringing)
|
||||
# and frequency-domain attenuation characteristics (like effective transition
|
||||
@@ -379,7 +508,7 @@ plt.show()
|
||||
|
||||
###############################################################################
|
||||
# References
|
||||
# ----------
|
||||
# ==========
|
||||
# .. [1] Parks TW, Burrus CS. Digital Filter Design.
|
||||
# New York: Wiley-Interscience, 1987.
|
||||
#
|
||||
@@ -394,3 +523,4 @@ plt.show()
|
||||
# .. _scipy firwin2: http://scipy.github.io/devdocs/generated/scipy.signal.firwin2.html # noqa
|
||||
# .. _matlab fir2: http://www.mathworks.com/help/signal/ref/fir2.html
|
||||
# .. _matlab firls: http://www.mathworks.com/help/signal/ref/firls.html
|
||||
# .. _Butterworth filter: https://en.wikipedia.org/wiki/Butterworth_filter
|
||||
|
||||
Reference in New Issue
Block a user