Package de.jungblut.classification.knn

Source Code of de.jungblut.classification.knn.AbstractKNearestNeighbours

package de.jungblut.classification.knn;

import java.util.List;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.datastructure.KDTree.VectorDistanceTuple;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.dense.SingleEntryDoubleVector;

/**
* K nearest neighbour classification algorithm that is seeded with a "database"
* of known examples and predicts based on the k-nearest neighbours majority
* vote for a class.
*
*/
public abstract class AbstractKNearestNeighbours extends AbstractClassifier {

  protected final int numOutcomes;
  protected final int k;

  /**
   * Constructs a new knn classifier.
   *
   * @param numOutcomes the number of different outcomes that can be predicted.
   * @param k the number of neighbours to analyse to get a prediction (it does
   *          so by majority voting).
   */
  public AbstractKNearestNeighbours(int numOutcomes, int k) {
    this.numOutcomes = numOutcomes;
    this.k = k;
  }

  /**
   * @return If the number of outcomes is 2 (binary prediction) the returned
   *         vector contains the class id (0 or 1) at the first index. If not, a
   *         histogram of the classes that were predicted.
   */
  @Override
  public DoubleVector predict(DoubleVector features) {
    List<VectorDistanceTuple<DoubleVector>> nearestNeighbours = getNearestNeighbours(
        features, k);

    DenseDoubleVector outcomeHistogram = new DenseDoubleVector(numOutcomes);
    for (VectorDistanceTuple<DoubleVector> tuple : nearestNeighbours) {
      int classIndex = 0;
      if (numOutcomes == 2) {
        classIndex = (int) tuple.getValue().get(0);
      } else {
        classIndex = tuple.getValue().maxIndex();
      }

      outcomeHistogram.set(classIndex, outcomeHistogram.get(classIndex) + 1);
    }
    if (numOutcomes == 2) {
      return new SingleEntryDoubleVector(outcomeHistogram.maxIndex());
    } else {
      return outcomeHistogram;
    }
  }

  @Override
  public DoubleVector predictProbability(DoubleVector features) {
    DoubleVector prediction = predict(features);
    if (numOutcomes != 2) {
      prediction = prediction.divide(prediction.sum());
    }
    return prediction;
  }

  /**
   * Find the k nearest neighbours for the given feature.
   *
   * @param feature the feature to find neighbours for.
   * @param k the number of neighbours to find.
   * @return a list of {@link VectorDistanceTuple}'s that contain the outcome of
   *         the retrieved vectors.
   */
  protected abstract List<VectorDistanceTuple<DoubleVector>> getNearestNeighbours(
      DoubleVector feature, int k);

}
TOP

Related Classes of de.jungblut.classification.knn.AbstractKNearestNeighbours

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.