Package com.sketchingbits.mlearning.clustering

Source Code of com.sketchingbits.mlearning.clustering.KMeans

/**
*
*/
package com.sketchingbits.mlearning.clustering;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.log4j.Logger;
import org.ejml.simple.SimpleMatrix;

public class KMeans {

  private final static Logger LOGGER = Logger.getLogger(KMeansTest.class.getName());
     
  private int noTrainingExamples;
  private int noFeatures;
  private double[][] trainingSet;
  private boolean isClassification;
 
  public class Centroid {
    public double error;
    public double[] features;
    public List<Integer> trainingExamples;
  }
  private List<Centroid> centroids;
 
  /**
   * Private constructor that contains generic functionality
   * @param trainingSet Training examples
   */
  private KMeans(double[][] trainingSet, boolean isClassification) {
    this.trainingSet = trainingSet;   
    this.noTrainingExamples = trainingSet.length;
    this.noFeatures = trainingSet[0].length;
    this.isClassification = isClassification;
  }
 
  /**
   * Constructor that selects a number of random centroids
   * @param dataset Dataset that contains all the samples
   * @param noCentroids Number of centroids to be used
   */
  public KMeans(double[][] trainingSet, boolean isClassification, int noCentroids) {
   
    // Initialize the training set
    this(trainingSet,isClassification);
       
    this.centroids = new ArrayList<Centroid>(noCentroids);

    // Let's pick some initial centroids randomly from the dataset
    // and put them into a Set to make sure to get unique items
    List<Integer> centroidList = new ArrayList<Integer>();
   
    // This is the initial list of potential centroids
    Set<Integer> trainingExamplesSet = new HashSet<Integer>();
    for (int index = 0; index < noTrainingExamples; index++) {
      trainingExamplesSet.add(index);
    }
   
    while (centroidList.size() < noCentroids && trainingExamplesSet.size() > 0) {     
     
      // We pick a centroid and remove it from the Set
      int trainingExample = (int) (Math.random() * ((trainingExamplesSet.size() - 1) + 1));
      trainingExamplesSet.remove(trainingExample);
     
      LOGGER.debug("\tChecking training example "+ trainingExample + " as a cluster.");
      boolean isAllDifferent = true;
      for (Integer centroid: centroidList) {
        boolean isDifferent = false;
        for (int i = 0; i < trainingSet[0].length; i++) {
          if (trainingSet[trainingExample][i] != trainingSet[centroid][i]) {
            isDifferent = true;
            break;
          }
        }
        if (!isDifferent) {
          isAllDifferent = false;
          break;
        }
      }
      if (isAllDifferent) {
        Centroid centroid1 = new Centroid();
        centroid1.features = trainingSet[trainingExample];     
        this.centroids.add(centroid1);
        centroidList.add(trainingExample);
        LOGGER.debug("\tSelected training example "+ trainingExample + " as a cluster.");
      }
    } 
    if (trainingExamplesSet.size() < 1) {
      LOGGER.equals("Cannot selected " + noCentroids + " different centroids. Try to reduce the number of centroids.");
      throw new java.util.InputMismatchException();
    }
  }

  /**
   * Constructor with predetermined centroids
   * @param dataset
   * @param centroids
   */
  public KMeans(double[][] trainingSet, boolean isClassification, int... centroids) {

    // Initialize the training set
    this(trainingSet,isClassification);

    this.centroids = new ArrayList<Centroid>(centroids.length);

    for (int centroid : centroids) {
      Centroid centroid1 = new Centroid();
      centroid1.features = trainingSet[centroid];     
      this.centroids.add(centroid1);
    }
  }
 
