final commit

This commit is contained in:
Joe Darby
2016-12-19 17:17:55 +00:00
parent 13399dd85f
commit a1ee8e06ab
3 changed files with 6 additions and 30 deletions
+4 -6
View File
@@ -11,13 +11,13 @@ object KMeans {
def train(dataset : DataFrame, iterations:Int) : Unit = { def train(dataset : DataFrame, iterations:Int) : Unit = {
val K = 4 // Number of desired clusters val K = 8 // Number of desired clusters
val relevantData = dataset.select("Reputation", "LastAccessDate") val relevantData = dataset.select("Reputation", "UpVotes", "DownVotes", "CreationDate", "LastAccessDate")
val m = relevantData.columns.length //number of features val m = relevantData.columns.length //number of features
val rows = relevantData.rdd val rows = relevantData.rdd
val rowsAsArray = rows.map(row => convertRow(row, m)).persist() val rowsAsArray = rows.map(row => convertRow(row, m)).persist()
//Initialise the centres by taking a random sample
var centres: Array[Array[Float]] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt) var centres: Array[Array[Float]] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt)
//To reduce chance of two random centres being the same, add a changing value to each //To reduce chance of two random centres being the same, add a changing value to each
println("centres initialised") println("centres initialised")
@@ -30,6 +30,7 @@ object KMeans {
var counts = Array[Int](K) var counts = Array[Int](K)
//Iterate through the clustering algorithm
for (i <- 0 until iterations) { for (i <- 0 until iterations) {
val clusterMap = clustering(centres, rowsAsArray, m, K).persist() val clusterMap = clustering(centres, rowsAsArray, m, K).persist()
centres = getCentres(clusterMap, m, K) centres = getCentres(clusterMap, m, K)
@@ -92,9 +93,6 @@ object KMeans {
} }
def calculateNorm(datapoint : Array[Float], centre : Array[Float], m: Int): Double = { def calculateNorm(datapoint : Array[Float], centre : Array[Float], m: Int): Double = {
var norm : Double = 0.0 var norm : Double = 0.0
for (a <- 0 until m) { for (a <- 0 until m) {
+1 -23
View File
@@ -30,28 +30,6 @@ object Main {
// get the users XML file // get the users XML file
val users = df("users") val users = df("users")
val centres = KMeans.train(users, 25) val centres = KMeans.train(users, 50)
//val centresArray = centres.collect()
//val unwrap = centresArray.map(x => x._2)
//unwrap.foreach(println)
} }
} }
//val users = dataFrames("users")
/*val dataFrames = DataParser.ParseData()
// get the users XML file
val users = dataFrames("users")
users.persist()
// Show 20 entries from the user dataset
users.show()
// Show types for the user dataset
users.printSchema()
users.show()
// create new dataframe with only the reputation of the users
users.select("CreationDate").show()
*/
// Info on using DataFrames here: https://www.mapr.com/blog/using-apache-spark-dataframes-processing-tabular-data
+1 -1
View File
@@ -30,7 +30,7 @@ object XMLParser {
("postHistory", "/data/stackoverflow/PostHistory","Id PostHistoryTypeId PostId RevisionGUID CreationDate UserId UserDisplayName Comment Text CloseReasonId", Array[DataType](IntegerType, IntegerType, IntegerType,IntegerType, DateType, IntegerType, StringType, StringType, StringType, IntegerType)), ("postHistory", "/data/stackoverflow/PostHistory","Id PostHistoryTypeId PostId RevisionGUID CreationDate UserId UserDisplayName Comment Text CloseReasonId", Array[DataType](IntegerType, IntegerType, IntegerType,IntegerType, DateType, IntegerType, StringType, StringType, StringType, IntegerType)),
("postLinks", "data/stackoverflow/PostLinks", "Id CreationDate PostId RelatedPostId PostLinkTypeId", Array[DataType](IntegerType, DateType, IntegerType, IntegerType, IntegerType)), ("postLinks", "data/stackoverflow/PostLinks", "Id CreationDate PostId RelatedPostId PostLinkTypeId", Array[DataType](IntegerType, DateType, IntegerType, IntegerType, IntegerType)),
*/ */
("users", "stackoverflow_dataset/users.txt", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType)) ("users", "data/stackoverflow/Users", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType))
/* /*
("votes", "/data/stackoverflow/Votes", "Id PostId VoteTypeId UserId CreationDate", Array[DataType](IntegerType, IntegerType, IntegerType, IntegerType, DateType)) ("votes", "/data/stackoverflow/Votes", "Id PostId VoteTypeId UserId CreationDate", Array[DataType](IntegerType, IntegerType, IntegerType, IntegerType, DateType))
*/ */