This commit is contained in:
Joe Darby
2016-12-16 14:56:48 +00:00
parent d7a10deaab
commit 543cee8c79
2 changed files with 10 additions and 5 deletions
+9 -4
View File
@@ -10,7 +10,7 @@ object KMeans {
* Run KMeans clustering on an input RDD vector
*/
//Create a map to store each data row with its closest cluster index as key
var centres : ArrayBuffer[Float] = ArrayBuffer(0.0f, 100.0f)
def train(dataset : DataFrame, iterations:Int) : Unit = {
val relevantData = dataset.select("Reputation")
@@ -18,20 +18,25 @@ object KMeans {
//val rowsAsArray = rows.map(row => List(row.getInt(0).toFloat, row.getInt(1).toFloat, row.getInt(2).toFloat) )
val rowsAsArray = rows.map(row => row.getInt(0).toFloat )
//rowsAsArray.foreach(println)
val K = 2 //number of intended clusters
val K = 4 //number of intended clusters
//val n = rows.count() //number of datapoints
val m = 1 //number of features
//var centres = new ArrayBuffer[Row]
for (i <- 0 to iterations) {
var centres : ArrayBuffer[Float] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt)
for (i <- 0 until iterations) {
val clusterMap :RDD[(Int,Float)]= rowsAsArray.map(row => (assignCluster(row,centres,m,K),row))
val newCentres = clusterMap.reduceByKey((a,b) => getAverage(a,b))
println("Average reputation is:")
val results = newCentres.map(x => x._2)
val resultsOutput = results.collect()
for (i <- 0 until K) {
centres(i) = resultsOutput(i)
}
println("Average reputation is:")
centres.foreach(println)
}
//get random number generator r and use to select K centres randomly from dataset
+1 -1
View File
@@ -29,7 +29,7 @@ object Main {
// get the users XML file
val users = df("users")
val centres = KMeans.train(users, 50)
val centres = KMeans.train(users, 3)
//val centresArray = centres.collect()
//val unwrap = centresArray.map(x => x._2)
//unwrap.foreach(println)