Implemented most of multiple adaptive track code
This commit is contained in:
@@ -14,6 +14,13 @@ server_lock = Lock()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
log = logging.getLogger('engineio')
|
||||
log.setLevel(logging.ERROR)
|
||||
log = logging.getLogger('socketio')
|
||||
log.setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def url_ok(url, port):
|
||||
# Use httplib on Python 2
|
||||
|
||||
+74
-41
@@ -60,18 +60,21 @@ class MatTestThread(BaseThread):
|
||||
'''
|
||||
Thread for running server side matrix test operations
|
||||
'''
|
||||
def __init__(self, listN=3, sessionFilepath=None,
|
||||
def __init__(self, listN=5, sessionFilepath=None,
|
||||
noiseFilepath="./matrix_test/behavioural_stim/stimulus/wav/noise/noise_norm.wav",
|
||||
noiseRMSFilepath="./matrix_test/behavioural_stim/stimulus/rms/noise_rms.npy",
|
||||
listFolder="./matrix_test/behavioural_stim/stimulus/wav/sentence-lists/",
|
||||
red_coef="./calibration/out/reduction_coefficients/mat_red_coef.npy",
|
||||
cal_coef="./calibration/out/calibration_coefficients/mat_cal_coef.npy",
|
||||
track_targets=[0.2, 0.5, 0.8],
|
||||
socketio=None, participant=None):
|
||||
|
||||
self.listDir = listFolder
|
||||
|
||||
self.adaptiveTracks = [AdaptiveTrack(x, red_coef, cal_coef) for x in track_targets]
|
||||
self.trackOrder = []
|
||||
|
||||
|
||||
self.adaptiveTrack = AdaptiveTrack(red_coef, cal_coef)
|
||||
self.listN = int(listN)
|
||||
self.loadedLists = []
|
||||
self.lists = []
|
||||
@@ -89,7 +92,7 @@ class MatTestThread(BaseThread):
|
||||
self.srt_50 = None
|
||||
self.s_50 = None
|
||||
|
||||
self.wordsCorrect = np.full((30, 5), False, dtype=bool)
|
||||
self.wordsCorrect = np.full((90, 5), False, dtype=bool)
|
||||
self.trialN = 0
|
||||
|
||||
self.availableSentenceInds = []
|
||||
@@ -104,11 +107,11 @@ class MatTestThread(BaseThread):
|
||||
socketio=socketio,
|
||||
participant=participant)
|
||||
|
||||
self.toSave = ['presentedWords', 'responses', 'srt_50', 's_50',
|
||||
self.toSave = ['trackOrder', 'presentedWords', 'responses', 'srt_50', 's_50',
|
||||
'wordsCorrect', 'trialN', 'test_name', 'backupFilepath',
|
||||
'currentWords', 'nCorrect', 'availableSentenceInds',
|
||||
'lists', 'listsRMS', 'listsString']
|
||||
self.toFinalise = ['snrTrack', 'trialN', 'wordsCorrect',
|
||||
self.toFinalise = ['snrTrack', 'trackOrder', 'trialN', 'wordsCorrect',
|
||||
'presentedWords', 'responses', 'srt_50', 's_50']
|
||||
|
||||
# Attach messages from gui to class methods
|
||||
@@ -118,12 +121,25 @@ class MatTestThread(BaseThread):
|
||||
|
||||
self.loadNoise(noiseFilepath, noiseRMSFilepath)
|
||||
|
||||
self.dev_mode = True
|
||||
self.dev_mode = False
|
||||
|
||||
|
||||
def displayInstructions(self):
|
||||
self.socketio.emit('display_instructions', namespace='/main')
|
||||
|
||||
|
||||
def renderSNRPlot(self):
|
||||
dpi = 300
|
||||
|
||||
maxTrialN = np.max([x.trialN for x in self.adaptiveTracks])
|
||||
plt.xlim([-1, maxTrialN])
|
||||
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
|
||||
self.img.seek(0)
|
||||
plot_url = base64.b64encode(self.img.getvalue()).decode()
|
||||
plot_url = "data:image/png;base64,{}".format(plot_url)
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
plt.clf()
|
||||
|
||||
def testLoop(self):
|
||||
'''
|
||||
Main loop for iteratively finding the SRT
|
||||
@@ -134,12 +150,15 @@ class MatTestThread(BaseThread):
|
||||
self.waitForPartReady()
|
||||
while not self.finishTest and not self._stopevent.isSet() and len(self.availableSentenceInds):
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plot_url = self.adaptiveTrack.plotSNR(self.wordsCorrect)
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
for at in self.adaptiveTracks:
|
||||
at.plotSNR()
|
||||
self.renderSNRPlot()
|
||||
# Get the index of the sentence to be played for the current trial
|
||||
currentSentenceInd = self.availableSentenceInds.pop(0)
|
||||
# Generate trial audio
|
||||
self.y = self.adaptiveTrack.generateTrial(
|
||||
# Get the index of the current adaptive track to use
|
||||
self.adTrInd = self.trackOrder.pop(0)
|
||||
# Generate trial audioself.wordsCorrect
|
||||
self.y = self.adaptiveTracks[self.adTrInd].generateTrial(
|
||||
self.lists[0][currentSentenceInd],
|
||||
self.listsRMS[0][currentSentenceInd]
|
||||
)
|
||||
@@ -153,7 +172,7 @@ class MatTestThread(BaseThread):
|
||||
break
|
||||
if self._stopevent.isSet():
|
||||
return
|
||||
self.adaptiveTrack.calcSNR(self.nCorrect)
|
||||
self.adaptiveTracks[self.adTrInd].calcSNR(self.nCorrect)
|
||||
self.checkSentencesAvailable()
|
||||
self.saveState(out=self.backupFilepath)
|
||||
self.trialN += 1
|
||||
@@ -163,8 +182,9 @@ class MatTestThread(BaseThread):
|
||||
self.socketio.emit('processing-complete', {'data': ''}, namespace='/main')
|
||||
self.waitForPageLoad()
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plot_url = self.adaptiveTrack.plotSNR(self.wordsCorrect)
|
||||
self.socketio.emit("mat_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
for at in self.adaptiveTracks:
|
||||
at.plotSNR()
|
||||
self.renderSNRPlot()
|
||||
self.fitLogistic()
|
||||
self.waitForFinalise()
|
||||
|
||||
@@ -226,8 +246,9 @@ class MatTestThread(BaseThread):
|
||||
def fitLogistic(self):
|
||||
'''
|
||||
'''
|
||||
self.wordsCorrect = self.wordsCorrect[:self.trialN].astype(float)
|
||||
self.trackSNR = self.adaptiveTrack.snrTrack[:self.trialN]
|
||||
self.wordsCorrect = [x.getWordsCorrect() for x in self.adaptiveTracks]
|
||||
self.trackSNR = [x.getSNRTrack() for x in self.adaptiveTracks]
|
||||
set_trace()
|
||||
res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
|
||||
percent_correct = (np.sum(self.wordsCorrect, axis=1)/self.wordsCorrect.shape[1])*100.
|
||||
sortedSNRind = np.argsort(-self.trackSNR)
|
||||
@@ -286,7 +307,6 @@ class MatTestThread(BaseThread):
|
||||
del self.listsString[0]
|
||||
if not len(self.lists):
|
||||
self.finishTest = True
|
||||
self.wordsCorrect[self.trialN-1] = correct
|
||||
return None
|
||||
self.availableSentenceInds = list(range(len(self.lists[0])))
|
||||
random.shuffle(self.availableSentenceInds)
|
||||
@@ -329,6 +349,12 @@ class MatTestThread(BaseThread):
|
||||
self.listsRMS[-1].append(x_rms)
|
||||
self.listsString[-1].append(words)
|
||||
|
||||
# Number of trials to split between adaptive tracks
|
||||
n = len(self.lists[0])*self.listN
|
||||
tn = len(self.adaptiveTracks)
|
||||
self.trackOrder = list(np.repeat(np.arange(tn), np.floor(n/tn)))
|
||||
random.shuffle(self.trackOrder)
|
||||
|
||||
# Shuffle order of sentence presentation
|
||||
self.availableSentenceInds = list(range(len(self.lists[0])))
|
||||
random.shuffle(self.availableSentenceInds)
|
||||
@@ -338,10 +364,10 @@ class MatTestThread(BaseThread):
|
||||
'''
|
||||
Read noise samples and calculate the RMS of the signal
|
||||
'''
|
||||
self.adaptiveTrack.setNoise(
|
||||
PySndfile(noiseFilepath, 'r'),
|
||||
np.load(noiseRMSFilepath)
|
||||
)
|
||||
noise = PySndfile(noiseFilepath, 'r')
|
||||
noise_rms = np.load(noiseRMSFilepath)
|
||||
for ind, _ in enumerate(self.adaptiveTracks):
|
||||
self.adaptiveTracks[ind].setNoise(noise, noise_rms)
|
||||
|
||||
|
||||
|
||||
@@ -356,6 +382,7 @@ class MatTestThread(BaseThread):
|
||||
correct = np.array([x == y for x, y in zip(self.currentWords, self.response)])
|
||||
self.nCorrect = np.sum(correct)/correct.size
|
||||
self.wordsCorrect[self.trialN] = correct
|
||||
self.adaptiveTracks[self.adTrInd].wordsCorrect[self.adaptiveTracks[self.adTrInd].trialN] = correct
|
||||
self.presentedWords.append(self.currentWords)
|
||||
|
||||
|
||||
@@ -365,7 +392,9 @@ class MatTestThread(BaseThread):
|
||||
'''
|
||||
with open(filepath, 'rb') as f:
|
||||
state = dill.load(f)
|
||||
self.adaptiveTrack.loadFromDict(state.pop('adaptiveTrack'))
|
||||
aTrack = state.pop('adaptiveTracks')
|
||||
for ind, aTrackDict in enumerate(aTrack):
|
||||
self.adaptiveTracks[ind].loadFromDict(aTrackDict)
|
||||
self.__dict__.update(state)
|
||||
|
||||
def saveState(self, out=None):
|
||||
@@ -375,30 +404,35 @@ class MatTestThread(BaseThread):
|
||||
if not out:
|
||||
out = "{}_state.pkl".format(self.test_name)
|
||||
saveDict = {k:self.__dict__[k] for k in self.toSave}
|
||||
atDict = self.adaptiveTrack.createSaveDict()
|
||||
saveDict['adaptiveTrack'] = atDict
|
||||
|
||||
saveDict['adaptiveTracks'] = []
|
||||
for ind, _ in enumerate(self.adaptiveTracks):
|
||||
atDict = self.adaptiveTracks[ind].createSaveDict()
|
||||
saveDict['adaptiveTracks'].append(atDict)
|
||||
with open(out, 'wb') as f:
|
||||
dill.dump(saveDict, f)
|
||||
|
||||
class AdaptiveTrack():
|
||||
'''
|
||||
'''
|
||||
def __init__(self, red_coef, cal_coef):
|
||||
def __init__(self, target, red_coef, cal_coef):
|
||||
'''
|
||||
'''
|
||||
self.snr = 0.0
|
||||
self.direction = 0
|
||||
# Record SNRs presented with each trial of the adaptive track
|
||||
self.snrTrack = np.empty(30)
|
||||
self.snrTrack = np.empty(90)
|
||||
self.snrTrack[:] = np.nan
|
||||
self.snrTrack[0] = 0.0
|
||||
# Count number of presented trials
|
||||
self.trialN = 1
|
||||
self.trialN = 0
|
||||
self.reduction_coef = np.load(red_coef)*np.load(cal_coef)
|
||||
self.wordsCorrect = np.full((90, 5), False, dtype=bool)
|
||||
|
||||
# Adaptive track parameters
|
||||
self.slope = 0.15
|
||||
self.i = 0
|
||||
self.target = target
|
||||
|
||||
self.fs = 44100
|
||||
|
||||
@@ -411,6 +445,12 @@ class AdaptiveTrack():
|
||||
self.noise = noise
|
||||
self.noise_rms = noise_rms
|
||||
|
||||
def getWordsCorrect(self):
|
||||
return self.wordsCorrect[:self.trialN].astype(float)
|
||||
|
||||
def getSNRTrack(self):
|
||||
return self.snrTrack[:self.trialN]
|
||||
|
||||
def processResponse(resp):
|
||||
pass
|
||||
|
||||
@@ -439,7 +479,7 @@ class AdaptiveTrack():
|
||||
'''
|
||||
'''
|
||||
prevSNR = self.snr
|
||||
self.snr -= (((1.5*1.41**-self.i)*(nCorrect - 0.5))/self.slope)
|
||||
self.snr -= (((1.5*1.41**-self.i)*(nCorrect - self.target))/self.slope)
|
||||
currentDirection = np.sign(np.diff([prevSNR, self.snr]))
|
||||
if self.direction != currentDirection:
|
||||
if currentDirection == 0:
|
||||
@@ -450,39 +490,32 @@ class AdaptiveTrack():
|
||||
self.direction = currentDirection
|
||||
self.snrTrack[self.trialN] = self.snr
|
||||
self.trialN += 1
|
||||
print("Track {0} trial N: {1}".format(self.target, self.trialN))
|
||||
|
||||
|
||||
def plotSNR(self, wordsCorrect):
|
||||
def plotSNR(self):
|
||||
'''
|
||||
'''
|
||||
plt.clf()
|
||||
plt.plot(self.snrTrack, 'bo-')
|
||||
plt.plot(self.snrTrack, 'o-')
|
||||
plt.ylim([20.0, -20.0])
|
||||
plt.xticks(np.arange(30))
|
||||
plt.xlim([-1, self.trialN])
|
||||
plt.xticks(np.arange(90))
|
||||
plt.xlabel("Trial N")
|
||||
plt.ylabel("SNR (dB)")
|
||||
plt.title("Adaptive track")
|
||||
for i, txt in enumerate(self.snrTrack[:self.trialN]):
|
||||
plt.annotate("{0}/{1}".format(
|
||||
np.sum(wordsCorrect[i]).astype(int),
|
||||
wordsCorrect[i].size),
|
||||
np.sum(self.wordsCorrect[i]).astype(int),
|
||||
self.wordsCorrect[i].size),
|
||||
(i, self.snrTrack[i]),
|
||||
xytext=(0, 13),
|
||||
va="center",
|
||||
ha="center",
|
||||
textcoords='offset points'
|
||||
)
|
||||
dpi = 300
|
||||
plt.savefig(self.img, format='png', figsize=(800/dpi, 800/dpi), dpi=dpi)
|
||||
self.img.seek(0)
|
||||
plot_url = base64.b64encode(self.img.getvalue()).decode()
|
||||
plot_url = "data:image/png;base64,{}".format(plot_url)
|
||||
return plot_url
|
||||
|
||||
def createSaveDict(self):
|
||||
toSave = ['snr', 'direction', 'snrTrack', 'trialN', 'reduction_coef', 'slope',
|
||||
'i', 'fs']
|
||||
'i', 'fs', 'wordsCorrect']
|
||||
saveDict = {k:self.__dict__[k] for k in toSave}
|
||||
return saveDict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user