134 lines
4.6 KiB
Python
134 lines
4.6 KiB
Python
"""Util function to baseline correct data."""
|
|
|
|
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
|
|
#
|
|
# License: BSD (3-clause)
|
|
|
|
import numpy as np
|
|
|
|
from .utils import logger, verbose
|
|
|
|
|
|
def _log_rescale(baseline, mode='mean'):
|
|
"""Log the rescaling method."""
|
|
if baseline is not None:
|
|
valid_modes = ('logratio', 'ratio', 'zscore', 'mean', 'percent',
|
|
'zlogratio')
|
|
if mode not in valid_modes:
|
|
raise Exception('mode should be any of : %s' % (valid_modes, ))
|
|
msg = 'Applying baseline correction (mode: %s)' % mode
|
|
else:
|
|
msg = 'No baseline correction applied'
|
|
return msg
|
|
|
|
|
|
@verbose
|
|
def rescale(data, times, baseline, mode='mean', copy=True, picks=None,
|
|
verbose=None):
|
|
"""Rescale (baseline correct) data.
|
|
|
|
Parameters
|
|
----------
|
|
data : array
|
|
It can be of any shape. The only constraint is that the last
|
|
dimension should be time.
|
|
times : 1D array
|
|
Time instants is seconds.
|
|
baseline : tuple or list of length 2, or None
|
|
The time interval to apply rescaling / baseline correction.
|
|
If None do not apply it. If baseline is ``(bmin, bmax)``
|
|
the interval is between ``bmin`` (s) and ``bmax`` (s).
|
|
If ``bmin is None`` the beginning of the data is used
|
|
and if ``bmax is None`` then ``bmax`` is set to the end of the
|
|
interval. If baseline is ``(None, None)`` the entire time
|
|
interval is used. If baseline is None, no correction is applied.
|
|
mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
|
|
Perform baseline correction by
|
|
|
|
- subtracting the mean of baseline values ('mean')
|
|
- dividing by the mean of baseline values ('ratio')
|
|
- dividing by the mean of baseline values and taking the log
|
|
('logratio')
|
|
- subtracting the mean of baseline values followed by dividing by
|
|
the mean of baseline values ('percent')
|
|
- subtracting the mean of baseline values and dividing by the
|
|
standard deviation of baseline values ('zscore')
|
|
- dividing by the mean of baseline values, taking the log, and
|
|
dividing by the standard deviation of log baseline values
|
|
('zlogratio')
|
|
|
|
copy : bool
|
|
Whether to return a new instance or modify in place.
|
|
picks : list of int | None
|
|
Data to process along the axis=-2 (None, default, processes all).
|
|
verbose : bool, str, int, or None
|
|
If not None, override default verbose level (see :func:`mne.verbose`
|
|
and :ref:`Logging documentation <tut_logging>` for more).
|
|
|
|
Returns
|
|
-------
|
|
data_scaled: array
|
|
Array of same shape as data after rescaling.
|
|
"""
|
|
data = data.copy() if copy else data
|
|
msg = _log_rescale(baseline, mode)
|
|
logger.info(msg)
|
|
if baseline is None or data.shape[-1] == 0:
|
|
return data
|
|
|
|
bmin, bmax = baseline
|
|
if bmin is None:
|
|
imin = 0
|
|
else:
|
|
imin = np.where(times >= bmin)[0]
|
|
if len(imin) == 0:
|
|
raise ValueError('bmin is too large (%s), it exceeds the largest '
|
|
'time value' % (bmin,))
|
|
imin = int(imin[0])
|
|
if bmax is None:
|
|
imax = len(times)
|
|
else:
|
|
imax = np.where(times <= bmax)[0]
|
|
if len(imax) == 0:
|
|
raise ValueError('bmax is too small (%s), it is smaller than the '
|
|
'smallest time value' % (bmax,))
|
|
imax = int(imax[-1]) + 1
|
|
if imin >= imax:
|
|
raise ValueError('Bad rescaling slice (%s:%s) from time values %s, %s'
|
|
% (imin, imax, bmin, bmax))
|
|
|
|
# technically this is inefficient when `picks` is given, but assuming
|
|
# that we generally pick most channels for rescaling, it's not so bad
|
|
mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True)
|
|
|
|
if mode == 'mean':
|
|
def fun(d, m):
|
|
d -= m
|
|
elif mode == 'ratio':
|
|
def fun(d, m):
|
|
d /= m
|
|
elif mode == 'logratio':
|
|
def fun(d, m):
|
|
d /= m
|
|
np.log10(d, out=d)
|
|
elif mode == 'percent':
|
|
def fun(d, m):
|
|
d -= m
|
|
d /= m
|
|
elif mode == 'zscore':
|
|
def fun(d, m):
|
|
d -= m
|
|
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
|
|
elif mode == 'zlogratio':
|
|
def fun(d, m):
|
|
d /= m
|
|
np.log10(d, out=d)
|
|
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
|
|
|
|
if picks is None:
|
|
fun(data, mean)
|
|
else:
|
|
for pi in picks:
|
|
fun(data[..., pi, :], mean[..., pi, :])
|
|
return data
|