k means, find centres

This commit is contained in:
joedarby
2016-12-15 16:04:14 +00:00
parent c3c01fe9e8
commit e90318a9a7
3 changed files with 82 additions and 27 deletions
+35 -7
View File
@@ -3,6 +3,7 @@
<component name="ChangeListManager">
<list default="true" id="b41a9788-25b3-4e04-923f-17cde259631b" name="Default" comment="">
<change type="MODIFICATION" beforePath="$PROJECT_DIR$/src/main/scala/KMeans.scala" afterPath="$PROJECT_DIR$/src/main/scala/KMeans.scala" />
<change type="MODIFICATION" beforePath="$PROJECT_DIR$/src/main/scala/Main.scala" afterPath="$PROJECT_DIR$/src/main/scala/Main.scala" />
</list>
<ignored path="$PROJECT_DIR$/target/" />
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
@@ -21,15 +22,25 @@
<file leaf-file-name="KMeans.scala" pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/src/main/scala/KMeans.scala">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="551">
<caret line="50" column="7" lean-forward="true" selection-start-line="50" selection-start-column="7" selection-end-line="50" selection-end-column="7" />
<state relative-caret-position="-450">
<caret line="5" column="0" lean-forward="true" selection-start-line="5" selection-start-column="0" selection-end-line="5" selection-end-column="0" />
<folding>
<element signature="e#23#59#0" expanded="true" />
<element signature="e#23#54#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file leaf-file-name="Main.scala" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/src/main/scala/Main.scala">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="360">
<caret line="33" column="1" lean-forward="true" selection-start-line="33" selection-start-column="1" selection-end-line="33" selection-end-column="1" />
<folding />
</state>
</provider>
</entry>
</file>
<file leaf-file-name="XMLParser.scala" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/src/main/scala/XMLParser.scala">
<provider selected="true" editor-type-id="text-editor">
@@ -53,6 +64,7 @@
<component name="IdeDocumentHistory">
<option name="CHANGED_PATHS">
<list>
<option value="$PROJECT_DIR$/src/main/scala/Main.scala" />
<option value="$PROJECT_DIR$/src/main/scala/KMeans.scala" />
</list>
</option>
@@ -441,7 +453,7 @@
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding>
<element signature="e#23#59#0" expanded="true" />
<element signature="e#23#54#0" expanded="true" />
</folding>
</state>
</provider>
@@ -462,12 +474,28 @@
</state>
</provider>
</entry>
<entry file="jar://$MAVEN_REPOSITORY$/org/scala-lang/scala-library/2.10.5/scala-library-2.10.5.jar!/scala/collection/TraversableLike.class">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="168">
<caret line="15" column="6" lean-forward="false" selection-start-line="15" selection-start-column="6" selection-end-line="15" selection-end-column="6" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/src/main/scala/Main.scala">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="360">
<caret line="33" column="1" lean-forward="true" selection-start-line="33" selection-start-column="1" selection-end-line="33" selection-end-column="1" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/src/main/scala/KMeans.scala">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="551">
<caret line="50" column="7" lean-forward="true" selection-start-line="50" selection-start-column="7" selection-end-line="50" selection-end-column="7" />
<state relative-caret-position="-450">
<caret line="5" column="0" lean-forward="true" selection-start-line="5" selection-start-column="0" selection-end-line="5" selection-end-column="0" />
<folding>
<element signature="e#23#59#0" expanded="true" />
<element signature="e#23#54#0" expanded="true" />
</folding>
</state>
</provider>
+43 -19
View File
@@ -1,12 +1,7 @@
package ClusterSOData
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
object KMeans {
@@ -14,22 +9,26 @@ object KMeans {
* Run KMeans clustering on an input RDD vector
*/
//Create a map to store each data row with its closest cluster index as key
var clusterMap : mutable.HashMap[Int,Row]
def train(dataset : DataFrame) {
val rows = dataset.collect()
def train(dataset : DataFrame) : RDD[(Int,Row)] = {
val rows = dataset.rdd
val K = 5 //number of intended clusters
val n = rows.length //number of datapoints
val n = rows.count() //number of datapoints
val m = 3 //number of features
var centres = new ArrayBuffer[Row]
//var centres = new ArrayBuffer[Row]
//get random number generator r and use to select K centres randomly from dataset
val r = scala.util.Random
/*val r = scala.util.Random
val random =
var a = 0
for (a <- 0 until K) {
centres(a) = rows(r.nextInt(n))
}
val clusterMap = rows.map(x => assignCluster(x, centres, m))
centres(a) = rows(r.ne
}*/
val centres = rows.takeSample(false, K, System.nanoTime().toInt)
val clusterMap :RDD[(Int,Row)]= rows.map(row => (assignCluster(row,centres,m,K),row))
val newCentres = calculateNewCentres(clusterMap)
newCentres
}
def calculateNorm(datapoint : Row, centre : Row, m: Int): Double = {
@@ -40,16 +39,41 @@ object KMeans {
norm = Math.pow(norm, 0.5)
}
def assignCluster(row : Row, centres: ArrayBuffer[Row], m : Int): (Int,Row) = {
def assignCluster(row : Row, centres: Array[Row], m : Int, K :Int): Int = {
var smallestNorm = 99999999999.0
var closestCentre = 0
for (centreIndex <- centres.indices) {
val norm = calculateNorm(row, centres(centreIndex), m)
for (centreNumber <- 0 until K) {
val norm = calculateNorm(row, centres(centreNumber), m)
if (norm < smallestNorm) {
smallestNorm = norm
closestCentre = centreIndex
closestCentre = centreNumber
}
}
(closestCentre,row)
closestCentre
}
def calculateNewCentres(clusterMap : RDD[(Int,Row)]): RDD[(Int,Row)] = {
val newCentres = clusterMap.reduceByKey((a,b) => averageRow(a,b))
/*for (a <- 0 until K) {
var cluster = clusterMap.filter{case (a,_) => a == 0}
var data = cluster.map((_,a) => a :Row)*/
}
/*def getCentre(cluster : RDD[(Int,Row)], oldCentre : Row, clusterIndex :Int) : Row = {
val unWrappedData :RDD[Row] = cluster.map(x => x._2)
val features : Row = unWrappedData.reduce(averageRow)
return features
}*/
def averageRow(a :Row, b:Row) : Row = {
val newRow = new ArrayBuffer[Float]
for (i <- a.size) {
val avgI = (a.getFloat(i) + b.getFloat(i)) /2
newRow(i) = avgI
}
}
}
+4 -1
View File
@@ -26,6 +26,9 @@ object Main {
// get the users XML file
val users = df("users")
users.show()
val centres = KMeans.train(users)
val centresArray = centres.collect()
val unwrap = centresArray.map(x => x._2)
unwrap.foreach(println)
}
}