Package de.jungblut.clustering

Source Code of de.jungblut.clustering.KMeansClusteringTest

package de.jungblut.clustering;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.util.ArrayList;
import java.util.List;

import org.junit.Test;

import de.jungblut.distance.EuclidianDistance;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;

public class KMeansClusteringTest {

  @Test
  public void testKMeansClustering() {
    ArrayList<DoubleVector> lst = getClusteringInput();

    KMeansClustering clusterer = new KMeansClustering(2, lst, false);
    EuclidianDistance dist = new EuclidianDistance();
    List<Cluster> assignments = clusterer.cluster(100, dist, 0.1d, false);
    DoubleVector[] centers = clusterer.getCenters();
    assertEquals(49.5, centers[0].get(0), 1e-4);
    assertEquals(24.5, centers[0].get(1), 1e-4);
    assertEquals(49.5, centers[1].get(0), 1e-4);
    assertEquals(74.5, centers[1].get(1), 1e-4);

    // the centers should partition the space in half
    assertEquals(2, assignments.size());
    assertEquals(5000, assignments.get(0).getAssignments().size());
    assertEquals(5000, assignments.get(1).getAssignments().size());
    // now verify the assignments
    for (DoubleVector v : assignments.get(0).getAssignments()) {
      double distRightCenter = dist.measureDistance(v, centers[0]);
      double distOtherCenter = dist.measureDistance(v, centers[1]);
      assertTrue(distRightCenter < distOtherCenter);
    }
    for (DoubleVector v : assignments.get(1).getAssignments()) {
      double distRightCenter = dist.measureDistance(v, centers[1]);
      double distOtherCenter = dist.measureDistance(v, centers[0]);
      assertTrue(distRightCenter < distOtherCenter);
    }
  }

  public static ArrayList<DoubleVector> getClusteringInput() {
    // we are "sampling" a 100x100 grid to a vector space and let's do some
    // clustering.
    return getClusteringInput(100, 100);
  }

  public static ArrayList<DoubleVector> getClusteringInput(int x, int y) {
    ArrayList<DoubleVector> lst = new ArrayList<>();
    for (int i = 0; i < x; i++) {
      for (int j = 0; j < y; j++) {
        lst.add(new DenseDoubleVector(new double[] { i, j }));
      }
    }
    return lst;
  }

}
TOP

Related Classes of de.jungblut.clustering.KMeansClusteringTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.