Files
BPLabs/gen_participants.py
T

280 lines
9.7 KiB
Python
Executable File

#!/usr/bin/env python3
from pathops import dir_must_exist
import os
import dill
import numpy as np
import pdb
import json
from natsort import natsorted
import random
random.seed(42)
np.random.seed(42)
import itertools
import copy
import logging
from loggerops import create_logger, log_newline
import shutil
import os
from pathlib import Path
from datetime import datetime
import re
from copy import deepcopy
logger = logging.getLogger(__name__)
nowtime = datetime.now()
from config import server, socketio, participants
def set_trace():
import logging
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
log = logging.getLogger('engineio')
log.setLevel(logging.ERROR)
pdb.set_trace()
def find_participants(folder='./participant_data/'):
'''
Returns a tuple of (participant number, participant filepath) for every
participant folder found in directory provided
'''
part_folder = [os.path.join(folder, o) for o in os.listdir(folder)
if os.path.isdir(os.path.join(folder,o))]
for path in part_folder:
part_key = os.path.basename(path)
participants[part_key] = deepcopy(Participant(participant_dir=path))
participants[part_key].load('info')
participants[part_key].load('parameters')
return participants
def gen_participant_num(participants, N = 100):
# generate array of numbers that haven't been taken between 0-100
# if list is empty increment until list isnt empty
# Choose a number
taken_nums = []
for part_key in participants.keys():
participant = participants[part_key]
taken_nums.append(int(participant['info']['number']))
inds = np.arange(N)+1
taken_inds = np.in1d(inds, taken_nums)
inds = inds[~taken_inds]
return inds
class Participant:
def __init__(self, participant_dir=None, number=None, age=None, gender=None, handedness=None, general_notes=None, parameters={}, gen_time=datetime.now()):
'''
'''
dir_must_exist(participant_dir)
self.participant_dir = participant_dir
self.data_paths = {}
self.generate_folder_hierachy()
self.parameters = parameters
self.gen_time = gen_time
self.data = {
"info": {
"number": number,
"age": age,
"gender": gender,
"handedness": handedness,
"general_notes": general_notes
},
"mat_test": {
"notes": ''
},
"eeg_story_train": {
"notes": ''
},
"eeg_mat_train": {
"notes": ''
},
"eeg_test": {
"notes": ''
},
"tone_test": {
"notes": ''
},
"click_test": {
"notes": ''
},
"pta": {
"notes": ''
}
}
self.data['parameters'] = parameters
def generate_folder_hierachy(self):
'''
'''
sub_dirs = ["mat_test", "tone_test", "pta", "click_test", "info",
"eeg_story_train", "eeg_mat_train", "eeg_test",
"eeg_test/stimulus", "parameters"]
for dir_name in sub_dirs:
dn = os.path.join(*dir_name.split('/'))
path = os.path.join(self.participant_dir, dn)
dir_must_exist(path)
self.data_paths[dir_name] = path
def __setitem__(self, key, item):
self.data[key] = item
def __getitem__(self, key):
return self.data[key]
def set_info(self, info):
self.data['info'] = info
def save(self, data_key):
'''
'''
directory = self.data_paths[data_key]
with open(os.path.join(directory, '{}.pkl'.format(data_key)), 'wb') as f:
dill.dump(self.data[data_key], f)
def load(self, data_key):
'''
'''
folder = os.path.join(self.participant_dir, data_key)
# print(f"Participant {self.data['info']['number']}: {folder}")
with open(os.path.join(folder, "{}.pkl".format(data_key)), 'rb') as f:
self.data[data_key].update(dill.load(f))
def roll_independant(A, r):
rows, column_indices = np.ogrid[:A.shape[0], :A.shape[1]]
# Use always a negative shift, so that column_indices are valid.
# (could also use module operation)
r[r < 0] += A.shape[1]
column_indices = column_indices - r[:,np.newaxis]
result = A[rows, column_indices]
return result
def main():
'''
'''
logger.warning("***REMEMBER THIS SCRIPTS WILL NOT OVERWRITE ANY EXISTING PARTICIPANT DATA. PLEASE DELETE THIS MANUALLY IF NEEDED!***")
participants = find_participants()
with open('./test_params.json') as json_file:
general_params = json.load(json_file)
# Generate all permutations of tests for couterbalancing
tests = general_params['tests']
cb_tests = list(itertools.permutations(tests))
tone_freqs = general_params['tone_freqs']
cb_tone_freqs = list(itertools.permutations(tone_freqs))
# Make sure that the number of participants is a multiple of the number of
# counterbalanced tests
cb_lcm = np.lcm.reduce([len(cb_tests), len(cb_tone_freqs), 4])
n_participants = cb_lcm * 3
part_nums = gen_participant_num(participants, N=n_participants)
n_decoder_repeats = general_params['decoder_test_SNR_repeats']
# Get all decoder test stimuli
listDir = "./matrix_test/short_concat_stim/out"
stim_dirs = natsorted([x for x in os.listdir(listDir) if os.path.isdir(os.path.join(listDir, x))])
for i in part_nums:
participant_params = {}
# Set the order of tests to be presented to the current participant,
# using previous counterbalancing above
participant_params['tests'] = list(cb_tests[(i-1) % len(cb_tests)])
# Randomy shuffle order of stimuli to be presented in decoder testing
sd_copy = copy.copy(stim_dirs)
random.shuffle(sd_copy)
participant_params['decoder_test_lists'] = sd_copy
# Generate randomised stimulus/SNR combinations for each participant
snrs = np.array(general_params['decoder_test_SNRs'], dtype=float)
np.random.shuffle(snrs)
snrs = np.repeat(snrs[np.newaxis], n_decoder_repeats, axis=0)
snrs = roll_independant(snrs, np.array(np.arange(n_decoder_repeats)+1-n_decoder_repeats))
participant_params['decoder_test_SNRs'] = snrs
# Is the hearing loss simulator active for this participant?
# Even numbers yes, odd numbers no
hl_sim_active = ((i-1) % 2) == 0
participant_params['hl_sim_active'] = hl_sim_active
# What order are the decoder stories presented?
dec_train_lists = general_params['decoder_train_lists']
if (int((i-1)/2) % 2) == 0:
# Play second story first for half the participants
participant_params['decoder_train_lists'] = dec_train_lists[4:8]+dec_train_lists[0:4]
else:
participant_params['decoder_train_lists'] = dec_train_lists
# What order are the decoder test stimuli presented?
dtl_copy = copy.copy(general_params['decoder_test_lists'])
np.random.shuffle(dtl_copy)
participant_params['decoder_test_lists'] = dtl_copy
# What order are the behavioural test stimuli presented?
participant_params['behavioural_train_lists'] = np.random.choice(
general_params['behavioural_train_lists'],
[general_params['behavioural_train_N']], replace=False)
participant_params['behavioural_test_lists'] = np.random.choice(
general_params['behavioural_test_lists'],
[general_params['behavioural_test_N']], replace=False)
# What order are the tone SNRs presented at?
n_tone_repeats = general_params['tone_repeats']
tone_snrs = np.array(general_params['tone_SNRs'], dtype=float)
np.random.shuffle(tone_snrs)
tone_snrs = np.repeat(tone_snrs[np.newaxis], n_tone_repeats, axis=1)
# Remove inf SNRs
infs = np.isinf(tone_snrs)
tone_snrs = tone_snrs[~infs]
# Prepend inf SNRS so they are always presented first
tone_snrs = np.concatenate([[np.inf]*infs.sum(), tone_snrs])
participant_params['tone_SNRs'] = tone_snrs
# What order are the tones presented at?
# Set the order of tone frequencies to be presented to the current
# participant, using previous counterbalancing above
participant_params['tone_freqs'] = list(cb_tone_freqs[int(((i-1)/4)) % len(cb_tone_freqs)])
final_params = copy.copy(general_params)
final_params.update(participant_params)
key = "participant_{}".format(i)
logger.info("{:<78}".format(f"Generating: {key}"))
participants[key] = Participant(participant_dir="./participant_data/{}".format(key),
number=i, parameters=final_params, gen_time=nowtime)
participants[key].save("info")
participants[key].save("parameters")
# Log all parameters of the current participant
for key, val in participants[key].parameters.items():
if type(val) is np.ndarray:
val = val.tolist()
trunc_str = re.sub(r'^(.{75}).*$', '\g<1>...', f"{key:<25}{val}")
logger.info(f"{trunc_str: <78}")
logger.info("-"*78)
print(f"Generated {part_nums.size} new participant databases")
if __name__ == '__main__':
logs_dir = Path('./logs/')
logs_dir.mkdir(exist_ok=True)
logfile_dir = logs_dir / __file__
logfile_dir.mkdir(exist_ok=True)
logfile_name = nowtime.strftime("%m-%d-%Y_%H-%M-%S")+'.log'
logger = create_logger(
logger_streamlevel=10,
log_filename=str(logfile_dir/logfile_name),
logger_filelevel=10
)
main()