End of night commit
This commit is contained in:
+32
-14
@@ -6,7 +6,7 @@ import dill
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
from scipy.optimize import minimize
|
||||
from scipy.optimize import minimize, least_squares
|
||||
import csv
|
||||
from shutil import copy2
|
||||
import sys
|
||||
@@ -171,6 +171,7 @@ class MatTestThread(BaseThread):
|
||||
|
||||
while not self.finishTest and not self._stopevent.isSet() and len(self.availableSentenceInds) and len(self.trackOrder):
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plt.clf()
|
||||
for at in self.adaptiveTracks:
|
||||
at.plotSNR()
|
||||
self.renderSNRPlot()
|
||||
@@ -212,6 +213,7 @@ class MatTestThread(BaseThread):
|
||||
self.socketio.emit('processing-complete', {'data': ''}, namespace='/main')
|
||||
self.waitForPageLoad()
|
||||
# Plot SNR of current trial to the clinician screen
|
||||
plt.clf()
|
||||
for at in self.adaptiveTracks:
|
||||
at.plotSNR()
|
||||
self.renderSNRPlot()
|
||||
@@ -219,14 +221,13 @@ class MatTestThread(BaseThread):
|
||||
self.waitForFinalise()
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def logisticFunction(L, L_50, s_50):
|
||||
def logisticFunction(L, L_50, s_50, minima=0.0, maxima=1.0):
|
||||
'''
|
||||
Calculate logistic function for SNRs L, 50% SRT point L_50, and slope
|
||||
s_50
|
||||
'''
|
||||
return 1./(1.+np.exp(4.*s_50*(L_50-L)))
|
||||
return (minima+((maxima-minima))*(1./(1.+np.exp(4*s_50*(L_50-L)))))
|
||||
|
||||
|
||||
def logisticFuncLiklihood(self, args):
|
||||
@@ -251,7 +252,7 @@ class MatTestThread(BaseThread):
|
||||
with np.errstate(divide='raise'):
|
||||
try:
|
||||
a = np.concatenate(res)
|
||||
a[a == 0] = np.finfo(float).eps
|
||||
a[np.logical_or(a == 0.0, np.isnan(a))] = np.finfo(float).eps
|
||||
out = -np.sum(np.log(a))
|
||||
except:
|
||||
set_trace()
|
||||
@@ -263,11 +264,21 @@ class MatTestThread(BaseThread):
|
||||
'''
|
||||
self.wordsCorrect = np.concatenate([x.getWordsCorrect() for x in self.adaptiveTracks])
|
||||
self.trackSNR = np.concatenate([x.getSNRTrack() for x in self.adaptiveTracks])
|
||||
res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
|
||||
|
||||
res = least_squares(
|
||||
self.logisticFuncLiklihood,
|
||||
np.array([np.median(self.trackSNR),1.0]),
|
||||
args=()
|
||||
)
|
||||
if not res.success:
|
||||
logger.error("Logistic function fitting failed. SRT and slope estimate results will be incorrect")
|
||||
breakpoint()
|
||||
#res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
|
||||
percent_correct = (np.sum(self.wordsCorrect, axis=1)/self.wordsCorrect.shape[1])*100.
|
||||
sortedSNRind = np.argsort(-self.trackSNR)
|
||||
sortedSNR = self.trackSNR[sortedSNRind]
|
||||
sortedPC = percent_correct[sortedSNRind]
|
||||
|
||||
x = np.linspace(np.min(sortedSNR)-5, np.max(sortedSNR)+3, 3000)
|
||||
srt_50, s_50 = res.x
|
||||
x_y = self.logisticFunction(x, srt_50, s_50)
|
||||
@@ -281,21 +292,28 @@ class MatTestThread(BaseThread):
|
||||
|
||||
#plt.plot(sortedSNR, sortedPC, "x")
|
||||
#sbnplot = sns.relplot(data=pd.DataFrame(x_y*100., x), kind="line")
|
||||
plt.clf()
|
||||
axes = plt.gca()
|
||||
psycLine, = axes.plot(x, x_y)
|
||||
plt.title("Predicted psychometric function")
|
||||
plt.xlabel("SNR (dB)")
|
||||
plt.ylabel("% Correct")
|
||||
srtLine, = axes.plot([srt_50,srt_50], [-50,50.], 'r--')
|
||||
axes.plot([-50.,srt_50], [50.,50.], 'r--')
|
||||
plt.xlim(x.min(), x.max())
|
||||
plt.ylim(x_y.min(), x_y.max())
|
||||
plt.yticks(np.arange(5)*25.)
|
||||
x_vals = np.array(axes.get_xlim())
|
||||
s_50 *= 100.
|
||||
b = 50. - s_50 * srt_50
|
||||
y_vals = s_50 * x_vals + b
|
||||
|
||||
|
||||
wc = self.wordsCorrect.sum(axis=1)*(100/5.)
|
||||
#wc = words_correct*(100)
|
||||
plt.clf()
|
||||
axes = plt.gca()
|
||||
points = plt.plot(self.trackSNR, wc, marker='x', color='r',
|
||||
linestyle='None')
|
||||
psycLine, = axes.plot(x, x_y)
|
||||
plt.title("Predicted psychometric function")
|
||||
plt.xlabel("SNR (dB)")
|
||||
plt.ylabel("% Correct")
|
||||
plt.xlim(x.min(), x.max())
|
||||
plt.ylim(x_y.min(), x_y.max())
|
||||
plt.yticks(np.arange(5)*25.)
|
||||
slopeLine, = axes.plot(x_vals, y_vals, '--')
|
||||
ticks = (np.arange((x.max()-x.min())/2.5)*2.5)+(2.5 * round(float(x.min())/2.5))
|
||||
ticks[find_nearest_idx(ticks, srt_50)] = srt_50
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
<select class="form-control" name="mode" id="mode">
|
||||
<option>Familiarisation</option>
|
||||
<option>Testing</option>
|
||||
<option>Testingssjks</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group d-flex justify-content-center">
|
||||
|
||||
Reference in New Issue
Block a user