Implemented K-d tree searching

This commit is contained in:
2016-03-24 23:20:24 +00:00
parent b73ec56a2c
commit 997a91ce1b
3 changed files with 111 additions and 21 deletions
+1 -1
View File
@@ -253,7 +253,7 @@ def main():
# Perform matching on databases using the method specified.
matcher.match(
matcher.brute_force_matcher,
matcher.knn_matcher,
grain_size=config.matcher["grain_size"],
overlap=config.matcher["overlap"]
)
+15 -15
View File
@@ -1,21 +1,21 @@
rms = {
"window_size": 130,
"overlap": 16,
"window_size": 70,
"overlap": 2,
}
variance = {
"window_size": 130,
"overlap": 16
"window_size": 70,
"overlap": 2
}
kurtosis = {
"window_size": 130,
"overlap": 16
"window_size": 70,
"overlap": 2
}
skewness = {
"window_size": 130,
"overlap": 16
"window_size": 70,
"overlap": 2
}
fft = {
@@ -63,24 +63,24 @@ analysis = {
matcher = {
"rematch": True,
"grain_size": 130,
"overlap": 16,
"grain_size": 70,
"overlap": 2,
# Defines the number of matches to keep for synthesis. Note that this must
# also be specified in the synthesis config
"match_quantity": 20
"match_quantity": 1
}
synthesizer = {
"enforce_rms": True,
"enf_rms_ratio_limit": 5.,
"enf_rms_ratio_limit": 100.,
"enforce_f0": True,
"enf_f0_ratio_limit": 10.,
"grain_size": 130,
"overlap": 16,
"grain_size": 70,
"overlap": 2,
"normalize" : True,
# Defines the number of potential grains to choose from matches when
# synthesizing output.
"match_quantity": 20
"match_quantity": 1
}
output_file = {
+95 -5
View File
@@ -2,7 +2,7 @@ from __future__ import print_function, division
import os
import shutil
import collections
from scipy import signal
from scipy import signal, spatial
import numpy as np
import pysndfile
import pdb
@@ -11,6 +11,7 @@ import traceback
import logging
import h5py
import pitch_shift
from sklearn.preprocessing import Imputer
from fileops import pathops
from audiofile import AnalysedAudioFile, AudioFile
@@ -354,6 +355,97 @@ class Matcher:
grain_indexes[:, 0] = grain_indexes[:, 1] - grain_indexes[:, 0]
return grain_indexes
def knn_matcher(self, grain_size, overlap):
# Count grains of the source database
source_sample_indexes = self.count_grains(self.source_db, grain_size, overlap)
try:
self.output_db.data.create_group("match")
except ValueError:
self.logger.debug("Match group already exists in the {0} HDF5 file.".format(self.output_db))
if self.rematch:
self.output_db.data["match"].clear()
#
final_match_indexes = []
if self.config:
weightings = self.config.matcher_weightings
else:
weightings = {x: 1. for x in self.matcher_analyses}
for tind, target_entry in enumerate(self.target_db.analysed_audio):
# Create an array of grain times for target sample
target_times = target_entry.generate_grain_times(grain_size, overlap, save_times=True)
x_size = target_times.shape[0]
match_indexes = np.empty((x_size, self.match_quantity))
match_vals = np.empty((x_size, self.match_quantity))
match_vals.fill(np.inf)
all_target_analyses = np.empty((len(self.matcher_analyses), target_times.shape[0]))
for i, analysis in enumerate(self.matcher_analyses):
analysis_formatting = self.analysis_dict[analysis]
target_data, s = target_entry.analysis_data_grains(target_times, analysis, format=analysis_formatting)
all_target_analyses[i] = target_data
imp = Imputer(axis=1)
all_target_analyses = imp.fit_transform(all_target_analyses)
# all_target_analyses[np.isnan(all_target_analyses)] = np.inf
# all_target_analyses = np.nan_to_num(all_target_analyses)
for sind, source_entry in enumerate(self.source_db.analysed_audio):
# Create an array of grain times for source sample
source_times = source_entry.generate_grain_times(grain_size, overlap, save_times=True)
all_source_analyses = np.empty((len(self.matcher_analyses), source_times.shape[0]))
for i, analysis in enumerate(self.matcher_analyses):
analysis_formatting = self.analysis_dict[analysis]
source_data, s = source_entry.analysis_data_grains(source_times, analysis, format=analysis_formatting)
all_source_analyses[i] = source_data
self.logger.info("Matching \"{0}\" for: {1} to {2}".format(analysis, source_entry.name, target_entry.name))
# all_source_analyses[np.isnan(all_source_analyses)] = np.inf
# all_source_analyses = np.nan_to_num(all_source_analyses)
all_source_analyses = imp.fit_transform(all_source_analyses)
source_tree = spatial.cKDTree(all_source_analyses.T, leafsize=100)
results_vals, results_inds = source_tree.query(all_target_analyses.T, k=self.match_quantity, p=2)
if len(results_vals.shape) < 2:
results_vals = np.array([results_vals]).T
results_inds = np.array([results_inds]).T
vals_append = np.append(match_vals, results_vals, axis=1)
vals_sort = np.argsort(vals_append)
inds_append = np.append(match_indexes, results_inds+source_sample_indexes[sind][0], axis=1)
m = np.arange(len(vals_append))[:, np.newaxis]
best_match_inds = inds_append[m, vals_sort]
match_indexes = best_match_inds[:, :self.match_quantity]
best_match_vals = vals_append[m, vals_sort]
match_vals = best_match_vals[:, :self.match_quantity]
match_grain_inds = self.calculate_db_inds(match_indexes, source_sample_indexes)
datafile_path = ''.join(("match/", target_entry.name))
try:
self.output_db.data[datafile_path] = match_grain_inds
self.output_db.data[datafile_path].attrs["grain_size"] = grain_size
self.output_db.data[datafile_path].attrs["overlap"] = overlap
except RuntimeError as err:
raise RuntimeError("Match data couldn't be written to HDF5 "
"file.\n Match data may already exist in the "
"file.\n Try running with the '--rematch' flag "
"to overwrite this data.\n Original error: "
"{0}".format(err))
def brute_force_matcher(self, grain_size, overlap):
'''Searches for matches to each grain by brute force comparison'''
@@ -388,20 +480,18 @@ class Matcher:
y_size = int(source_sample_indexes[-1][-1])
chunk_size = 8192
self.output_db.data.create_dataset("data_distance", (x_size, y_size), dtype=np.float, chunks=True)
self.output_db.data.create_dataset("distance_accum", (x_size, y_size), dtype=np.float, chunks=True, fillvalue=0)
for analysis in self.matcher_analyses:
self.logger.info("Current analysis: {0}".format(analysis))
analysis_formatting = self.analysis_dict[analysis]
# Get the analysis object for the current entry
analysis_object = target_entry.analyses[analysis]
# Get data for all target grains for each analysis
target_data, s = target_entry.analysis_data_grains(target_times, analysis, format=analysis_formatting)
data_max = 0.
for sind, source_entry in enumerate(self.source_db.analysed_audio):