MRG: Add second-order sections filtering (#3324)

ENH: Add second-order sections filtering
This commit is contained in:
Eric Larson
2016-06-27 11:53:55 -04:00
committed by Alexandre Gramfort
parent 0a52ab33fa
commit 42678aa7ec
9 changed files with 651 additions and 109 deletions
+1
View File
@@ -6,3 +6,4 @@ omit =
*/mne/externals/*
*/bin/*
*/setup.py
*/mne/fixes*
+2
View File
@@ -461,6 +461,8 @@ EEG referencing:
band_pass_filter
construct_iir_filter
estimate_ringing_samples
filter_data
high_pass_filter
low_pass_filter
notch_filter
+4 -4
View File
@@ -25,6 +25,8 @@ Changelog
- Added :func:`mne.viz.ica.plot_ica_properties` that allows ploting of independent component properties similar to ``pop_prop`` in EEGLAB. Also :class:`mne.preprocessing.ica.ICA` has :func:`mne.preprocessing.ica.ICA.plot_properties` method now. Added by `Mikołaj Magnuski`_
- Add second-order sections (instead of ``(b, a)`` form) IIR filtering for reduced numerical error by `Eric Larson`_
BUG
~~~
@@ -63,6 +65,8 @@ API
- Added option to pass a list of axes to :func:`mne.viz.epochs.plot_epochs_image` by `Mikołaj Magnuski`_
- Constructing IIR filters in :func:`mne.filter.construct_iir_filter` defaults to ``output='ba'`` in 0.13 but this will be changed to ``output='sos'`` by `Eric Larson`_
.. _changes_0_12:
Version 0.12
@@ -1605,7 +1609,3 @@ of commits):
.. _Jon Houck: http://www.unm.edu/~jhouck/
.. _Pablo-Arias: https://github.com/Pablo-Arias
.. _Alexander Rudiuk: https://github.com/ARudiuk
.. _Mikołaj Magnuski: https://github.com/mmagnuski
+169 -69
View File
@@ -8,7 +8,7 @@ from scipy.fftpack import fft, ifftshift, fftfreq
from .cuda import (setup_cuda_fft_multiply_repeated, fft_multiply_repeated,
setup_cuda_fft_resample, fft_resample, _smart_pad)
from .externals.six import string_types, integer_types
from .fixes import get_firwin2, get_filtfilt
from .fixes import get_firwin2, get_filtfilt, get_sosfiltfilt, partial
from .parallel import parallel_func, check_n_jobs
from .time_frequency.multitaper import dpss_windows, _mt_spectra
from .utils import logger, verbose, sum_squared, check_version, warn
@@ -376,49 +376,100 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
return x
def _check_coefficients(b, a):
def _check_coefficients(system):
"""Check for filter stability"""
from scipy.signal import tf2zpk
z, p, k = tf2zpk(b, a)
if isinstance(system, tuple):
from scipy.signal import tf2zpk
z, p, k = tf2zpk(*system)
else: # sos
from scipy.signal import sos2zpk
z, p, k = sos2zpk(system)
if np.any(np.abs(p) > 1.0):
raise RuntimeError('Filter poles outside unit circle, filter will be '
'unstable. Consider using different filter '
'coefficients.')
def _filtfilt(x, b, a, padlen, picks, n_jobs, copy):
def _filtfilt(x, iir_params, picks, n_jobs, copy):
"""Helper to more easily call filtfilt"""
# set up array for filtering, reshape to 2D, operate on last axis
filtfilt = get_filtfilt()
padlen = min(iir_params['padlen'], len(x))
n_jobs = check_n_jobs(n_jobs)
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
_check_coefficients(b, a)
if 'sos' in iir_params:
sosfiltfilt = get_sosfiltfilt()
fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen)
_check_coefficients(iir_params['sos'])
else:
filtfilt = get_filtfilt()
fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'],
padlen=padlen)
_check_coefficients((iir_params['b'], iir_params['a']))
if n_jobs == 1:
for p in picks:
x[p] = filtfilt(b, a, x[p], padlen=padlen)
x[p] = fun(x=x[p])
else:
parallel, p_fun, _ = parallel_func(filtfilt, n_jobs)
data_new = parallel(p_fun(b, a, x[p], padlen=padlen)
for p in picks)
parallel, p_fun, _ = parallel_func(fun, n_jobs)
data_new = parallel(p_fun(x=x[p]) for p in picks)
for pp, p in enumerate(picks):
x[p] = data_new[pp]
x.shape = orig_shape
return x
def _estimate_ringing_samples(b, a):
"""Helper function for determining IIR padding"""
# XXX Need to extend this to more than 1000 samples for long IIR filters!
from scipy.signal import lfilter
x = np.zeros(1000)
def estimate_ringing_samples(system, max_try=100000):
"""Estimate filter ringing
Parameters
----------
system : tuple | ndarray
A tuple of (b, a) or ndarray of second-order sections coefficients.
max_try : int
Approximate maximum number of samples to try.
This will be changed to a multple of 1000.
Returns
-------
n : int
The approximate ringing.
"""
from scipy import signal
if isinstance(system, tuple): # TF
kind = 'ba'
b, a = system
zi = [0.] * (len(a) - 1)
else:
kind = 'sos'
sos = system
zi = [[0.] * 2] * len(sos)
n_per_chunk = 1000
n_chunks_max = int(np.ceil(max_try / float(n_per_chunk)))
x = np.zeros(n_per_chunk)
x[0] = 1
h = lfilter(b, a, x)
return np.where(np.abs(h) > 0.001 * np.max(np.abs(h)))[0][-1]
last_good = n_per_chunk
thresh_val = 0
for ii in range(n_chunks_max):
if kind == 'ba':
h, zi = signal.lfilter(b, a, x, zi=zi)
else:
h, zi = signal.sosfilt(sos, x, zi=zi)
x[0] = 0 # for subsequent iterations we want zero input
h = np.abs(h)
thresh_val = max(0.001 * np.max(h), thresh_val)
idx = np.where(np.abs(h) > thresh_val)[0]
if len(idx) > 0:
last_good = idx[-1]
else: # this iteration had no sufficiently lange values
idx = (ii - 1) * n_per_chunk + last_good
break
else:
warn('Could not properly estimate ringing for the filter')
idx = n_per_chunk * n_chunks_max
return idx
def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
f_pass=None, f_stop=None, sfreq=None, btype=None,
return_copy=True):
def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None,
btype=None, return_copy=True):
"""Use IIR parameters to get filtering coefficients
This function works like a wrapper for iirdesign and iirfilter in
@@ -428,19 +479,39 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
function) with the filter coefficients ('b' and 'a') and an estimate
of the padding necessary ('padlen') so IIR filtering can be performed.
.. note:: As of 0.14, second-order sections will be used in filter
design by default (replacing ``output='ba'`` by
``output='sos'``) to help ensure filter stability and
reduce numerical error. Second-order sections filtering
requires SciPy >= 16.0.
Parameters
----------
iir_params : dict
Dictionary of parameters to use for IIR filtering.
If iir_params['b'] and iir_params['a'] exist, these will be used
as coefficients to perform IIR filtering. Otherwise, if
iir_params['order'] and iir_params['ftype'] exist, these will be
used with scipy.signal.iirfilter to make a filter. Otherwise, if
iir_params['gpass'] and iir_params['gstop'] exist, these will be
used with scipy.signal.iirdesign to design a filter.
iir_params['padlen'] defines the number of samples to pad (and
an estimate will be calculated if it is not given). See Notes for
more details.
* If ``iir_params['sos']`` exists, it will be used as
second-order sections to perform IIR filtering.
.. versionadded:: 0.13
* Otherwise, if ``iir_params['b']`` and ``iir_params['a']``
exist, these will be used as coefficients to perform IIR
filtering.
* Otherwise, if ``iir_params['order']`` and
``iir_params['ftype']`` exist, these will be used with
`scipy.signal.iirfilter` to make a filter.
* Otherwise, if ``iir_params['gpass']`` and
``iir_params['gstop']`` exist, these will be used with
`scipy.signal.iirdesign` to design a filter.
* ``iir_params['padlen']`` defines the number of samples to pad
(and an estimate will be calculated if it is not given).
See Notes for more details.
* ``iir_params['output']`` defines the system output kind when
designing filters, either "sos" or "ba". For 0.13 the
default is 'ba' but will change to 'sos' in 0.14.
f_pass : float or list of float
Frequency for the pass-band. Low-pass and high-pass filters should
be a float, band-pass should be a 2-element list of float.
@@ -451,23 +522,31 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
The sample rate.
btype : str
Type of filter. Should be 'lowpass', 'highpass', or 'bandpass'
(or analogous string representations known to scipy.signal).
(or analogous string representations known to
:func:`scipy.signal.iirfilter`).
return_copy : bool
If False, the 'b', 'a', and 'padlen' entries in iir_params will be
set inplace (if they weren't already). Otherwise, a new iir_params
instance will be created and returned with these entries.
If False, the 'sos', 'b', 'a', and 'padlen' entries in
``iir_params`` will be set inplace (if they weren't already).
Otherwise, a new ``iir_params`` instance will be created and
returned with these entries.
Returns
-------
iir_params : dict
Updated iir_params dict, with the entries (set only if they didn't
exist before) for 'b', 'a', and 'padlen' for IIR filtering.
exist before) for 'sos' (or 'b', 'a'), and 'padlen' for
IIR filtering.
See Also
--------
mne.filter.filter_data
mne.io.Raw.filter
Notes
-----
This function triages calls to scipy.signal.iirfilter and iirdesign
based on the input arguments (see descriptions of these functions
and scipy's scipy.signal.filter_design documentation for details).
This function triages calls to :func:`scipy.signal.iirfilter` and
:func:`scipy.signal.iirdesign` based on the input arguments (see
linked functions for more details).
Examples
--------
@@ -478,20 +557,20 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
filter 'N' and the type of filtering 'ftype' are specified. To get
coefficients for a 4th-order Butterworth filter, this would be:
>>> iir_params = dict(order=4, ftype='butter')
>>> iir_params = dict(order=4, ftype='butter', output='sos')
>>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False)
>>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
(5, 5, 82)
>>> print((2 * len(iir_params['sos']), iir_params['padlen']))
(4, 82)
Filters can also be constructed using filter design methods. To get a
40 Hz Chebyshev type 1 lowpass with specific gain characteristics in the
pass and stop bands (assuming the desired stop band is at 45 Hz), this
would be a filter with much longer ringing:
>>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20)
>>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20, output='sos')
>>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low')
>>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
(6, 6, 439)
>>> print((2 * len(iir_params['sos']), iir_params['padlen']))
(6, 439)
Padding and/or filter coefficients can also be manually specified. For
a 10-sample moving window with no padding during filtering, for example,
@@ -502,17 +581,32 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
>>> print((iir_params['b'], iir_params['a'], iir_params['padlen']))
(array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), [1, 0], 0)
For more information, see the tutorials :ref:`tut_background_filtering`
and :ref:`tut_artifacts_filter`.
""" # noqa
from scipy.signal import iirfilter, iirdesign
known_filters = ('bessel', 'butter', 'butterworth', 'cauer', 'cheby1',
'cheby2', 'chebyshev1', 'chebyshev2', 'chebyshevi',
'chebyshevii', 'ellip', 'elliptic')
a = None
b = None
if not isinstance(iir_params, dict):
raise TypeError('iir_params must be a dict, got %s' % type(iir_params))
system = None
# if the filter has been designed, we're good to go
if 'a' in iir_params and 'b' in iir_params:
[b, a] = [iir_params['b'], iir_params['a']]
if 'sos' in iir_params:
system = iir_params['sos']
output = 'sos'
elif 'a' in iir_params and 'b' in iir_params:
system = (iir_params['b'], iir_params['a'])
output = 'ba'
else:
output = iir_params.get('output', None)
if output is None:
warn('The default output type is "ba" in 0.13 but will change '
'to "sos" in 0.14')
output = 'ba'
if not isinstance(output, string_types) or output not in ('ba', 'sos'):
raise ValueError('Output must be "ba" or "sos", got %s'
% (output,))
# ensure we have a valid ftype
if 'ftype' not in iir_params:
raise RuntimeError('ftype must be an entry in iir_params if ''b'' '
@@ -526,30 +620,36 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
# use order-based design
Wp = np.asanyarray(f_pass) / (float(sfreq) / 2)
if 'order' in iir_params:
[b, a] = iirfilter(iir_params['order'], Wp, btype=btype,
ftype=ftype)
system = iirfilter(iir_params['order'], Wp, btype=btype,
ftype=ftype, output=output)
else:
# use gpass / gstop design
Ws = np.asanyarray(f_stop) / (float(sfreq) / 2)
if 'gpass' not in iir_params or 'gstop' not in iir_params:
raise ValueError('iir_params must have at least ''gstop'' and'
' ''gpass'' (or ''N'') entries')
[b, a] = iirdesign(Wp, Ws, iir_params['gpass'],
iir_params['gstop'], ftype=ftype)
system = iirdesign(Wp, Ws, iir_params['gpass'],
iir_params['gstop'], ftype=ftype, output=output)
if a is None or b is None:
if system is None:
raise RuntimeError('coefficients could not be created from iir_params')
# do some sanity checks
_check_coefficients(system)
# now deal with padding
if 'padlen' not in iir_params:
padlen = _estimate_ringing_samples(b, a)
padlen = estimate_ringing_samples(system)
else:
padlen = iir_params['padlen']
if return_copy:
iir_params = deepcopy(iir_params)
iir_params.update(dict(b=b, a=a, padlen=padlen))
iir_params.update(dict(padlen=padlen))
if output == 'sos':
iir_params.update(sos=system)
else:
iir_params.update(b=system[0], a=system[1])
return iir_params
@@ -564,8 +664,6 @@ def _check_method(method, iir_params, extra_types):
if method == 'iir':
if iir_params is None:
iir_params = dict(order=4, ftype='butter')
if not isinstance(iir_params, dict):
raise ValueError('iir_params must be a dict')
elif iir_params is not None:
raise ValueError('iir_params must be None if method != "iir"')
method = method.lower()
@@ -641,6 +739,16 @@ def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='10s',
-------
data : ndarray, shape (n_channels, n_times)
The filtered data.
See Also
--------
mne.filter.construct_iir_filter
mne.io.Raw.filter
Notes
-----
For more information, see the tutorials :ref:`tut_background_filtering`
and :ref:`tut_artifacts_filter`.
"""
if not isinstance(data, np.ndarray) or data.ndim != 2:
raise ValueError('data must be an array with two dimensions')
@@ -795,9 +903,7 @@ def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='10s',
else:
iir_params = construct_iir_filter(iir_params, [Fp1, Fp2],
[Fs1, Fs2], Fs, 'bandpass')
padlen = min(iir_params['padlen'], len(x))
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
picks, n_jobs, copy)
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
@@ -911,9 +1017,7 @@ def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='10s',
for fp_1, fp_2, fs_1, fs_2 in zip(Fp1, Fp2, Fs1, Fs2):
iir_params_new = construct_iir_filter(iir_params, [fp_1, fp_2],
[fs_1, fs_2], Fs, 'bandstop')
padlen = min(iir_params_new['padlen'], len(x))
xf = _filtfilt(x, iir_params_new['b'], iir_params_new['a'], padlen,
picks, n_jobs, copy)
xf = _filtfilt(x, iir_params_new, picks, n_jobs, copy)
return xf
@@ -1002,9 +1106,7 @@ def low_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
else:
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'low')
padlen = min(iir_params['padlen'], len(x))
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
picks, n_jobs, copy)
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
@@ -1095,9 +1197,7 @@ def high_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
else:
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'high')
padlen = min(iir_params['padlen'], len(x))
xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
picks, n_jobs, copy)
xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
+246 -11
View File
@@ -589,25 +589,260 @@ def get_firwin2():
return firwin2
def _filtfilt(*args, **kwargs):
"""wrap filtfilt, excluding padding arguments"""
from scipy.signal import filtfilt
# cut out filter args
if len(args) > 4:
args = args[:4]
if 'padlen' in kwargs:
del kwargs['padlen']
return filtfilt(*args, **kwargs)
def _filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None):
"""copy of modern SciPy filtfilt without "method" or "irlen" arguments"""
from scipy.signal import lfilter_zi, lfilter
b = np.atleast_1d(b)
a = np.atleast_1d(a)
x = np.asarray(x)
# method == "pad"
edge, ext = _validate_pad(padtype, padlen, x, axis,
ntaps=max(len(a), len(b)))
# Get the steady state of the filter's step response.
zi = lfilter_zi(b, a)
# Reshape zi and create x0 so that zi*x0 broadcasts
# to the correct value for the 'zi' keyword argument
# to lfilter.
zi_shape = [1] * x.ndim
zi_shape[axis] = zi.size
zi = np.reshape(zi, zi_shape)
x0 = axis_slice(ext, stop=1, axis=axis)
# Forward filter.
(y, zf) = lfilter(b, a, ext, axis=axis, zi=zi * x0)
# Backward filter.
# Create y0 so zi*y0 broadcasts appropriately.
y0 = axis_slice(y, start=-1, axis=axis)
(y, zf) = lfilter(b, a, axis_reverse(y, axis=axis), axis=axis, zi=zi * y0)
# Reverse y.
y = axis_reverse(y, axis=axis)
if edge > 0:
# Slice the actual signal from the extended signal.
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
return y
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
"""copy of SciPy sosfiltfilt"""
sos, n_sections = _validate_sos(sos)
# `method` is "pad"...
ntaps = 2 * n_sections + 1
ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
edge, ext = _validate_pad(padtype, padlen, x, axis,
ntaps=ntaps)
# These steps follow the same form as filtfilt with modifications
zi = sosfilt_zi(sos) # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
zi_shape = [1] * x.ndim
zi_shape[axis] = 2
zi.shape = [n_sections] + zi_shape
x_0 = axis_slice(ext, stop=1, axis=axis)
(y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
y_0 = axis_slice(y, start=-1, axis=axis)
(y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
y = axis_reverse(y, axis=axis)
if edge > 0:
y = axis_slice(y, start=edge, stop=-edge, axis=axis)
return y
def axis_slice(a, start=None, stop=None, step=None, axis=-1):
"""Take a slice along axis 'axis' from 'a'"""
a_slice = [slice(None)] * a.ndim
a_slice[axis] = slice(start, stop, step)
b = a[a_slice]
return b
def axis_reverse(a, axis=-1):
"""Reverse the 1-d slices of `a` along axis `axis`."""
return axis_slice(a, step=-1, axis=axis)
def _validate_pad(padtype, padlen, x, axis, ntaps):
"""Helper to validate padding for filtfilt"""
if padtype not in ['even', 'odd', 'constant', None]:
raise ValueError(("Unknown value '%s' given to padtype. padtype "
"must be 'even', 'odd', 'constant', or None.") %
padtype)
if padtype is None:
padlen = 0
if padlen is None:
# Original padding; preserved for backwards compatibility.
edge = ntaps * 3
else:
edge = padlen
# x's 'axis' dimension must be bigger than edge.
if x.shape[axis] <= edge:
raise ValueError("The length of the input vector x must be at least "
"padlen, which is %d." % edge)
if padtype is not None and edge > 0:
# Make an extension of length `edge` at each
# end of the input array.
if padtype == 'even':
ext = even_ext(x, edge, axis=axis)
elif padtype == 'odd':
ext = odd_ext(x, edge, axis=axis)
else:
ext = const_ext(x, edge, axis=axis)
else:
ext = x
return edge, ext
def _validate_sos(sos):
"""Helper to validate a SOS input"""
sos = np.atleast_2d(sos)
if sos.ndim != 2:
raise ValueError('sos array must be 2D')
n_sections, m = sos.shape
if m != 6:
raise ValueError('sos array must be shape (n_sections, 6)')
if not (sos[:, 3] == 1).all():
raise ValueError('sos[:, 3] should be all ones')
return sos, n_sections
def odd_ext(x, n, axis=-1):
"""Generate a new ndarray by making an odd extension of x along an axis."""
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(("The extension length n (%d) is too big. " +
"It must not exceed x.shape[axis]-1, which is %d.")
% (n, x.shape[axis] - 1))
left_end = axis_slice(x, start=0, stop=1, axis=axis)
left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
right_end = axis_slice(x, start=-1, axis=axis)
right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
ext = np.concatenate((2 * left_end - left_ext,
x,
2 * right_end - right_ext),
axis=axis)
return ext
def even_ext(x, n, axis=-1):
"""Create an ndarray that is an even extension of x along an axis."""
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(("The extension length n (%d) is too big. " +
"It must not exceed x.shape[axis]-1, which is %d.")
% (n, x.shape[axis] - 1))
left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
ext = np.concatenate((left_ext,
x,
right_ext),
axis=axis)
return ext
def const_ext(x, n, axis=-1):
"""Create an ndarray that is a constant extension of x along an axis"""
if n < 1:
return x
left_end = axis_slice(x, start=0, stop=1, axis=axis)
ones_shape = [1] * x.ndim
ones_shape[axis] = n
ones = np.ones(ones_shape, dtype=x.dtype)
left_ext = ones * left_end
right_end = axis_slice(x, start=-1, axis=axis)
right_ext = ones * right_end
ext = np.concatenate((left_ext,
x,
right_ext),
axis=axis)
return ext
def sosfilt_zi(sos):
"""Compute an initial state `zi` for the sosfilt function"""
from scipy.signal import lfilter_zi
sos = np.asarray(sos)
if sos.ndim != 2 or sos.shape[1] != 6:
raise ValueError('sos must be shape (n_sections, 6)')
n_sections = sos.shape[0]
zi = np.empty((n_sections, 2))
scale = 1.0
for section in range(n_sections):
b = sos[section, :3]
a = sos[section, 3:]
zi[section] = scale * lfilter_zi(b, a)
# If H(z) = B(z)/A(z) is this section's transfer function, then
# b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
# state value of this section's step response.
scale *= b.sum() / a.sum()
return zi
def sosfilt(sos, x, axis=-1, zi=None):
"""Filter data along one dimension using cascaded second-order sections"""
from scipy.signal import lfilter
x = np.asarray(x)
sos = np.atleast_2d(sos)
if sos.ndim != 2:
raise ValueError('sos array must be 2D')
n_sections, m = sos.shape
if m != 6:
raise ValueError('sos array must be shape (n_sections, 6)')
use_zi = zi is not None
if use_zi:
zi = np.asarray(zi)
x_zi_shape = list(x.shape)
x_zi_shape[axis] = 2
x_zi_shape = tuple([n_sections] + x_zi_shape)
if zi.shape != x_zi_shape:
raise ValueError('Invalid zi shape. With axis=%r, an input with '
'shape %r, and an sos array with %d sections, zi '
'must have shape %r.' %
(axis, x.shape, n_sections, x_zi_shape))
zf = np.zeros_like(zi)
for section in range(n_sections):
if use_zi:
x, zf[section] = lfilter(sos[section, :3], sos[section, 3:],
x, axis, zi=zi[section])
else:
x = lfilter(sos[section, :3], sos[section, 3:], x, axis)
out = (x, zf) if use_zi else x
return out
def get_filtfilt():
"""Helper to get filtfilt from scipy"""
from scipy.signal import filtfilt
if 'padlen' in _get_args(filtfilt):
return filtfilt
else:
return _filtfilt
return _filtfilt
def get_sosfiltfilt():
"""Helper to get sosfiltfilt from scipy"""
try:
from scipy.signal import sosfiltfilt
except ImportError:
sosfiltfilt = _sosfiltfilt
return sosfiltfilt
def _get_argrelmax():
+5
View File
@@ -893,6 +893,11 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
mne.io.Raw.notch_filter
mne.io.Raw.resample
mne.filter.filter_data
Notes
-----
For more information, see the tutorials :ref:`tut_background_filtering`
and :ref:`tut_artifacts_filter`.
"""
_check_preload(self, 'raw.filter')
data_picks = _pick_data_or_ica(self.info)
+83 -17
View File
@@ -3,19 +3,39 @@ from numpy.testing import (assert_array_almost_equal, assert_almost_equal,
assert_array_equal, assert_allclose)
from nose.tools import assert_equal, assert_true, assert_raises
import warnings
from scipy.signal import resample as sp_resample
from scipy.signal import resample as sp_resample, butter
from mne.filter import (band_pass_filter, high_pass_filter, low_pass_filter,
band_stop_filter, resample, _resample_stim_channels,
construct_iir_filter, notch_filter, detrend,
_overlap_add_filter, _smart_pad)
_overlap_add_filter, _smart_pad,
estimate_ringing_samples, filter_data)
from mne.utils import sum_squared, run_tests_if_main, slow_test, catch_logging
from mne.utils import (sum_squared, run_tests_if_main, slow_test,
catch_logging, requires_version)
warnings.simplefilter('always') # enable b/c these tests throw warnings
rng = np.random.RandomState(0)
@requires_version('scipy', '0.16')
def test_estimate_ringing():
"""Test our ringing estimation function"""
# Actual values might differ based on system, so let's be approximate
for kind in ('ba', 'sos'):
for thresh, lims in ((0.1, (30, 60)), # 47
(0.01, (300, 600)), # 475
(0.001, (3000, 6000)), # 4758
(0.0001, (30000, 60000))): # 37993
n_ring = estimate_ringing_samples(butter(3, thresh, output=kind))
assert_true(lims[0] <= n_ring <= lims[1],
msg='%s %s: %s <= %s <= %s'
% (kind, thresh, lims[0], n_ring, lims[1]))
with warnings.catch_warnings(record=True) as w:
assert_equal(estimate_ringing_samples(butter(4, 0.00001)), 100000)
assert_true(any('properly estimate' in str(ww.message) for ww in w))
def test_1d_filter():
"""Test our private overlap-add filtering function"""
# make some random signals and filters
@@ -68,17 +88,35 @@ def test_1d_filter():
assert_allclose(x_expected, x_filtered)
@requires_version('scipy', '0.16')
def test_iir_stability():
"""Test IIR filter stability check
"""
"""Test IIR filter stability check"""
sig = np.empty(1000)
sfreq = 1000
# This will make an unstable filter, should throw RuntimeError
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
method='iir', iir_params=dict(ftype='butter', order=8))
# can't pass iir_params if method='fir'
method='iir', iir_params=dict(ftype='butter', order=8,
output='ba'))
# This one should work just fine
high_pass_filter(sig, sfreq, 0.6, method='iir',
iir_params=dict(ftype='butter', order=8, output='sos'))
# bad system type
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.6, method='iir',
iir_params=dict(ftype='butter', order=8, output='foo'))
# missing ftype
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
method='iir', iir_params=dict(order=8, output='sos'))
# bad ftype
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
method='iir',
iir_params=dict(order=8, ftype='foo', output='sos'))
# missing gstop
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
method='iir', iir_params=dict(gpass=0.5, output='sos'))
# can't pass iir_params if method='fft'
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
method='fir', iir_params=dict(ftype='butter', order=2))
method='fft', iir_params=dict(ftype='butter', order=2,
output='sos'))
# method must be string
assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
method=1)
@@ -86,12 +124,26 @@ def test_iir_stability():
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
method='blah')
# bad iir_params
assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
method='iir', iir_params='blah')
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
method='fir', iir_params='blah')
method='fft', iir_params=dict())
# should pass because dafault trans_bandwidth is not relevant
high_pass_filter(sig, 250, 0.5, method='iir',
iir_params=dict(ftype='butter', order=6))
iir_params = dict(ftype='butter', order=2, output='sos')
x_sos = high_pass_filter(sig, 250, 0.5, method='iir',
iir_params=iir_params)
iir_params_sos = construct_iir_filter(iir_params, f_pass=0.5, sfreq=250,
btype='highpass')
x_sos_2 = high_pass_filter(sig, 250, 0.5, method='iir',
iir_params=iir_params_sos)
assert_allclose(x_sos[100:-100], x_sos_2[100:-100])
x_ba = high_pass_filter(sig, 250, 0.5, method='iir',
iir_params=dict(ftype='butter', order=2,
output='ba'))
# Note that this will fail for higher orders (e.g., 6) showing the
# hopefully decreased numerical error of SOS
assert_allclose(x_sos[100:-100], x_ba[100:-100])
def test_notch_filters():
@@ -178,6 +230,7 @@ def test_resample_stim_channel():
assert_equal(new_data.shape[1], new_data_len)
@requires_version('scipy', '0.16')
@slow_test
def test_filters():
"""Test low-, band-, high-pass, and band-stop filters plus resampling
@@ -261,15 +314,22 @@ def test_filters():
assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)
# let's construct some filters
iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
# this should be a third order filter
assert_true(iir_params['a'].size - 1 == 3)
assert_true(iir_params['b'].size - 1 == 3)
iir_params = dict(ftype='butter', order=4)
assert_equal(iir_params['a'].size - 1, 3)
assert_equal(iir_params['b'].size - 1, 3)
iir_params = dict(ftype='butter', order=4, output='ba')
iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
assert_true(iir_params['a'].size - 1 == 4)
assert_true(iir_params['b'].size - 1 == 4)
assert_equal(iir_params['a'].size - 1, 4)
assert_equal(iir_params['b'].size - 1, 4)
iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='sos')
iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
# this should be a third order filter, which requires 2 SOS ((2, 6))
assert_equal(iir_params['sos'].shape, (2, 6))
iir_params = dict(ftype='butter', order=4, output='sos')
iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
assert_equal(iir_params['sos'].shape, (2, 6))
# check that picks work for 3d array with one channel and picks=[0]
a = rng.randn(5 * sfreq, 5 * sfreq)
@@ -299,6 +359,12 @@ def test_filters():
# the firwin2 function gets us this close
assert_allclose(x, x_filt, rtol=1e-3, atol=1e-3)
# degenerate conditions
assert_raises(ValueError, filter_data, x, sfreq, 1, 10) # not 2D
assert_raises(ValueError, filter_data, x[np.newaxis], -sfreq, 1, 10)
assert_raises(ValueError, filter_data, x[np.newaxis], sfreq, 1,
sfreq * 0.75)
def test_cuda():
"""Test CUDA-based filtering
+6 -3
View File
@@ -14,9 +14,10 @@ from mne.utils import run_tests_if_main
from mne.fixes import (_in1d, _tril_indices, _copysign, _unravel_index,
_Counter, _unique, _bincount, _digitize,
_sparse_block_diag, _matrix_rank, _meshgrid,
_isclose)
from mne.fixes import _firwin2 as mne_firwin2
from mne.fixes import _filtfilt as mne_filtfilt
_isclose,
_firwin2 as mne_firwin2,
_filtfilt as mne_filtfilt,
_sosfiltfilt as mne_sosfiltfilt)
rng = np.random.RandomState(0)
@@ -148,6 +149,8 @@ def test_filtfilt():
# Filter with an impulse
y = mne_filtfilt([1, 0], [1, 0], x, padlen=0)
assert_array_equal(x, y)
y = mne_sosfiltfilt(np.array([[1., 0., 0., 1, 0., 0.]]), x, padlen=0)
assert_array_equal(x, y)
def test_sparse_block_diag():
+135 -5
View File
@@ -2,6 +2,7 @@
r"""
.. _tut_background_filtering:
===================================
Background information on filtering
===================================
@@ -13,8 +14,10 @@ in MNE-Python on actual data, see the :ref:`tut_artifacts_filter` tutorial.
.. contents::
.. _filtering-basics:
Filtering basics
----------------
================
Let's get some of the basic math down. In the frequency domain, digital
filters have a transfer function that is given by:
@@ -76,7 +79,9 @@ In general, the sharper something is in frequency, the broader it is in time,
and vice-versa. This is a fundamental time-frequency tradeoff, and it will
show up below.
Here we will focus first on FIR filters, which are the default filters used by
===========
First we will focus first on FIR filters, which are the default filters used by
MNE-Python.
"""
@@ -166,13 +171,24 @@ h = np.sinc(2 * f_p * t) / (4 * np.pi)
def plot_filter(h, title, freq, gain, show=True):
if h.ndim == 2: # second-order sections
sos = h
n = mne.filter.estimate_ringing_samples(sos)
h = np.zeros(n)
h[0] = 1
h = signal.sosfilt(sos, h)
H = np.ones(512, np.complex128)
for section in sos:
f, this_H = signal.freqz(section[:3], section[3:])
H *= this_H
else:
f, H = signal.freqz(h)
fig, axs = plt.subplots(2)
t = np.arange(len(h)) / sfreq
axs[0].plot(t, h, color=blue)
axs[0].set(xlim=t[[0, -1]], xlabel='Time (sec)',
ylabel='Amplitude h(n)', title=title)
box_off(axs[0])
f, H = signal.freqz(h)
f *= sfreq / (2 * np.pi)
axs[1].semilogx(f, 10 * np.log10((H * H.conj()).real), color=blue,
linewidth=2, zorder=4)
@@ -366,8 +382,121 @@ mne.viz.tight_layout()
plt.show()
###############################################################################
# IIR filters
# ===========
# MNE-Python also offers IIR filtering functionality that is based on the
# methods from :mod:`scipy.signal`. Specifically, we use the general-purpose
# functions :func:`scipy.signal.iirfilter` and :func:`scipy.signal.iirdesign`,
# which provide unified interfaces to IIR filter design.
#
# Designing IIR filters
# ---------------------
# Let's continue with our design of a 40 Hz low-pass filter, and look at
# some trade-offs of different IIR filters.
#
# Often the default IIR filter is a `Butterworth filter`_, which is designed
# to have a *maximally flat pass-band*. Let's look at a few orders of filter,
# i.e., a few different number of coefficients used and therefore steepness
# of the filter:
sos = signal.iirfilter(2, f_p / nyq, btype='low', ftype='butter', output='sos')
plot_filter(sos, 'Butterworth order=2', freq, gain)
# Eventually this will just be from scipy signal.sosfiltfilt, but 0.18 is
# not widely adopted yet (as of June 2016), so we use our wrapper...
sosfiltfilt = mne.fixes.get_sosfiltfilt()
x_shallow = sosfiltfilt(sos, x)
###############################################################################
# The falloff of this filter is not very steep.
#
# .. warning:: For brevity, we do not show the phase of these filters here.
# In the FIR case, we can design linear-phase filters, and
# compensate for the delay if necessary. This cannot be done
# with IIR filters, and as the filter order increases, the
# phase distortion near and in the transition band worsens.
# However, if acausal (forward-backward) filtering can be used,
# e.g. with :func:`scipy.signal.filtfilt`, these phase issues
# can be mitigated.
#
# .. note:: Here we have made use of second-order sections (SOS)
# by using :func:`scipy.signal.sosfilt` and, under the
# hood, :func:`scipy.signal.zpk2sos` when passing the
# ``output='sos'`` keyword argument to
# :func:`scipy.signal.iirfilter`. The filter definitions
# given in :ref:`filtering-basics` use the polynomial
# numerator/denominator (sometimes called "tf") form ``(b, a)``,
# which are theoretically equivalent to the SOS form used here.
# In practice, however, the SOS form can give much better results
# due to issues with numerical precision (see
# :func:`scipy.signal.sosfilt` for an example), so SOS should be
# used when possible to do IIR filtering.
#
# Let's increase the order, and note that now we have better attenuation,
# with a longer impulse response:
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='butter', output='sos')
plot_filter(sos, 'Butterworth order=8', freq, gain)
x_steep = sosfiltfilt(sos, x)
###############################################################################
# There are other types of IIR filters that we can use. For a complete list,
# check out the documentation for :func:`scipy.signal.iirdesign`. Let's
# try a Chebychev (type I) filter, which trades off ripple in the pass-band
# to get better attenuation in the stop-band:
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
rp=1) # dB of acceptable pass-band ripple
plot_filter(sos, 'Chebychev-1 order=8, ripple=1 dB', freq, gain)
###############################################################################
# And if we can live with even more ripple, we can get it slightly steeper,
# but the impulse response begins to ring substantially longer (note the
# different x-axis scale):
sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
rp=6)
plot_filter(sos, 'Chebychev-1 order=8, ripple=6 dB', freq, gain)
###############################################################################
# Applying IIR filters
# --------------------
# Now let's look at how our shallow and steep Butterworth IIR filters
# perform on our morlet signal from before:
axs = plt.subplots(2)[1]
yticks = np.arange(4) / -30.
yticklabels = ['Original', 'Noisy', 'Butterworth-2', 'Butterworth-8']
plot_signal(x_orig, offset=yticks[0])
plot_signal(x, offset=yticks[1])
plot_signal(x_shallow, offset=yticks[2])
plot_signal(x_steep, offset=yticks[3])
axs[0].set(xlim=tlim, title='Lowpass=%d Hz' % f_p, xticks=tticks,
ylim=[-0.125, 0.025], yticks=yticks, yticklabels=yticklabels,)
for text in axs[0].get_yticklabels():
text.set(rotation=45, size=8)
axs[1].set(xlim=flim, ylim=ylim, xlabel='Frequency (Hz)',
ylabel='Magnitude (dB)')
box_off(axs[0])
box_off(axs[1])
mne.viz.tight_layout()
plt.show()
###############################################################################
# Filtering in MNE-Python
# =======================
# Most often, filtering in MNE-Python is done at the :class:`mne.io.Raw` level,
# and thus :func:`mne.io.Raw.filter` is used. This function under the hood
# (among other things) calls :func:`mne.filter.filter_data` to actually
# filter the data.
#
# :func:`mne.filter.filter_data` by default applies a FIR filter designed using
# :func:`scipy.signal.firwin2`. For more information on how to use the
# MNE-Python filtering functions with real data, consult the preprocessing
# tutorial on :ref:`tut_artifacts_filter`.
#
# Summary
# -------
# =======
# When filtering, there are always tradeoffs that should be considered.
# One important tradeoff is between time-domain characteristics (like ringing)
# and frequency-domain attenuation characteristics (like effective transition
@@ -379,7 +508,7 @@ plt.show()
###############################################################################
# References
# ----------
# ==========
# .. [1] Parks TW, Burrus CS. Digital Filter Design.
# New York: Wiley-Interscience, 1987.
#
@@ -394,3 +523,4 @@ plt.show()
# .. _scipy firwin2: http://scipy.github.io/devdocs/generated/scipy.signal.firwin2.html # noqa
# .. _matlab fir2: http://www.mathworks.com/help/signal/ref/fir2.html
# .. _matlab firls: http://www.mathworks.com/help/signal/ref/firls.html
# .. _Butterworth filter: https://en.wikipedia.org/wiki/Butterworth_filter