k means, find centres
This commit is contained in:
Generated
+35
-7
@@ -3,6 +3,7 @@
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="b41a9788-25b3-4e04-923f-17cde259631b" name="Default" comment="">
|
||||
<change type="MODIFICATION" beforePath="$PROJECT_DIR$/src/main/scala/KMeans.scala" afterPath="$PROJECT_DIR$/src/main/scala/KMeans.scala" />
|
||||
<change type="MODIFICATION" beforePath="$PROJECT_DIR$/src/main/scala/Main.scala" afterPath="$PROJECT_DIR$/src/main/scala/Main.scala" />
|
||||
</list>
|
||||
<ignored path="$PROJECT_DIR$/target/" />
|
||||
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
|
||||
@@ -21,15 +22,25 @@
|
||||
<file leaf-file-name="KMeans.scala" pinned="false" current-in-tab="true">
|
||||
<entry file="file://$PROJECT_DIR$/src/main/scala/KMeans.scala">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="551">
|
||||
<caret line="50" column="7" lean-forward="true" selection-start-line="50" selection-start-column="7" selection-end-line="50" selection-end-column="7" />
|
||||
<state relative-caret-position="-450">
|
||||
<caret line="5" column="0" lean-forward="true" selection-start-line="5" selection-start-column="0" selection-end-line="5" selection-end-column="0" />
|
||||
<folding>
|
||||
<element signature="e#23#59#0" expanded="true" />
|
||||
<element signature="e#23#54#0" expanded="true" />
|
||||
</folding>
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
</file>
|
||||
<file leaf-file-name="Main.scala" pinned="false" current-in-tab="false">
|
||||
<entry file="file://$PROJECT_DIR$/src/main/scala/Main.scala">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="360">
|
||||
<caret line="33" column="1" lean-forward="true" selection-start-line="33" selection-start-column="1" selection-end-line="33" selection-end-column="1" />
|
||||
<folding />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
</file>
|
||||
<file leaf-file-name="XMLParser.scala" pinned="false" current-in-tab="false">
|
||||
<entry file="file://$PROJECT_DIR$/src/main/scala/XMLParser.scala">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
@@ -53,6 +64,7 @@
|
||||
<component name="IdeDocumentHistory">
|
||||
<option name="CHANGED_PATHS">
|
||||
<list>
|
||||
<option value="$PROJECT_DIR$/src/main/scala/Main.scala" />
|
||||
<option value="$PROJECT_DIR$/src/main/scala/KMeans.scala" />
|
||||
</list>
|
||||
</option>
|
||||
@@ -441,7 +453,7 @@
|
||||
<state relative-caret-position="0">
|
||||
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
|
||||
<folding>
|
||||
<element signature="e#23#59#0" expanded="true" />
|
||||
<element signature="e#23#54#0" expanded="true" />
|
||||
</folding>
|
||||
</state>
|
||||
</provider>
|
||||
@@ -462,12 +474,28 @@
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="jar://$MAVEN_REPOSITORY$/org/scala-lang/scala-library/2.10.5/scala-library-2.10.5.jar!/scala/collection/TraversableLike.class">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="168">
|
||||
<caret line="15" column="6" lean-forward="false" selection-start-line="15" selection-start-column="6" selection-end-line="15" selection-end-column="6" />
|
||||
<folding />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/src/main/scala/Main.scala">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="360">
|
||||
<caret line="33" column="1" lean-forward="true" selection-start-line="33" selection-start-column="1" selection-end-line="33" selection-end-column="1" />
|
||||
<folding />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/src/main/scala/KMeans.scala">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="551">
|
||||
<caret line="50" column="7" lean-forward="true" selection-start-line="50" selection-start-column="7" selection-end-line="50" selection-end-column="7" />
|
||||
<state relative-caret-position="-450">
|
||||
<caret line="5" column="0" lean-forward="true" selection-start-line="5" selection-start-column="0" selection-end-line="5" selection-end-column="0" />
|
||||
<folding>
|
||||
<element signature="e#23#59#0" expanded="true" />
|
||||
<element signature="e#23#54#0" expanded="true" />
|
||||
</folding>
|
||||
</state>
|
||||
</provider>
|
||||
|
||||
+43
-19
@@ -1,12 +1,7 @@
|
||||
package ClusterSOData
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.SparkContext._
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
object KMeans {
|
||||
@@ -14,22 +9,26 @@ object KMeans {
|
||||
* Run KMeans clustering on an input RDD vector
|
||||
*/
|
||||
//Create a map to store each data row with its closest cluster index as key
|
||||
var clusterMap : mutable.HashMap[Int,Row]
|
||||
|
||||
def train(dataset : DataFrame) {
|
||||
val rows = dataset.collect()
|
||||
def train(dataset : DataFrame) : RDD[(Int,Row)] = {
|
||||
val rows = dataset.rdd
|
||||
val K = 5 //number of intended clusters
|
||||
val n = rows.length //number of datapoints
|
||||
val n = rows.count() //number of datapoints
|
||||
val m = 3 //number of features
|
||||
var centres = new ArrayBuffer[Row]
|
||||
//var centres = new ArrayBuffer[Row]
|
||||
|
||||
//get random number generator r and use to select K centres randomly from dataset
|
||||
val r = scala.util.Random
|
||||
/*val r = scala.util.Random
|
||||
val random =
|
||||
var a = 0
|
||||
for (a <- 0 until K) {
|
||||
centres(a) = rows(r.nextInt(n))
|
||||
}
|
||||
val clusterMap = rows.map(x => assignCluster(x, centres, m))
|
||||
centres(a) = rows(r.ne
|
||||
}*/
|
||||
val centres = rows.takeSample(false, K, System.nanoTime().toInt)
|
||||
val clusterMap :RDD[(Int,Row)]= rows.map(row => (assignCluster(row,centres,m,K),row))
|
||||
val newCentres = calculateNewCentres(clusterMap)
|
||||
newCentres
|
||||
|
||||
}
|
||||
|
||||
def calculateNorm(datapoint : Row, centre : Row, m: Int): Double = {
|
||||
@@ -40,16 +39,41 @@ object KMeans {
|
||||
norm = Math.pow(norm, 0.5)
|
||||
}
|
||||
|
||||
def assignCluster(row : Row, centres: ArrayBuffer[Row], m : Int): (Int,Row) = {
|
||||
def assignCluster(row : Row, centres: Array[Row], m : Int, K :Int): Int = {
|
||||
var smallestNorm = 99999999999.0
|
||||
var closestCentre = 0
|
||||
for (centreIndex <- centres.indices) {
|
||||
val norm = calculateNorm(row, centres(centreIndex), m)
|
||||
for (centreNumber <- 0 until K) {
|
||||
val norm = calculateNorm(row, centres(centreNumber), m)
|
||||
if (norm < smallestNorm) {
|
||||
smallestNorm = norm
|
||||
closestCentre = centreIndex
|
||||
closestCentre = centreNumber
|
||||
}
|
||||
}
|
||||
(closestCentre,row)
|
||||
closestCentre
|
||||
}
|
||||
|
||||
def calculateNewCentres(clusterMap : RDD[(Int,Row)]): RDD[(Int,Row)] = {
|
||||
val newCentres = clusterMap.reduceByKey((a,b) => averageRow(a,b))
|
||||
|
||||
|
||||
|
||||
/*for (a <- 0 until K) {
|
||||
var cluster = clusterMap.filter{case (a,_) => a == 0}
|
||||
var data = cluster.map((_,a) => a :Row)*/
|
||||
|
||||
}
|
||||
|
||||
/*def getCentre(cluster : RDD[(Int,Row)], oldCentre : Row, clusterIndex :Int) : Row = {
|
||||
val unWrappedData :RDD[Row] = cluster.map(x => x._2)
|
||||
val features : Row = unWrappedData.reduce(averageRow)
|
||||
return features
|
||||
}*/
|
||||
|
||||
def averageRow(a :Row, b:Row) : Row = {
|
||||
val newRow = new ArrayBuffer[Float]
|
||||
for (i <- a.size) {
|
||||
val avgI = (a.getFloat(i) + b.getFloat(i)) /2
|
||||
newRow(i) = avgI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,9 @@ object Main {
|
||||
|
||||
// get the users XML file
|
||||
val users = df("users")
|
||||
users.show()
|
||||
val centres = KMeans.train(users)
|
||||
val centresArray = centres.collect()
|
||||
val unwrap = centresArray.map(x => x._2)
|
||||
unwrap.foreach(println)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user