* @param distanceMeasure distance measure to use
* @return the confusion matrix
*/
public static Matrix getConfusionMatrix(List<? extends Vector> rowCentroids, List<? extends Vector> columnCentroids,
Iterable<? extends Vector> datapoints, DistanceMeasure distanceMeasure) {
Searcher rowSearcher = new BruteSearch(distanceMeasure);
rowSearcher.addAll(rowCentroids);
Searcher columnSearcher = new BruteSearch(distanceMeasure);
columnSearcher.addAll(columnCentroids);
int numRows = rowCentroids.size();
int numCols = columnCentroids.size();
Matrix confusionMatrix = new DenseMatrix(numRows, numCols);
for (Vector vector : datapoints) {
WeightedThing<Vector> closestRowCentroid = rowSearcher.search(vector, 1).get(0);
WeightedThing<Vector> closestColumnCentroid = columnSearcher.search(vector, 1).get(0);
int row = ((Centroid) closestRowCentroid.getValue()).getIndex();
int column = ((Centroid) closestColumnCentroid.getValue()).getIndex();
double vectorWeight;
if (vector instanceof WeightedVector) {
vectorWeight = ((WeightedVector) vector).getWeight();