Implemented K-d tree searching
This commit is contained in:
@@ -253,7 +253,7 @@ def main():
|
||||
|
||||
# Perform matching on databases using the method specified.
|
||||
matcher.match(
|
||||
matcher.brute_force_matcher,
|
||||
matcher.knn_matcher,
|
||||
grain_size=config.matcher["grain_size"],
|
||||
overlap=config.matcher["overlap"]
|
||||
)
|
||||
|
||||
+15
-15
@@ -1,21 +1,21 @@
|
||||
rms = {
|
||||
"window_size": 130,
|
||||
"overlap": 16,
|
||||
"window_size": 70,
|
||||
"overlap": 2,
|
||||
}
|
||||
|
||||
variance = {
|
||||
"window_size": 130,
|
||||
"overlap": 16
|
||||
"window_size": 70,
|
||||
"overlap": 2
|
||||
}
|
||||
|
||||
kurtosis = {
|
||||
"window_size": 130,
|
||||
"overlap": 16
|
||||
"window_size": 70,
|
||||
"overlap": 2
|
||||
}
|
||||
|
||||
skewness = {
|
||||
"window_size": 130,
|
||||
"overlap": 16
|
||||
"window_size": 70,
|
||||
"overlap": 2
|
||||
}
|
||||
|
||||
fft = {
|
||||
@@ -63,24 +63,24 @@ analysis = {
|
||||
|
||||
matcher = {
|
||||
"rematch": True,
|
||||
"grain_size": 130,
|
||||
"overlap": 16,
|
||||
"grain_size": 70,
|
||||
"overlap": 2,
|
||||
# Defines the number of matches to keep for synthesis. Note that this must
|
||||
# also be specified in the synthesis config
|
||||
"match_quantity": 20
|
||||
"match_quantity": 1
|
||||
}
|
||||
|
||||
synthesizer = {
|
||||
"enforce_rms": True,
|
||||
"enf_rms_ratio_limit": 5.,
|
||||
"enf_rms_ratio_limit": 100.,
|
||||
"enforce_f0": True,
|
||||
"enf_f0_ratio_limit": 10.,
|
||||
"grain_size": 130,
|
||||
"overlap": 16,
|
||||
"grain_size": 70,
|
||||
"overlap": 2,
|
||||
"normalize" : True,
|
||||
# Defines the number of potential grains to choose from matches when
|
||||
# synthesizing output.
|
||||
"match_quantity": 20
|
||||
"match_quantity": 1
|
||||
}
|
||||
|
||||
output_file = {
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import print_function, division
|
||||
import os
|
||||
import shutil
|
||||
import collections
|
||||
from scipy import signal
|
||||
from scipy import signal, spatial
|
||||
import numpy as np
|
||||
import pysndfile
|
||||
import pdb
|
||||
@@ -11,6 +11,7 @@ import traceback
|
||||
import logging
|
||||
import h5py
|
||||
import pitch_shift
|
||||
from sklearn.preprocessing import Imputer
|
||||
|
||||
from fileops import pathops
|
||||
from audiofile import AnalysedAudioFile, AudioFile
|
||||
@@ -354,6 +355,97 @@ class Matcher:
|
||||
grain_indexes[:, 0] = grain_indexes[:, 1] - grain_indexes[:, 0]
|
||||
return grain_indexes
|
||||
|
||||
def knn_matcher(self, grain_size, overlap):
|
||||
# Count grains of the source database
|
||||
source_sample_indexes = self.count_grains(self.source_db, grain_size, overlap)
|
||||
try:
|
||||
self.output_db.data.create_group("match")
|
||||
except ValueError:
|
||||
self.logger.debug("Match group already exists in the {0} HDF5 file.".format(self.output_db))
|
||||
|
||||
if self.rematch:
|
||||
self.output_db.data["match"].clear()
|
||||
#
|
||||
final_match_indexes = []
|
||||
|
||||
if self.config:
|
||||
weightings = self.config.matcher_weightings
|
||||
else:
|
||||
weightings = {x: 1. for x in self.matcher_analyses}
|
||||
|
||||
for tind, target_entry in enumerate(self.target_db.analysed_audio):
|
||||
# Create an array of grain times for target sample
|
||||
target_times = target_entry.generate_grain_times(grain_size, overlap, save_times=True)
|
||||
x_size = target_times.shape[0]
|
||||
match_indexes = np.empty((x_size, self.match_quantity))
|
||||
match_vals = np.empty((x_size, self.match_quantity))
|
||||
match_vals.fill(np.inf)
|
||||
|
||||
all_target_analyses = np.empty((len(self.matcher_analyses), target_times.shape[0]))
|
||||
|
||||
for i, analysis in enumerate(self.matcher_analyses):
|
||||
analysis_formatting = self.analysis_dict[analysis]
|
||||
|
||||
target_data, s = target_entry.analysis_data_grains(target_times, analysis, format=analysis_formatting)
|
||||
all_target_analyses[i] = target_data
|
||||
|
||||
imp = Imputer(axis=1)
|
||||
all_target_analyses = imp.fit_transform(all_target_analyses)
|
||||
# all_target_analyses[np.isnan(all_target_analyses)] = np.inf
|
||||
# all_target_analyses = np.nan_to_num(all_target_analyses)
|
||||
|
||||
|
||||
for sind, source_entry in enumerate(self.source_db.analysed_audio):
|
||||
# Create an array of grain times for source sample
|
||||
source_times = source_entry.generate_grain_times(grain_size, overlap, save_times=True)
|
||||
all_source_analyses = np.empty((len(self.matcher_analyses), source_times.shape[0]))
|
||||
|
||||
for i, analysis in enumerate(self.matcher_analyses):
|
||||
analysis_formatting = self.analysis_dict[analysis]
|
||||
|
||||
source_data, s = source_entry.analysis_data_grains(source_times, analysis, format=analysis_formatting)
|
||||
all_source_analyses[i] = source_data
|
||||
|
||||
self.logger.info("Matching \"{0}\" for: {1} to {2}".format(analysis, source_entry.name, target_entry.name))
|
||||
# all_source_analyses[np.isnan(all_source_analyses)] = np.inf
|
||||
# all_source_analyses = np.nan_to_num(all_source_analyses)
|
||||
|
||||
all_source_analyses = imp.fit_transform(all_source_analyses)
|
||||
|
||||
source_tree = spatial.cKDTree(all_source_analyses.T, leafsize=100)
|
||||
results_vals, results_inds = source_tree.query(all_target_analyses.T, k=self.match_quantity, p=2)
|
||||
|
||||
if len(results_vals.shape) < 2:
|
||||
results_vals = np.array([results_vals]).T
|
||||
results_inds = np.array([results_inds]).T
|
||||
|
||||
vals_append = np.append(match_vals, results_vals, axis=1)
|
||||
vals_sort = np.argsort(vals_append)
|
||||
inds_append = np.append(match_indexes, results_inds+source_sample_indexes[sind][0], axis=1)
|
||||
|
||||
m = np.arange(len(vals_append))[:, np.newaxis]
|
||||
best_match_inds = inds_append[m, vals_sort]
|
||||
match_indexes = best_match_inds[:, :self.match_quantity]
|
||||
best_match_vals = vals_append[m, vals_sort]
|
||||
match_vals = best_match_vals[:, :self.match_quantity]
|
||||
|
||||
match_grain_inds = self.calculate_db_inds(match_indexes, source_sample_indexes)
|
||||
|
||||
datafile_path = ''.join(("match/", target_entry.name))
|
||||
try:
|
||||
self.output_db.data[datafile_path] = match_grain_inds
|
||||
self.output_db.data[datafile_path].attrs["grain_size"] = grain_size
|
||||
self.output_db.data[datafile_path].attrs["overlap"] = overlap
|
||||
|
||||
except RuntimeError as err:
|
||||
raise RuntimeError("Match data couldn't be written to HDF5 "
|
||||
"file.\n Match data may already exist in the "
|
||||
"file.\n Try running with the '--rematch' flag "
|
||||
"to overwrite this data.\n Original error: "
|
||||
"{0}".format(err))
|
||||
|
||||
|
||||
|
||||
def brute_force_matcher(self, grain_size, overlap):
|
||||
'''Searches for matches to each grain by brute force comparison'''
|
||||
|
||||
@@ -388,20 +480,18 @@ class Matcher:
|
||||
y_size = int(source_sample_indexes[-1][-1])
|
||||
chunk_size = 8192
|
||||
|
||||
|
||||
self.output_db.data.create_dataset("data_distance", (x_size, y_size), dtype=np.float, chunks=True)
|
||||
|
||||
self.output_db.data.create_dataset("distance_accum", (x_size, y_size), dtype=np.float, chunks=True, fillvalue=0)
|
||||
|
||||
for analysis in self.matcher_analyses:
|
||||
self.logger.info("Current analysis: {0}".format(analysis))
|
||||
analysis_formatting = self.analysis_dict[analysis]
|
||||
# Get the analysis object for the current entry
|
||||
analysis_object = target_entry.analyses[analysis]
|
||||
|
||||
|
||||
# Get data for all target grains for each analysis
|
||||
target_data, s = target_entry.analysis_data_grains(target_times, analysis, format=analysis_formatting)
|
||||
|
||||
|
||||
data_max = 0.
|
||||
for sind, source_entry in enumerate(self.source_db.analysed_audio):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user