final commit
This commit is contained in:
@@ -11,13 +11,13 @@ object KMeans {
|
|||||||
|
|
||||||
|
|
||||||
def train(dataset : DataFrame, iterations:Int) : Unit = {
|
def train(dataset : DataFrame, iterations:Int) : Unit = {
|
||||||
val K = 4 // Number of desired clusters
|
val K = 8 // Number of desired clusters
|
||||||
val relevantData = dataset.select("Reputation", "LastAccessDate")
|
val relevantData = dataset.select("Reputation", "UpVotes", "DownVotes", "CreationDate", "LastAccessDate")
|
||||||
val m = relevantData.columns.length //number of features
|
val m = relevantData.columns.length //number of features
|
||||||
val rows = relevantData.rdd
|
val rows = relevantData.rdd
|
||||||
val rowsAsArray = rows.map(row => convertRow(row, m)).persist()
|
val rowsAsArray = rows.map(row => convertRow(row, m)).persist()
|
||||||
|
|
||||||
|
//Initialise the centres by taking a random sample
|
||||||
var centres: Array[Array[Float]] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt)
|
var centres: Array[Array[Float]] = rowsAsArray.takeSample(false, K, System.nanoTime().toInt)
|
||||||
//To reduce chance of two random centres being the same, add a changing value to each
|
//To reduce chance of two random centres being the same, add a changing value to each
|
||||||
println("centres initialised")
|
println("centres initialised")
|
||||||
@@ -30,6 +30,7 @@ object KMeans {
|
|||||||
|
|
||||||
var counts = Array[Int](K)
|
var counts = Array[Int](K)
|
||||||
|
|
||||||
|
//Iterate through the clustering algorithm
|
||||||
for (i <- 0 until iterations) {
|
for (i <- 0 until iterations) {
|
||||||
val clusterMap = clustering(centres, rowsAsArray, m, K).persist()
|
val clusterMap = clustering(centres, rowsAsArray, m, K).persist()
|
||||||
centres = getCentres(clusterMap, m, K)
|
centres = getCentres(clusterMap, m, K)
|
||||||
@@ -92,9 +93,6 @@ object KMeans {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def calculateNorm(datapoint : Array[Float], centre : Array[Float], m: Int): Double = {
|
def calculateNorm(datapoint : Array[Float], centre : Array[Float], m: Int): Double = {
|
||||||
var norm : Double = 0.0
|
var norm : Double = 0.0
|
||||||
for (a <- 0 until m) {
|
for (a <- 0 until m) {
|
||||||
|
|||||||
@@ -30,28 +30,6 @@ object Main {
|
|||||||
// get the users XML file
|
// get the users XML file
|
||||||
|
|
||||||
val users = df("users")
|
val users = df("users")
|
||||||
val centres = KMeans.train(users, 25)
|
val centres = KMeans.train(users, 50)
|
||||||
//val centresArray = centres.collect()
|
|
||||||
//val unwrap = centresArray.map(x => x._2)
|
|
||||||
//unwrap.foreach(println)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//val users = dataFrames("users")
|
|
||||||
|
|
||||||
/*val dataFrames = DataParser.ParseData()
|
|
||||||
|
|
||||||
// get the users XML file
|
|
||||||
val users = dataFrames("users")
|
|
||||||
users.persist()
|
|
||||||
// Show 20 entries from the user dataset
|
|
||||||
users.show()
|
|
||||||
// Show types for the user dataset
|
|
||||||
users.printSchema()
|
|
||||||
users.show()
|
|
||||||
|
|
||||||
// create new dataframe with only the reputation of the users
|
|
||||||
users.select("CreationDate").show()
|
|
||||||
*/
|
|
||||||
// Info on using DataFrames here: https://www.mapr.com/blog/using-apache-spark-dataframes-processing-tabular-data
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ object XMLParser {
|
|||||||
("postHistory", "/data/stackoverflow/PostHistory","Id PostHistoryTypeId PostId RevisionGUID CreationDate UserId UserDisplayName Comment Text CloseReasonId", Array[DataType](IntegerType, IntegerType, IntegerType,IntegerType, DateType, IntegerType, StringType, StringType, StringType, IntegerType)),
|
("postHistory", "/data/stackoverflow/PostHistory","Id PostHistoryTypeId PostId RevisionGUID CreationDate UserId UserDisplayName Comment Text CloseReasonId", Array[DataType](IntegerType, IntegerType, IntegerType,IntegerType, DateType, IntegerType, StringType, StringType, StringType, IntegerType)),
|
||||||
("postLinks", "data/stackoverflow/PostLinks", "Id CreationDate PostId RelatedPostId PostLinkTypeId", Array[DataType](IntegerType, DateType, IntegerType, IntegerType, IntegerType)),
|
("postLinks", "data/stackoverflow/PostLinks", "Id CreationDate PostId RelatedPostId PostLinkTypeId", Array[DataType](IntegerType, DateType, IntegerType, IntegerType, IntegerType)),
|
||||||
*/
|
*/
|
||||||
("users", "stackoverflow_dataset/users.txt", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType))
|
("users", "data/stackoverflow/Users", "Reputation CreationDate DisplayName EmailHash LastAccessDate WebsiteUrl Location Age AboutMe Views UpVotes DownVotes", Array[DataType](IntegerType, DateType, StringType, StringType, DateType, StringType, StringType, IntegerType, StringType, IntegerType, IntegerType, IntegerType))
|
||||||
/*
|
/*
|
||||||
("votes", "/data/stackoverflow/Votes", "Id PostId VoteTypeId UserId CreationDate", Array[DataType](IntegerType, IntegerType, IntegerType, IntegerType, DateType))
|
("votes", "/data/stackoverflow/Votes", "Id PostId VoteTypeId UserId CreationDate", Array[DataType](IntegerType, IntegerType, IntegerType, IntegerType, DateType))
|
||||||
*/
|
*/
|
||||||
|
|||||||
Reference in New Issue
Block a user