Fixed psychometric function plotting for behavioural test, trial 1 behavioural results saving bug still exists
This commit is contained in:
+33
-22
@@ -6,7 +6,7 @@ import dill
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
from scipy.optimize import minimize, least_squares
|
||||
from scipy.optimize import minimize, least_squares, curve_fit
|
||||
import csv
|
||||
from shutil import copy2
|
||||
import sys
|
||||
@@ -253,7 +253,9 @@ class MatTestThread(BaseThread):
|
||||
try:
|
||||
a = np.concatenate(res)
|
||||
a[np.logical_or(a == 0.0, np.isnan(a))] = np.finfo(float).eps
|
||||
out = -np.sum(np.log(a))
|
||||
out = -np.log(np.sum(np.log(a)))
|
||||
if np.isnan(out):
|
||||
out = 999999999.0
|
||||
except:
|
||||
set_trace()
|
||||
return out
|
||||
@@ -262,25 +264,32 @@ class MatTestThread(BaseThread):
|
||||
def fitLogistic(self):
|
||||
'''
|
||||
'''
|
||||
self.wordsCorrect = np.concatenate([x.getWordsCorrect() for x in self.adaptiveTracks])
|
||||
self.trackSNR = np.concatenate([x.getSNRTrack() for x in self.adaptiveTracks])
|
||||
self.wordsCorrect = np.concatenate([x.getWordsCorrect() for x in self.adaptiveTracks])[1:]
|
||||
self.trackSNR = np.concatenate([x.getSNRTrack() for x in self.adaptiveTracks])[:-1]
|
||||
inds = np.argsort(self.trackSNR)
|
||||
wcs = self.wordsCorrect[inds]
|
||||
stsnr = self.trackSNR[inds]
|
||||
|
||||
res = least_squares(
|
||||
self.logisticFuncLiklihood,
|
||||
np.array([np.median(self.trackSNR),1.0]),
|
||||
args=()
|
||||
)
|
||||
if not res.success:
|
||||
logger.error("Logistic function fitting failed. SRT and slope estimate results will be incorrect")
|
||||
breakpoint()
|
||||
#res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
|
||||
x = stsnr
|
||||
y = wcs.sum(axis=1)/5.
|
||||
popt, pcov = curve_fit(self.logisticFunction, x, y, p0=[np.median(self.trackSNR), 0.1], bounds=([-30.0, 0.0001], [30.0, 100.0]), method='dogbox')
|
||||
srt_50, s_50 = popt
|
||||
|
||||
# res = least_squares(
|
||||
# self.logisticFuncLiklihood,
|
||||
# np.array([np.median(self.trackSNR),0.01]),
|
||||
# args=()
|
||||
# )
|
||||
# srt_50, s_50 = res.x
|
||||
# if not res.success:
|
||||
# logger.error("Logistic function fitting failed. SRT and slope estimate results will be incorrect")
|
||||
# res = minimize(self.logisticFuncLiklihood, np.array([np.mean(self.trackSNR),1.0]))
|
||||
percent_correct = (np.sum(self.wordsCorrect, axis=1)/self.wordsCorrect.shape[1])*100.
|
||||
sortedSNRind = np.argsort(-self.trackSNR)
|
||||
sortedSNRind = np.argsort(self.trackSNR)
|
||||
sortedSNR = self.trackSNR[sortedSNRind]
|
||||
sortedPC = percent_correct[sortedSNRind]
|
||||
|
||||
x = np.linspace(np.min(sortedSNR)-5, np.max(sortedSNR)+3, 3000)
|
||||
srt_50, s_50 = res.x
|
||||
x_y = self.logisticFunction(x, srt_50, s_50)
|
||||
x_y *= 100.
|
||||
# np.savez('./plot.npz', x, x_y*100., sortedSNR, sortedPC)
|
||||
@@ -292,28 +301,29 @@ class MatTestThread(BaseThread):
|
||||
|
||||
#plt.plot(sortedSNR, sortedPC, "x")
|
||||
#sbnplot = sns.relplot(data=pd.DataFrame(x_y*100., x), kind="line")
|
||||
plt.clf()
|
||||
axes = plt.gca()
|
||||
srtLine, = axes.plot([srt_50,srt_50], [-50,50.], 'r--')
|
||||
axes.plot([-50.,srt_50], [50.,50.], 'r--')
|
||||
x_vals = np.array(axes.get_xlim())
|
||||
s_50 *= 100.
|
||||
b = 50. - s_50 * srt_50
|
||||
y_vals = s_50 * x_vals + b
|
||||
|
||||
|
||||
wc = self.wordsCorrect.sum(axis=1)*(100/5.)
|
||||
#wc = words_correct*(100)
|
||||
plt.clf()
|
||||
axes = plt.gca()
|
||||
points = plt.plot(self.trackSNR, wc, marker='x', color='r',
|
||||
points = plt.plot(sortedSNR, sortedPC, marker='x', color='r',
|
||||
linestyle='None')
|
||||
psycLine, = axes.plot(x, x_y)
|
||||
plt.title("Predicted psychometric function")
|
||||
plt.xlabel("SNR (dB)")
|
||||
plt.ylabel("% Correct")
|
||||
plt.xlim(x.min(), x.max())
|
||||
plt.ylim(x_y.min(), x_y.max())
|
||||
plt.ylim(x_y.min()-5, x_y.max()+5)
|
||||
plt.yticks(np.arange(5)*25.)
|
||||
x_vals = np.array(axes.get_xlim())
|
||||
y_point = self.logisticFunction(srt_50, srt_50, s_50)*100.
|
||||
s_50 *= 100
|
||||
c = y_point - s_50 * srt_50
|
||||
y_vals = s_50 * x_vals + c
|
||||
slopeLine, = axes.plot(x_vals, y_vals, '--')
|
||||
ticks = (np.arange((x.max()-x.min())/2.5)*2.5)+(2.5 * round(float(x.min())/2.5))
|
||||
ticks[find_nearest_idx(ticks, srt_50)] = srt_50
|
||||
@@ -327,6 +337,7 @@ class MatTestThread(BaseThread):
|
||||
plot_url = "data:image/png;base64,{}".format(plot_url)
|
||||
self.srt_50, self.s_50 = srt_50, s_50
|
||||
self.socketio.emit("mat_mle_plot_ready", {'data': plot_url}, namespace="/main")
|
||||
breakpoint()
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user