  // Clustering using a non optimized loop over the samples
  public List<Centroid> run() {
   
    LOGGER.info("Start the clustering for the training set");
    int noCentroids = centroids.size()
    LOGGER.info("Training examples: " + noTrainingExamples + " - Features: " + noFeatures + " - Clusters: " + noCentroids);

    // We store the values in a SimpleMatrix structure
    SimpleMatrix datasetMatrix = new SimpleMatrix(trainingSet);
   
    // Let's convert the centroids into an array of SimpleMatrix
    SimpleMatrix[] centroidMatrix = new SimpleMatrix[noCentroids];
    for (int index=0; index < noCentroids; index++) {
      double[][] centroidTemp = {centroids.get(index).features};
      centroidMatrix[index] = new SimpleMatrix(centroidTemp);
    }
   
    // Initialize variables to be used in Step 1
    int selectedCentroid = 0;
    double distanceWithCentroid = 0.0;       
    double minDistanceWithCentroid = 0.0;
    SimpleMatrix distanceMatrix = null;
   
    int iteration = 0;
    boolean isClusteringDone =false;
    while (!isClusteringDone) {
     
      LOGGER.debug("\tIteration " + iteration++);
     
      // Clear the centroid assignment     
      for (Centroid centroid : centroids) {
        if (centroid.trainingExamples != null)
          centroid.trainingExamples.clear();
        else
          centroid.trainingExamples = new ArrayList<Integer>();
      }   
     
      // Step 1: Calculate the distance between each value and the centroids
      // and assign each sample to the closest centroid
      for (int sample = 0; sample < noTrainingExamples; sample++) {
       
        SimpleMatrix vector = datasetMatrix.extractVector(true, sample);
       
        for (int centroid = 0; centroid < noCentroids; centroid++) {
          distanceMatrix = vector.minus(centroidMatrix[centroid]);
          distanceMatrix = distanceMatrix.elementMult(distanceMatrix);
          distanceWithCentroid = distanceMatrix.elementSum() / noFeatures;
         
          if (centroid == 0) {
            selectedCentroid = 0;
            minDistanceWithCentroid = distanceWithCentroid;
          } else {
            if (distanceWithCentroid < minDistanceWithCentroid) {
              minDistanceWithCentroid = distanceWithCentroid;
              selectedCentroid = centroid;
            }
          }
        }       
        centroids.get(selectedCentroid).trainingExamples.add(new Integer(sample));
      }
     
      // Step 2: We move the centroid to the average of all samples
      // assigned to it         
      isClusteringDone = true;
      for (int centroid = 0; centroid < noCentroids; centroid++) {
       
        SimpleMatrix averageVector = new SimpleMatrix(1,noFeatures);
        for (Integer sample : (ArrayList<Integer>) centroids.get(centroid).trainingExamples) {
          averageVector = averageVector.plus(datasetMatrix.extractVector(true, sample));
        }
        averageVector = averageVector.divide(centroids.get(centroid).trainingExamples.size());
       
        if (isClassification) {
          for( int i = 0; i < averageVector.getNumElements(); i++ )
            averageVector.set(i,Math.round(averageVector.get(i)));
        }
       
        if (!centroidMatrix[centroid].isIdentical(averageVector,0)) {
          centroidMatrix[centroid] = averageVector;         
          isClusteringDone = false;
        }
      }
    }
   
    double[][] errorMatrix = new double[noTrainingExamples][noFeatures];
    for (int centroid = 0; centroid < noCentroids; centroid++) {
      for (int feature = 0; feature < noFeatures; feature++) {
        centroids.get(centroid).features[feature] = centroidMatrix[centroid].get(feature);
      }
      for (int trainingExample = 0; trainingExample < centroids.get(centroid).trainingExamples.size(); trainingExample++) {
        errorMatrix[(centroids.get(centroid).trainingExamples.get(trainingExample))] = centroids.get(centroid).features;
      }
    }   
    SimpleMatrix errorSimpleMatrix = new SimpleMatrix(errorMatrix).minus(datasetMatrix);
    errorSimpleMatrix = errorSimpleMatrix.mult(errorSimpleMatrix.transpose());
    double error = errorSimpleMatrix.elementSum() / (noFeatures * noTrainingExamples);
   
    LOGGER.info("Clustering completed - Error: " + Math.floor(error * 100) + "%");     

    return centroids;
  }
}
TOP

Related Classes of com.sketchingbits.mlearning.clustering.KMeans

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.