diff --git a/src/main/scala/KMeans.scala b/src/main/scala/KMeans.scala index 5cff70e..5fa1c47 100644 --- a/src/main/scala/KMeans.scala +++ b/src/main/scala/KMeans.scala @@ -11,13 +11,13 @@ object KMeans { def train(dataset : DataFrame, iterations:Int) : Unit = { - val K = 4 // Number of desired clusters - val relevantData = dataset.select("Reputation", "LastAccessDate") + val K = 8 // Number of desired clusters + val relevantData = dataset.select("Reputation", "UpVotes", "DownVotes", "CreationDate", "LastAccessDate") val m = relevantData.columns.length //number of features val rows = relevantData.rdd val rowsAsArray = rows.map(row => convertRow(row, m)).persist() - + //Initialise the centres by taking a random sample var centres: Array[Array[Float]] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt) //To reduce chance of two random centres being the same, add a changing value to each println("centres initialised") @@ -30,6 +30,7 @@ object KMeans { var counts = Array[Int](K) + //Iterate through the clustering algorithm for (i <- 0 until iterations) { val clusterMap = clustering(centres, rowsAsArray, m, K).persist() centres = getCentres(clusterMap, m, K) @@ -92,9 +93,6 @@ object KMeans { } - - - def calculateNorm(datapoint : Array[Float], centre : Array[Float], m: Int): Double = { var norm : Double = 0.0 for (a <- 0 until m) { diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala index 9713d6d..76adc6b 100644 --- a/src/main/scala/Main.scala +++ b/src/main/scala/Main.scala @@ -30,28 +30,6 @@ object Main { // get the users XML file val users = df("users") - val centres = KMeans.train(users, 25) - //val centresArray = centres.collect() - //val unwrap = centresArray.map(x => x._2) - //unwrap.foreach(println) + val centres = KMeans.train(users, 50) } } - - //val users = dataFrames("users") - - /*val dataFrames = DataParser.ParseData() - - // get the users XML file - val users = dataFrames("users") - users.persist() - // Show 20 entries from the user dataset - users.show() - // Show types for the user dataset - users.printSchema() - users.show() - - // create new dataframe with only the reputation of the users - users.select("CreationDate").show() -*/ - // Info on using DataFrames here: https://www.mapr.com/blog/using-apache-spark-dataframes-processing-tabular-data - diff --git a/src/main/scala/XMLParser.scala b/src/main/scala/XMLParser.scala index 4efdce6..f01f231 100644 --- a/src/main/scala/XMLParser.scala +++ b/src/main/scala/XMLParser.scala @@ -30,7 +30,7 @@ object XMLParser { ("postHistory", "/data/stackoverflow/PostHistory","Id PostHistoryTypeId PostId RevisionGUID CreationDate UserId UserDisplayName Comment Text CloseReasonId", Array[DataType](IntegerType, IntegerType, IntegerType,IntegerType, DateType, IntegerType, StringType, StringType, StringType, IntegerType)), ("postLinks", "data/stackoverflow/PostLinks", "Id CreationDate PostId RelatedPostId PostLinkTypeId", Array[DataType](IntegerType, DateType, IntegerType, IntegerType, IntegerType)), */ - ("users", "stackoverflow_dataset/users.txt", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType)) + ("users", "data/stackoverflow/Users", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType)) /* ("votes", "/data/stackoverflow/Votes", "Id PostId VoteTypeId UserId CreationDate", Array[DataType](IntegerType, IntegerType, IntegerType, IntegerType, DateType)) */