Implementing interleaved adaptive track for behavioral measure

This commit is contained in:
2019-01-24 17:29:35 +00:00
parent cea97424ab
commit 0166d97d9c
2 changed files with 166 additions and 97 deletions
+161 -95
View File
@@ -71,55 +71,39 @@ class MatTestThread(BaseThread):
self.listDir = listFolder
self.reduction_coef = np.load(red_coef)*np.load(cal_coef)
self.adaptiveTrack = AdaptiveTrack(red_coef, cal_coef)
self.listN = int(listN)
self.loadedLists = []
self.lists = []
self.listsRMS = []
self.listsString = []
self.presentedWords = []
self.noise = None
self.noise_rms = None
self.fs = None
self.nCorrect = None
self.response = ['','','','','']
self.responses = []
self.presentedWords = []
self.nCorrect = None
self.snr = 0.0
self.direction = 0
# Record SNRs presented with each trial of the adaptive track
self.snrTrack = np.empty(30)
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
self.snrTrack[:] = np.nan
self.snrTrack[0] = 0.0
# Count number of presented trials
self.trialN = 1
self.currentWords = []
self.srt_50 = None
self.s_50 = None
self.currentList = None
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
self.trialN = 0
self.availableSentenceInds = []
self.usedLists = []
# Adaptive track parameters
self.slope = 0.15
self.i = 0
# Plotting parameters
self.img = io.BytesIO()
self.img.seek(0)
self.img.truncate(0)
super(MatTestThread, self).__init__('mat_test',
sessionFilepath=sessionFilepath,
socketio=socketio,
participant=participant)
self.toSave = ['listsRMS', 'y', 'currentList', 'slope', 'snr', 'snrTrack',
'direction', 'noise_rms', 'i', 'currentWords', 'usedLists',
'availableSentenceInds', 'trialN', 'listsString', 'fs',
'nCorrect', 'loadedLists', 'lists', 'listN', 'wordsCorrect',
'responses', 'presentedWords', 'srt_50', 's_50']
self.toSave = ['presentedWords', 'responses', 'srt_50', 's_50',
'wordsCorrect', 'trialN', 'test_name', 'backupFilepath',
'currentWords', 'nCorrect', 'availableSentenceInds',
'lists', 'listsRMS', 'listsString']
self.toFinalise = ['snrTrack', 'trialN', 'wordsCorrect',
'presentedWords', 'responses', 'srt_50', 's_50']
@@ -130,6 +114,8 @@ class MatTestThread(BaseThread):
self.loadNoise(noiseFilepath, noiseRMSFilepath)
self.dev_mode = True
def displayInstructions(self):
self.socketio.emit('display_instructions', namespace='/main')
@@ -143,8 +129,19 @@ class MatTestThread(BaseThread):
self.displayInstructions()
self.waitForPartReady()
while not self.finishTest and not self._stopevent.isSet() and len(self.availableSentenceInds):
self.plotSNR()
self.y = self.generateTrial(self.snr)
# Plot SNR of current trial to the clinician screen
plot_url = self.adaptiveTrack.plotSNR(self.wordsCorrect)
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
# Get the index of the sentence to be played for the current trial
currentSentenceInd = self.availableSentenceInds.pop(0)
# Generate trial audio
self.y = self.adaptiveTrack.generateTrial(
self.lists[0][currentSentenceInd],
self.listsRMS[0][currentSentenceInd]
)
# Define words presented in the current trial
self.currentWords = self.listsString[0][currentSentenceInd]
self.playStimulus(self.y, self.fs)
self.waitForResponse()
self.checkSentencesAvailable()
@@ -152,14 +149,18 @@ class MatTestThread(BaseThread):
break
if self._stopevent.isSet():
return
self.calcSNR()
self.adaptiveTrack.calcSNR(self.nCorrect)
self.checkSentencesAvailable()
self.saveState(out=self.backupFilepath)
self.trialN += 1
self.saveState(out=self.backupFilepath)
if not self._stopevent.isSet():
self.unsetPageLoaded()
self.socketio.emit('processing-complete', {'data': ''}, namespace='/main')
self.waitForPageLoad()
self.plotSNR()
# Plot SNR of current trial to the clinician screen
plot_url = self.adaptiveTrack.plotSNR()
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
self.fitLogistic()
self.waitForFinalise()
@@ -271,36 +272,8 @@ class MatTestThread(BaseThread):
self.socketio.emit("mat_mle_plot_ready", {'data': plot_url}, namespace="/main")
def plotSNR(self):
'''
'''
plt.clf()
plt.plot(self.snrTrack, 'bo-')
plt.ylim([20.0, -20.0])
plt.xticks(np.arange(30))
plt.xlim([-1, self.trialN])
plt.xlabel("Trial N")
plt.ylabel("SNR (dB)")
plt.title("Adaptive track")
for i, txt in enumerate(self.snrTrack[:self.trialN]):
plt.annotate("{0}/{1}".format(
np.sum(self.wordsCorrect[i]).astype(int),
self.wordsCorrect[i].size),
(i, self.snrTrack[i]),
xytext=(0, 13),
va="center",
ha="center",
textcoords='offset points'
)
dpi = 300
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
self.img.seek(0)
plot_url = base64.b64encode(self.img.getvalue()).decode()
plot_url = "data:image/png;base64,{}".format(plot_url)
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
def checkSentencesAvailable(self):
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
# If all sentences in the current list have been presented...
if not self.availableSentenceInds:
# Set subsequent list as the current list
@@ -314,30 +287,6 @@ class MatTestThread(BaseThread):
self.availableSentenceInds = list(range(len(self.lists[0])))
random.shuffle(self.availableSentenceInds)
def calcSNR(self):
'''
'''
self.presentedWords.append(self.currentWords)
self.responses.append(self.response)
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
print("Current words: {}".format(self.currentWords))
print("Response: {}".format(self.response))
print("Correct: {}".format(correct))
self.nCorrect = np.sum(correct)/correct.size
prevSNR = self.snr
self.snr -= (((1.5*1.41**-self.i)*(self.nCorrect - 0.5))/self.slope)
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
if self.direction != currentDirection:
if currentDirection == 0:
pass
else:
if self.direction != 0:
self.i += 1
self.direction = currentDirection
self.checkSentencesAvailable()
self.snrTrack[self.trialN] = self.snr
self.wordsCorrect[self.trialN-1] = correct
self.trialN += 1
def playStimulusSocketHandle(self):
self.playStimulus(self.y, self.fs)
@@ -385,18 +334,87 @@ class MatTestThread(BaseThread):
'''
Read noise samples and calculate the RMS of the signal
'''
self.noise = PySndfile(noiseFilepath, 'r')
self.noise_rms = np.load(noiseRMSFilepath)
self.adaptiveTrack.setNoise(
PySndfile(noiseFilepath, 'r'),
np.load(noiseRMSFilepath)
)
def generateTrial(self, snr):
currentSentenceInd = self.availableSentenceInds.pop(0)
def submitMatResponse(self, msg):
'''
Get and store participant response for current trial
'''
self.response = [x.upper() for x in msg['resp']]
self.responses.append(self.response)
self.newResp = True
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
self.nCorrect = np.sum(correct)/correct.size
self.wordsCorrect[self.trialN] = correct
self.presentedWords.append(self.currentWords)
def loadState(self, filepath):
'''
Restore thread state from a saved session filepath
'''
with open(filepath, 'rb') as f:
state = dill.load(f)
self.adaptiveTrack.loadFromDict(state.pop('adaptiveTrack'))
self.__dict__.update(state)
def saveState(self, out=None):
'''
Save the state of the thread to a pickle file
'''
if not out:
out = "{}_state.pkl".format(self.test_name)
saveDict = {k:self.__dict__[k] for k in self.toSave}
atDict = self.adaptiveTrack.createSaveDict()
saveDict['adaptiveTrack'] = atDict
with open(out, 'wb') as f:
dill.dump(saveDict, f)
class AdaptiveTrack():
'''
'''
def __init__(self, red_coef, cal_coef):
'''
'''
self.snr = 0.0
self.direction = 0
# Record SNRs presented with each trial of the adaptive track
self.snrTrack = np.empty(30)
self.snrTrack[:] = np.nan
self.snrTrack[0] = 0.0
# Count number of presented trials
self.trialN = 1
self.reduction_coef = np.load(red_coef)*np.load(cal_coef)
# Adaptive track parameters
self.slope = 0.15
self.i = 0
self.fs = 44100
# Plotting parameters
self.img = io.BytesIO()
self.img.seek(0)
self.img.truncate(0)
def setNoise(self, noise, noise_rms):
self.noise = noise
self.noise_rms = noise_rms
def processResponse(resp):
pass
def generateTrial(self, x, x_rms):
# Convert desired SNR to dB FS
snr_fs = 10**(snr/20.)
snr_fs = 10**(self.snr/20.)
# Get speech data
x = self.lists[0][currentSentenceInd]
x_rms = self.listsRMS[0][currentSentenceInd]
self.currentWords = self.listsString[0][currentSentenceInd]
# Get noise data
noiseLen = x.size + self.fs*2.5
start = random.randint(0, self.noise.frames()-noiseLen)
@@ -413,9 +431,57 @@ class MatTestThread(BaseThread):
return y
def submitMatResponse(self, msg):
def calcSNR(self, nCorrect):
'''
Get and store participant response for current trial
'''
self.response = [x.upper() for x in msg['resp']]
self.newResp = True
prevSNR = self.snr
self.snr -= (((1.5*1.41**-self.i)*(nCorrect - 0.5))/self.slope)
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
if self.direction != currentDirection:
if currentDirection == 0:
pass
else:
if self.direction != 0:
self.i += 1
self.direction = currentDirection
self.snrTrack[self.trialN] = self.snr
self.trialN += 1
def plotSNR(self, wordsCorrect):
'''
'''
plt.clf()
plt.plot(self.snrTrack, 'bo-')
plt.ylim([20.0, -20.0])
plt.xticks(np.arange(30))
plt.xlim([-1, self.trialN])
plt.xlabel("Trial N")
plt.ylabel("SNR (dB)")
plt.title("Adaptive track")
for i, txt in enumerate(self.snrTrack[:self.trialN]):
plt.annotate("{0}/{1}".format(
np.sum(wordsCorrect[i]).astype(int),
wordsCorrect[i].size),
(i, self.snrTrack[i]),
xytext=(0, 13),
va="center",
ha="center",
textcoords='offset points'
)
dpi = 300
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
self.img.seek(0)
plot_url = base64.b64encode(self.img.getvalue()).decode()
plot_url = "data:image/png;base64,{}".format(plot_url)
return plot_url
def createSaveDict(self):
toSave = ['snr', 'direction', 'snrTrack', 'trialN', 'reduction_coef', 'slope',
'i', 'fs']
saveDict = {k:self.__dict__[k] for k in toSave}
return saveDict
def loadFromDict(self, stateDict):
self.__dict__.update(stateDict)
+5 -2
View File
@@ -177,7 +177,7 @@ class BaseThread(Thread):
if not self.dev_mode:
play_wav(wav_file)
else:
play_wav('./test.wav')
play_wav('./da_stim/DA_170.wav')
self.socketio.emit("{}_stim_done".format(test_name), namespace="/main")
@@ -189,7 +189,10 @@ class BaseThread(Thread):
self.newResp = False
self.socketio.emit("stim_playing", namespace="/main")
# Play audio
sd.play(y, fs, blocking=True)
if not self.dev_mode:
sd.play(y, fs, blocking=True)
else:
self.play_wav('./da_stim/DA_170.wav', '')
self.socketio.emit("stim_done", namespace="/main")