/**
*
*/
package com.sketchingbits.mlearning.clustering;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.log4j.Logger;
import org.ejml.simple.SimpleMatrix;
public class KMeans {
private final static Logger LOGGER = Logger.getLogger(KMeansTest.class.getName());
private int noTrainingExamples;
private int noFeatures;
private double[][] trainingSet;
private boolean isClassification;
public class Centroid {
public double error;
public double[] features;
public List<Integer> trainingExamples;
}
private List<Centroid> centroids;
/**
* Private constructor that contains generic functionality
* @param trainingSet Training examples
*/
private KMeans(double[][] trainingSet, boolean isClassification) {
this.trainingSet = trainingSet;
this.noTrainingExamples = trainingSet.length;
this.noFeatures = trainingSet[0].length;
this.isClassification = isClassification;
}
/**
* Constructor that selects a number of random centroids
* @param dataset Dataset that contains all the samples
* @param noCentroids Number of centroids to be used
*/
public KMeans(double[][] trainingSet, boolean isClassification, int noCentroids) {
// Initialize the training set
this(trainingSet,isClassification);
this.centroids = new ArrayList<Centroid>(noCentroids);
// Let's pick some initial centroids randomly from the dataset
// and put them into a Set to make sure to get unique items
List<Integer> centroidList = new ArrayList<Integer>();
// This is the initial list of potential centroids
Set<Integer> trainingExamplesSet = new HashSet<Integer>();
for (int index = 0; index < noTrainingExamples; index++) {
trainingExamplesSet.add(index);
}
while (centroidList.size() < noCentroids && trainingExamplesSet.size() > 0) {
// We pick a centroid and remove it from the Set
int trainingExample = (int) (Math.random() * ((trainingExamplesSet.size() - 1) + 1));
trainingExamplesSet.remove(trainingExample);
LOGGER.debug("\tChecking training example "+ trainingExample + " as a cluster.");
boolean isAllDifferent = true;
for (Integer centroid: centroidList) {
boolean isDifferent = false;
for (int i = 0; i < trainingSet[0].length; i++) {
if (trainingSet[trainingExample][i] != trainingSet[centroid][i]) {
isDifferent = true;
break;
}
}
if (!isDifferent) {
isAllDifferent = false;
break;
}
}
if (isAllDifferent) {
Centroid centroid1 = new Centroid();
centroid1.features = trainingSet[trainingExample];
this.centroids.add(centroid1);
centroidList.add(trainingExample);
LOGGER.debug("\tSelected training example "+ trainingExample + " as a cluster.");
}
}
if (trainingExamplesSet.size() < 1) {
LOGGER.equals("Cannot selected " + noCentroids + " different centroids. Try to reduce the number of centroids.");
throw new java.util.InputMismatchException();
}
}
/**
* Constructor with predetermined centroids
* @param dataset
* @param centroids
*/
public KMeans(double[][] trainingSet, boolean isClassification, int... centroids) {
// Initialize the training set
this(trainingSet,isClassification);
this.centroids = new ArrayList<Centroid>(centroids.length);
for (int centroid : centroids) {
Centroid centroid1 = new Centroid();
centroid1.features = trainingSet[centroid];
this.centroids.add(centroid1);
}
}
// Clustering using a non optimized loop over the samples
public List<Centroid> run() {
LOGGER.info("Start the clustering for the training set");
int noCentroids = centroids.size();
LOGGER.info("Training examples: " + noTrainingExamples + " - Features: " + noFeatures + " - Clusters: " + noCentroids);
// We store the values in a SimpleMatrix structure
SimpleMatrix datasetMatrix = new SimpleMatrix(trainingSet);
// Let's convert the centroids into an array of SimpleMatrix
SimpleMatrix[] centroidMatrix = new SimpleMatrix[noCentroids];
for (int index=0; index < noCentroids; index++) {
double[][] centroidTemp = {centroids.get(index).features};
centroidMatrix[index] = new SimpleMatrix(centroidTemp);
}
// Initialize variables to be used in Step 1
int selectedCentroid = 0;
double distanceWithCentroid = 0.0;
double minDistanceWithCentroid = 0.0;
SimpleMatrix distanceMatrix = null;
int iteration = 0;
boolean isClusteringDone =false;
while (!isClusteringDone) {
LOGGER.debug("\tIteration " + iteration++);
// Clear the centroid assignment
for (Centroid centroid : centroids) {
if (centroid.trainingExamples != null)
centroid.trainingExamples.clear();
else
centroid.trainingExamples = new ArrayList<Integer>();
}
// Step 1: Calculate the distance between each value and the centroids
// and assign each sample to the closest centroid
for (int sample = 0; sample < noTrainingExamples; sample++) {
SimpleMatrix vector = datasetMatrix.extractVector(true, sample);
for (int centroid = 0; centroid < noCentroids; centroid++) {
distanceMatrix = vector.minus(centroidMatrix[centroid]);
distanceMatrix = distanceMatrix.elementMult(distanceMatrix);
distanceWithCentroid = distanceMatrix.elementSum() / noFeatures;
if (centroid == 0) {
selectedCentroid = 0;
minDistanceWithCentroid = distanceWithCentroid;
} else {
if (distanceWithCentroid < minDistanceWithCentroid) {
minDistanceWithCentroid = distanceWithCentroid;
selectedCentroid = centroid;
}
}
}
centroids.get(selectedCentroid).trainingExamples.add(new Integer(sample));
}
// Step 2: We move the centroid to the average of all samples
// assigned to it
isClusteringDone = true;
for (int centroid = 0; centroid < noCentroids; centroid++) {
SimpleMatrix averageVector = new SimpleMatrix(1,noFeatures);
for (Integer sample : (ArrayList<Integer>) centroids.get(centroid).trainingExamples) {
averageVector = averageVector.plus(datasetMatrix.extractVector(true, sample));
}
averageVector = averageVector.divide(centroids.get(centroid).trainingExamples.size());
if (isClassification) {
for( int i = 0; i < averageVector.getNumElements(); i++ )
averageVector.set(i,Math.round(averageVector.get(i)));
}
if (!centroidMatrix[centroid].isIdentical(averageVector,0)) {
centroidMatrix[centroid] = averageVector;
isClusteringDone = false;
}
}
}
double[][] errorMatrix = new double[noTrainingExamples][noFeatures];
for (int centroid = 0; centroid < noCentroids; centroid++) {
for (int feature = 0; feature < noFeatures; feature++) {
centroids.get(centroid).features[feature] = centroidMatrix[centroid].get(feature);
}
for (int trainingExample = 0; trainingExample < centroids.get(centroid).trainingExamples.size(); trainingExample++) {
errorMatrix[(centroids.get(centroid).trainingExamples.get(trainingExample))] = centroids.get(centroid).features;
}
}
SimpleMatrix errorSimpleMatrix = new SimpleMatrix(errorMatrix).minus(datasetMatrix);
errorSimpleMatrix = errorSimpleMatrix.mult(errorSimpleMatrix.transpose());
double error = errorSimpleMatrix.elementSum() / (noFeatures * noTrainingExamples);
LOGGER.info("Clustering completed - Error: " + Math.floor(error * 100) + "%");
return centroids;
}
}