Files
BPLabs/matrix_test_thread.py
T

485 lines
17 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from threading import Thread, Event
import io
import dill
import base64
import os
import random
from scipy.optimize import minimize
import csv
from shutil import copy2
from pysndfile import sndio
from matrix_test.filesystem import globDir
import sounddevice as sd
import pdb
def set_trace():
import logging
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
log = logging.getLogger('engineio')
log.setLevel(logging.ERROR)
pdb.set_trace()
def find_nearest_idx(array, value):
'''
Adapted from: https://stackoverflow.com/questions/2566412/find-nearest-value-in-numpy-array
'''
array = np.asarray(array)
idx = (np.abs(array - value)).argmin()
return idx
def abline(slope, intercept):
"""Plot a line from slope and intercept"""
axes = plt.gca()
x_vals = np.array(axes.get_xlim())
y_vals = intercept + slope * x_vals
plt.plot(x_vals, y_vals, '--')
class MatTestThread(Thread):
'''
Thread for running server side matrix test operations
'''
def __init__(self, listN=3, sessionFilepath=None, noiseFilepath="./matrix_test/stimulus/wav/noise/noise.wav",
listFolder="./matrix_test/stimulus/wav/sentence-lists/", socketio=None, participant=None):
super(MatTestThread, self).__init__()
self.participant=participant
self.newResp = False
self.foundSRT = False
self.finalised = False
self.pageLoaded = False
self.clinPageLoaded = False
self.partPageLoaded = False
self.socketio = socketio
# Attach messages from gui to class methods
self.socketio.on_event('submit_mat_response', self.submitMatResponse, namespace='/main')
self.socketio.on_event('mat_page_loaded', self.setPageLoaded, namespace='/main')
self.socketio.on_event('save_file_dialog_resp', self.manualSave, namespace='/main')
self.socketio.on_event('load_file_dialog_resp', self.loadStateSocketHandle, namespace='/main')
self.socketio.on_event('repeat_stimulus', self.playStimulusSocketHandle, namespace='/main')
self.socketio.on_event('finish_test', self.finishTestEarly, namespace='/main')
self.socketio.on_event('finalise_results', self.finaliseResults, namespace='/main')
self.listN = listN
self.loadedLists = []
self.lists = []
self.listsRMS = []
self.listsString = []
self.noise = None
self.noise_rms = None
self.fs = None
self.response = ['','','','','']
self.responses = []
self.presentedWords = []
self.nCorrect = None
self.snr = 0.0
self.direction = 0
# Record SNRs presented with each trial of the adaptive track
self.snrTrack = np.empty(30)
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
self.snrTrack[:] = np.nan
self.snrTrack[0] = 0.0
# Count number of presented trials
self.trialN = 1
self.currentList = None
self.availableSentenceInds = []
self.usedLists = []
# Adaptive track parameters
self.slope = 0.15
self.i = 0
# Plotting parameters
self.img = io.BytesIO()
self.img.seek(0)
self.img.truncate(0)
self._stopevent = Event()
if self.participant:
folder = self.participant.data_paths['adaptive_matrix_data']
self.backupFilepath=os.path.join(folder, 'mat_state.pkl')
else:
self.backupFilepath='./mat_state.pkl'
# If loading session from file, load session variables from the file
if sessionFilepath:
self.loadState(sessionFilepath)
else:
# Preload audio at start of the test
self.loadStimulus(listFolder, n=self.listN)
self.loadNoise(noiseFilepath)
def testLoop(self):
'''
Main loop for iteratively finding the SRT
'''
self.waitForPageLoad()
while not self.foundSRT and not self._stopevent.isSet():
self.plotSNR()
self.playStimulus()
self.waitForResponse()
if self.foundSRT:
break
if self._stopevent.isSet():
return
self.calcSNR()
self.saveState(out=self.backupFilepath)
self.saveState(out=self.backupFilepath)
if not self._stopevent.isSet():
self.unsetPageLoaded()
self.socketio.emit('processing-complete', {'data': ''}, namespace='/main')
self.waitForPageLoad()
self.plotSNR()
self.fitLogistic()
self.waitForFinalise()
def finishTestEarly(self):
self.foundSRT = True
def join(self, timeout=None):
""" Stop the thread. """
self._stopevent.set()
Thread.join(self, timeout)
def waitForResponse(self):
while not self.newResp and not self._stopevent.isSet() and not self.foundSRT:
self._stopevent.wait(0.2)
return
def waitForPageLoad(self):
while not self.pageLoaded and not self._stopevent.isSet():
self.socketio.emit("check-loaded", namespace='/main')
self._stopevent.wait(0.5)
return
def waitForFinalise(self):
while not self.finalised and not self._stopevent.isSet():
self._stopevent.wait(0.2)
return
def finaliseResults(self):
toSave = ['snrTrack', 'trialN', 'wordsCorrect', 'presentedWords', 'responses']
saveDict = {k:self.__dict__[k] for k in toSave}
if self.participant:
self.participant['adaptive_matrix_data'].update(saveDict)
self.participant.save("adaptive_matrix_data")
backup_path = os.path.join(self.participant.data_paths['adaptive_matrix_data'],
'finalised_backup.pkl')
copy2(self.backupFilepath, backup_path)
else:
copy2(self.backupFilepath, './finalised_backup.pkl')
with open('./Matrix_test_results.pkl', 'wb') as f:
dill.dump(saveDict, f)
self.finalised = True
@staticmethod
def logisticFunction(L, L_50, s_50):
'''
Calculate logistic function for SNRs L, 50% SRT point L_50, and slope
s_50
'''
return 1./(1.+np.exp(4.*s_50*(L_50-L)))
def logisticFuncLiklihood(self, args):
'''
Calculate the log liklihood for given L_50 and s_50 parameters.
This function is designed for use with the scipy minimize optimisation
function to find the optimal L_50 and s_50 parameters.
args: a tuple containing (L_50, s_50)
self.wordsCorrect: an n dimensional binary array of shape (N, 5),
containing the correctness of responses to each of the 5 words for N
trials
self.trackSNR: A sorted list of SNRs of shape N, for N trials
'''
L_50, s_50 = args
ck = self.wordsCorrect[np.arange(self.trackSNR.shape[0])]
p_lf = self.logisticFunction(self.trackSNR, L_50, s_50)
# Reshape array for vectorized calculation of log liklihood
p_lf = p_lf[:, np.newaxis].repeat(5, axis=1)
# Calculate the liklihood
res = (p_lf**ck)*(((1.-p_lf)**(1.-ck)))
with np.errstate(divide='raise'):
try:
a = np.concatenate(res)
a[a == 0] = a.max()
out = -np.sum(np.log(a))
except:
set_trace()
return out
def fitLogistic(self):
'''
'''
self.wordsCorrect = self.wordsCorrect[:self.trialN].astype(float)
self.trackSNR = self.snrTrack[:self.trialN]
res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
percent_correct = (np.sum(self.wordsCorrect, axis=1)/self.wordsCorrect.shape[1])*100.
sortedSNRind = np.argsort(-self.trackSNR)
sortedSNR = self.trackSNR[sortedSNRind]
sortedPC = percent_correct[sortedSNRind]
x = np.linspace(np.min(sortedSNR)-5, np.max(sortedSNR)+3, 3000)
snr_50, s_50 = res.x
x_y = self.logisticFunction(x, snr_50, s_50)
x_y *= 100.
# np.savez('./plot.npz', x, x_y*100., sortedSNR, sortedPC)
# snrPC = pd.DataFrame(sortedPC, sortedSNR)
# sns.kdeplot(sortedSNR, sortedPC, cmap="Blues", shade=True)
# sns.relplot(data=snrPC)
# sns.relplot(x, x_y, kind="line")
#plt.plot(sortedSNR, sortedPC, "x")
#sbnplot = sns.relplot(data=pd.DataFrame(x_y*100., x), kind="line")
plt.clf()
axes = plt.gca()
psycLine, = axes.plot(x, x_y)
plt.title("Predicted psychometric function")
plt.xlabel("SNR (dB)")
plt.ylabel("% Correct")
srtLine, = axes.plot([snr_50,snr_50], [-50,50.], 'r--')
axes.plot([-50.,snr_50], [50.,50.], 'r--')
plt.xlim(x.min(), x.max())
plt.ylim(x_y.min(), x_y.max())
plt.yticks(np.arange(5)*25.)
x_vals = np.array(axes.get_xlim())
s_50 *= 100.
b = 50. - s_50 * snr_50
y_vals = s_50 * x_vals + b
slopeLine, = axes.plot(x_vals, y_vals, '--')
ticks = (np.arange((x.max()-x.min())/2.5)*2.5)+(2.5 * round(float(x.min())/2.5))
ticks[find_nearest_idx(ticks, snr_50)] = snr_50
labels = ["{:.2f}".format(x) for x in ticks]
plt.xticks(ticks, labels)
plt.legend((psycLine, srtLine, slopeLine), ("Psychometric function", "SRT={:.2f}dB".format(snr_50), "Slope={:.2f}%/dB".format(s_50)))
dpi = 300
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
self.img.seek(0)
plot_url = base64.b64encode(self.img.getvalue()).decode()
plot_url = "data:image/png;base64,{}".format(plot_url)
self.socketio.emit("mat_mle_plot_ready", {'data': plot_url}, namespace="/main")
def plotSNR(self):
'''
'''
plt.clf()
plt.plot(self.snrTrack, 'bo-')
plt.ylim([20.0, -20.0])
plt.xticks(np.arange(30))
plt.xlim([-1, self.trialN])
plt.xlabel("Trial N")
plt.ylabel("SNR (dB)")
plt.title("Adaptive track")
for i, txt in enumerate(self.snrTrack[:self.trialN]):
plt.annotate("{0}/{1}".format(
np.sum(self.wordsCorrect[i]).astype(int),
self.wordsCorrect[i].size),
(i, self.snrTrack[i]),
xytext=(0, 13),
va="center",
ha="center",
textcoords='offset points'
)
dpi = 300
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
self.img.seek(0)
plot_url = base64.b64encode(self.img.getvalue()).decode()
plot_url = "data:image/png;base64,{}".format(plot_url)
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
def calcSNR(self):
'''
'''
self.presentedWords.append(self.currentWords)
self.responses.append(self.response)
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
print("Current words: {}".format(self.currentWords))
print("Response: {}".format(self.response))
print("Correct: {}".format(correct))
self.nCorrect = np.sum(correct)/correct.size
prevSNR = self.snr
self.snr -= (((1.5*1.41**-self.i)*(self.nCorrect - 0.5))/self.slope)
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
if self.direction != currentDirection:
if currentDirection == 0:
pass
else:
if self.direction != 0:
self.i += 1
self.direction = currentDirection
# If all sentences in the current list have been presented...
if not self.availableSentenceInds:
# Set subsequent list as the current list
del self.lists[0]
del self.listsRMS[0]
del self.listsString[0]
if not len(self.lists):
self.foundSRT = True
self.wordsCorrect[self.trialN-1] = correct
return None
self.availableSentenceInds = list(range(len(self.lists[0])))
random.shuffle(self.availableSentenceInds)
self.snrTrack[self.trialN] = self.snr
self.wordsCorrect[self.trialN-1] = correct
self.trialN += 1
def playStimulusSocketHandle(self):
self.playStimulus(replay=True)
def playStimulus(self, replay=False):
self.newResp = False
self.socketio.emit("mat_stim_playing", namespace="/main")
if not replay:
self.y = self.generateTrial(self.snr)
# Play audio
sd.play(self.y, self.fs, blocking=True)
self.socketio.emit("mat_stim_done", namespace="/main")
def loadStimulus(self, listDir, n=3, demo=False):
# Get folder path of all lists in the list directory
lists = next(os.walk(listDir))[1]
lists.pop(lists.index("demo"))
# Don't reload an lists that have already been loaded
pop = [lists.index(x) for x in self.loadedLists]
for i in sorted(pop, reverse=True):
del lists[i]
# Randomly select n lists
inds = list(range(len(lists)))
random.shuffle(inds)
# Pick first n shuffled lists
inds = inds[:n]
for ind in inds:
# Get filepaths to the audiofiles and word csv file for the current
# list
listAudiofiles = globDir(os.path.join(listDir, lists[ind]), "*.wav")
listCSV = globDir(os.path.join(listDir, lists[ind]), "*.csv")
with open(listCSV[0]) as csv_file:
csv_reader = csv.reader(csv_file)
# Allocate empty lists to store audio samples, RMS and words of
# each list sentence
self.lists.append([])
self.listsRMS.append([])
self.listsString.append([])
# Get data for each sentence
for fp, words in zip(listAudiofiles, csv_reader):
# Read in audio file and calculate it's RMS
x, self.fs, _ = sndio.read(fp)
x_rms = np.sqrt(np.mean(x**2))
self.lists[-1].append(x)
self.listsRMS[-1].append(x_rms)
self.listsString[-1].append(words)
# Shuffle order of sentence presentation
self.availableSentenceInds = list(range(len(self.lists[0])))
random.shuffle(self.availableSentenceInds)
def loadNoise(self, noiseFilepath):
'''
Read noise samples and calculate the RMS of the signal
'''
x, _, _ = sndio.read(noiseFilepath)
self.noise = x
self.noise_rms = np.sqrt(np.mean(self.noise**2))
def unsetPageLoaded(self):
self.pageLoaded = False
self.partPageLoaded = False
self.clinPageLoaded = False
def setPageLoaded(self, msg):
if msg['data'] == "clinician":
self.clinPageLoaded = True
else:
self.partPageLoaded = True
self.pageLoaded = self.clinPageLoaded and self.partPageLoaded
def generateTrial(self, snr):
currentSentenceInd = self.availableSentenceInds.pop(0)
# Convert desired SNR to dB FS
snr_fs = 10**(snr/20)
# Get speech data
x = self.lists[0][currentSentenceInd]
x_rms = self.listsRMS[0][currentSentenceInd]
self.currentWords = self.listsString[0][currentSentenceInd]
# Get noise data
noiseLen = x.size + self.fs
start = random.randint(0, self.noise.size-noiseLen)
end = start + noiseLen
x_noise = self.noise[start:end]
# Calculate RMS of noise
noise_rms = np.sqrt(np.mean(x_noise**2))
# Scale noise to match the RMS of the speech
x_noise = x_noise*(x_rms/noise_rms)
y = x_noise
# Set speech to start 500ms after the noise, scaled to the desired SNR
sigStart = round(self.fs/2.)
y[sigStart:sigStart+x.size] += x*snr_fs
return y
def submitMatResponse(self, msg):
'''
Get and store participant response for current trial
'''
self.response = [x.upper() for x in msg['resp']]
self.newResp = True
def saveState(self, out="mat_state.pkl"):
toSave = ['listsRMS', 'y', 'currentList', 'slope', 'snr', 'snrTrack',
'direction', 'noise_rms', 'i', 'currentWords', 'usedLists',
'availableSentenceInds', 'trialN', 'listsString', 'noise',
'fs', 'nCorrect', 'loadedLists', 'lists', 'listN',
'wordsCorrect', 'responses', 'presentedWords']
saveDict = {k:self.__dict__[k] for k in toSave}
with open(out, 'wb') as f:
dill.dump(saveDict, f)
def manualSave(self, msg):
'''
Get and store participant response for current trial
'''
filepath = msg['data']
self.saveState(out=filepath)
def loadStateSocketHandle(self, msg):
filepath = msg['data']
self.loadState(filepath)
def loadState(self, filepath):
with open(filepath, 'rb') as f:
self.__dict__.update(dill.load(f))
def run(self):
'''
This function is called when the thread starts
'''
return self.testLoop()