Implementing interleaved adaptive track for behavioral measure
This commit is contained in:
+161
-95
@@ -71,55 +71,39 @@ class MatTestThread(BaseThread):
|
||||
self.listDir = listFolder
|
||||
|
||||
|
||||
self.reduction_coef = np.load(red_coef)*np.load(cal_coef)
|
||||
self.adaptiveTrack = AdaptiveTrack(red_coef, cal_coef)
|
||||
self.listN = int(listN)
|
||||
self.loadedLists = []
|
||||
self.lists = []
|
||||
self.listsRMS = []
|
||||
self.listsString = []
|
||||
self.presentedWords = []
|
||||
self.noise = None
|
||||
self.noise_rms = None
|
||||
self.fs = None
|
||||
self.nCorrect = None
|
||||
|
||||
self.response = ['','','','','']
|
||||
self.responses = []
|
||||
self.presentedWords = []
|
||||
self.nCorrect = None
|
||||
self.snr = 0.0
|
||||
self.direction = 0
|
||||
# Record SNRs presented with each trial of the adaptive track
|
||||
self.snrTrack = np.empty(30)
|
||||
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
|
||||
self.snrTrack[:] = np.nan
|
||||
self.snrTrack[0] = 0.0
|
||||
# Count number of presented trials
|
||||
self.trialN = 1
|
||||
self.currentWords = []
|
||||
self.srt_50 = None
|
||||
self.s_50 = None
|
||||
|
||||
self.currentList = None
|
||||
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
|
||||
self.trialN = 0
|
||||
|
||||
self.availableSentenceInds = []
|
||||
self.usedLists = []
|
||||
|
||||
# Adaptive track parameters
|
||||
self.slope = 0.15
|
||||
self.i = 0
|
||||
|
||||
# Plotting parameters
|
||||
self.img = io.BytesIO()
|
||||
self.img.seek(0)
|
||||
self.img.truncate(0)
|
||||
|
||||
super(MatTestThread, self).__init__('mat_test',
|
||||
sessionFilepath=sessionFilepath,
|
||||
socketio=socketio,
|
||||
participant=participant)
|
||||
|
||||
self.toSave = ['listsRMS', 'y', 'currentList', 'slope', 'snr', 'snrTrack',
|
||||
'direction', 'noise_rms', 'i', 'currentWords', 'usedLists',
|
||||
'availableSentenceInds', 'trialN', 'listsString', 'fs',
|
||||
'nCorrect', 'loadedLists', 'lists', 'listN', 'wordsCorrect',
|
||||
'responses', 'presentedWords', 'srt_50', 's_50']
|
||||
self.toSave = ['presentedWords', 'responses', 'srt_50', 's_50',
|
||||
'wordsCorrect', 'trialN', 'test_name', 'backupFilepath',
|
||||
'currentWords', 'nCorrect', 'availableSentenceInds',
|
||||
'lists', 'listsRMS', 'listsString']
|
||||
self.toFinalise = ['snrTrack', 'trialN', 'wordsCorrect',
|
||||
'presentedWords', 'responses', 'srt_50', 's_50']
|
||||
|
||||
@@ -130,6 +114,8 @@ class MatTestThread(BaseThread):
|
||||
|
||||
self.loadNoise(noiseFilepath, noiseRMSFilepath)
|
||||
|
||||
self.dev_mode = True
|
||||
|
||||
|
||||
def displayInstructions(self):
|
||||
self.socketio.emit('display_instructions', namespace='/main')
|
||||
@@ -143,8 +129,19 @@ class MatTestThread(BaseThread):
|
||||
self.displayInstructions()
|
||||
self.waitForPartReady()
|
||||
while not self.finishTest and not self._stopevent.isSet() and len(self.availableSentenceInds):
|
||||
self.plotSNR()
|
||||
self.y = self.generateTrial(self.snr)
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plot_url = self.adaptiveTrack.plotSNR(self.wordsCorrect)
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
# Get the index of the sentence to be played for the current trial
|
||||
currentSentenceInd = self.availableSentenceInds.pop(0)
|
||||
# Generate trial audio
|
||||
self.y = self.adaptiveTrack.generateTrial(
|
||||
self.lists[0][currentSentenceInd],
|
||||
self.listsRMS[0][currentSentenceInd]
|
||||
)
|
||||
# Define words presented in the current trial
|
||||
self.currentWords = self.listsString[0][currentSentenceInd]
|
||||
|
||||
self.playStimulus(self.y, self.fs)
|
||||
self.waitForResponse()
|
||||
self.checkSentencesAvailable()
|
||||
@@ -152,14 +149,18 @@ class MatTestThread(BaseThread):
|
||||
break
|
||||
if self._stopevent.isSet():
|
||||
return
|
||||
self.calcSNR()
|
||||
self.adaptiveTrack.calcSNR(self.nCorrect)
|
||||
self.checkSentencesAvailable()
|
||||
self.saveState(out=self.backupFilepath)
|
||||
self.trialN += 1
|
||||
self.saveState(out=self.backupFilepath)
|
||||
if not self._stopevent.isSet():
|
||||
self.unsetPageLoaded()
|
||||
self.socketio.emit('processing-complete', {'data': ''}, namespace='/main')
|
||||
self.waitForPageLoad()
|
||||
self.plotSNR()
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plot_url = self.adaptiveTrack.plotSNR()
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
self.fitLogistic()
|
||||
self.waitForFinalise()
|
||||
|
||||
@@ -271,36 +272,8 @@ class MatTestThread(BaseThread):
|
||||
self.socketio.emit("mat_mle_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
|
||||
|
||||
def plotSNR(self):
|
||||
'''
|
||||
'''
|
||||
plt.clf()
|
||||
plt.plot(self.snrTrack, 'bo-')
|
||||
plt.ylim([20.0, -20.0])
|
||||
plt.xticks(np.arange(30))
|
||||
plt.xlim([-1, self.trialN])
|
||||
plt.xlabel("Trial N")
|
||||
plt.ylabel("SNR (dB)")
|
||||
plt.title("Adaptive track")
|
||||
for i, txt in enumerate(self.snrTrack[:self.trialN]):
|
||||
plt.annotate("{0}/{1}".format(
|
||||
np.sum(self.wordsCorrect[i]).astype(int),
|
||||
self.wordsCorrect[i].size),
|
||||
(i, self.snrTrack[i]),
|
||||
xytext=(0, 13),
|
||||
va="center",
|
||||
ha="center",
|
||||
textcoords='offset points'
|
||||
)
|
||||
dpi = 300
|
||||
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
|
||||
self.img.seek(0)
|
||||
plot_url = base64.b64encode(self.img.getvalue()).decode()
|
||||
plot_url = "data:image/png;base64,{}".format(plot_url)
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
|
||||
def checkSentencesAvailable(self):
|
||||
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
|
||||
# If all sentences in the current list have been presented...
|
||||
if not self.availableSentenceInds:
|
||||
# Set subsequent list as the current list
|
||||
@@ -314,30 +287,6 @@ class MatTestThread(BaseThread):
|
||||
self.availableSentenceInds = list(range(len(self.lists[0])))
|
||||
random.shuffle(self.availableSentenceInds)
|
||||
|
||||
def calcSNR(self):
|
||||
'''
|
||||
'''
|
||||
self.presentedWords.append(self.currentWords)
|
||||
self.responses.append(self.response)
|
||||
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
|
||||
print("Current words: {}".format(self.currentWords))
|
||||
print("Response: {}".format(self.response))
|
||||
print("Correct: {}".format(correct))
|
||||
self.nCorrect = np.sum(correct)/correct.size
|
||||
prevSNR = self.snr
|
||||
self.snr -= (((1.5*1.41**-self.i)*(self.nCorrect - 0.5))/self.slope)
|
||||
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
|
||||
if self.direction != currentDirection:
|
||||
if currentDirection == 0:
|
||||
pass
|
||||
else:
|
||||
if self.direction != 0:
|
||||
self.i += 1
|
||||
self.direction = currentDirection
|
||||
self.checkSentencesAvailable()
|
||||
self.snrTrack[self.trialN] = self.snr
|
||||
self.wordsCorrect[self.trialN-1] = correct
|
||||
self.trialN += 1
|
||||
|
||||
def playStimulusSocketHandle(self):
|
||||
self.playStimulus(self.y, self.fs)
|
||||
@@ -385,18 +334,87 @@ class MatTestThread(BaseThread):
|
||||
'''
|
||||
Read noise samples and calculate the RMS of the signal
|
||||
'''
|
||||
self.noise = PySndfile(noiseFilepath, 'r')
|
||||
self.noise_rms = np.load(noiseRMSFilepath)
|
||||
self.adaptiveTrack.setNoise(
|
||||
PySndfile(noiseFilepath, 'r'),
|
||||
np.load(noiseRMSFilepath)
|
||||
)
|
||||
|
||||
|
||||
def generateTrial(self, snr):
|
||||
currentSentenceInd = self.availableSentenceInds.pop(0)
|
||||
|
||||
|
||||
def submitMatResponse(self, msg):
|
||||
'''
|
||||
Get and store participant response for current trial
|
||||
'''
|
||||
self.response = [x.upper() for x in msg['resp']]
|
||||
self.responses.append(self.response)
|
||||
self.newResp = True
|
||||
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
|
||||
self.nCorrect = np.sum(correct)/correct.size
|
||||
self.wordsCorrect[self.trialN] = correct
|
||||
self.presentedWords.append(self.currentWords)
|
||||
|
||||
|
||||
def loadState(self, filepath):
|
||||
'''
|
||||
Restore thread state from a saved session filepath
|
||||
'''
|
||||
with open(filepath, 'rb') as f:
|
||||
state = dill.load(f)
|
||||
self.adaptiveTrack.loadFromDict(state.pop('adaptiveTrack'))
|
||||
self.__dict__.update(state)
|
||||
|
||||
def saveState(self, out=None):
|
||||
'''
|
||||
Save the state of the thread to a pickle file
|
||||
'''
|
||||
if not out:
|
||||
out = "{}_state.pkl".format(self.test_name)
|
||||
saveDict = {k:self.__dict__[k] for k in self.toSave}
|
||||
atDict = self.adaptiveTrack.createSaveDict()
|
||||
saveDict['adaptiveTrack'] = atDict
|
||||
with open(out, 'wb') as f:
|
||||
dill.dump(saveDict, f)
|
||||
|
||||
class AdaptiveTrack():
|
||||
'''
|
||||
'''
|
||||
def __init__(self, red_coef, cal_coef):
|
||||
'''
|
||||
'''
|
||||
self.snr = 0.0
|
||||
self.direction = 0
|
||||
# Record SNRs presented with each trial of the adaptive track
|
||||
self.snrTrack = np.empty(30)
|
||||
self.snrTrack[:] = np.nan
|
||||
self.snrTrack[0] = 0.0
|
||||
# Count number of presented trials
|
||||
self.trialN = 1
|
||||
self.reduction_coef = np.load(red_coef)*np.load(cal_coef)
|
||||
|
||||
# Adaptive track parameters
|
||||
self.slope = 0.15
|
||||
self.i = 0
|
||||
|
||||
self.fs = 44100
|
||||
|
||||
# Plotting parameters
|
||||
self.img = io.BytesIO()
|
||||
self.img.seek(0)
|
||||
self.img.truncate(0)
|
||||
|
||||
def setNoise(self, noise, noise_rms):
|
||||
self.noise = noise
|
||||
self.noise_rms = noise_rms
|
||||
|
||||
def processResponse(resp):
|
||||
pass
|
||||
|
||||
|
||||
def generateTrial(self, x, x_rms):
|
||||
# Convert desired SNR to dB FS
|
||||
snr_fs = 10**(snr/20.)
|
||||
snr_fs = 10**(self.snr/20.)
|
||||
# Get speech data
|
||||
x = self.lists[0][currentSentenceInd]
|
||||
x_rms = self.listsRMS[0][currentSentenceInd]
|
||||
self.currentWords = self.listsString[0][currentSentenceInd]
|
||||
# Get noise data
|
||||
noiseLen = x.size + self.fs*2.5
|
||||
start = random.randint(0, self.noise.frames()-noiseLen)
|
||||
@@ -413,9 +431,57 @@ class MatTestThread(BaseThread):
|
||||
return y
|
||||
|
||||
|
||||
def submitMatResponse(self, msg):
|
||||
def calcSNR(self, nCorrect):
|
||||
'''
|
||||
Get and store participant response for current trial
|
||||
'''
|
||||
self.response = [x.upper() for x in msg['resp']]
|
||||
self.newResp = True
|
||||
prevSNR = self.snr
|
||||
self.snr -= (((1.5*1.41**-self.i)*(nCorrect - 0.5))/self.slope)
|
||||
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
|
||||
if self.direction != currentDirection:
|
||||
if currentDirection == 0:
|
||||
pass
|
||||
else:
|
||||
if self.direction != 0:
|
||||
self.i += 1
|
||||
self.direction = currentDirection
|
||||
self.snrTrack[self.trialN] = self.snr
|
||||
self.trialN += 1
|
||||
|
||||
|
||||
def plotSNR(self, wordsCorrect):
|
||||
'''
|
||||
'''
|
||||
plt.clf()
|
||||
plt.plot(self.snrTrack, 'bo-')
|
||||
plt.ylim([20.0, -20.0])
|
||||
plt.xticks(np.arange(30))
|
||||
plt.xlim([-1, self.trialN])
|
||||
plt.xlabel("Trial N")
|
||||
plt.ylabel("SNR (dB)")
|
||||
plt.title("Adaptive track")
|
||||
for i, txt in enumerate(self.snrTrack[:self.trialN]):
|
||||
plt.annotate("{0}/{1}".format(
|
||||
np.sum(wordsCorrect[i]).astype(int),
|
||||
wordsCorrect[i].size),
|
||||
(i, self.snrTrack[i]),
|
||||
xytext=(0, 13),
|
||||
va="center",
|
||||
ha="center",
|
||||
textcoords='offset points'
|
||||
)
|
||||
dpi = 300
|
||||
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
|
||||
self.img.seek(0)
|
||||
plot_url = base64.b64encode(self.img.getvalue()).decode()
|
||||
plot_url = "data:image/png;base64,{}".format(plot_url)
|
||||
return plot_url
|
||||
|
||||
def createSaveDict(self):
|
||||
toSave = ['snr', 'direction', 'snrTrack', 'trialN', 'reduction_coef', 'slope',
|
||||
'i', 'fs']
|
||||
saveDict = {k:self.__dict__[k] for k in toSave}
|
||||
return saveDict
|
||||
|
||||
|
||||
def loadFromDict(self, stateDict):
|
||||
self.__dict__.update(stateDict)
|
||||
|
||||
+5
-2
@@ -177,7 +177,7 @@ class BaseThread(Thread):
|
||||
if not self.dev_mode:
|
||||
play_wav(wav_file)
|
||||
else:
|
||||
play_wav('./test.wav')
|
||||
play_wav('./da_stim/DA_170.wav')
|
||||
|
||||
self.socketio.emit("{}_stim_done".format(test_name), namespace="/main")
|
||||
|
||||
@@ -189,7 +189,10 @@ class BaseThread(Thread):
|
||||
self.newResp = False
|
||||
self.socketio.emit("stim_playing", namespace="/main")
|
||||
# Play audio
|
||||
sd.play(y, fs, blocking=True)
|
||||
if not self.dev_mode:
|
||||
sd.play(y, fs, blocking=True)
|
||||
else:
|
||||
self.play_wav('./da_stim/DA_170.wav', '')
|
||||
self.socketio.emit("stim_done", namespace="/main")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user