Package de.jungblut.classification.knn

Source Code of de.jungblut.classification.knn.SparseKNearestNeighboursTest

package de.jungblut.classification.knn;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

import de.jungblut.distance.CosineDistance;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;

public class SparseKNearestNeighboursTest {

  @Test
  public void testSparseKNN() {

    SparseKNearestNeighbours neighbours = new SparseKNearestNeighbours(2, 2,
        new CosineDistance());

    // we seperate stuff in two dimensions each
    DoubleVector left = new SingleEntryDoubleVector(0d);
    DoubleVector right = new SingleEntryDoubleVector(1d);
    DoubleVector v1 = new SparseDoubleVector(4);
    v1.set(0, 1d);
    v1.set(1, 1d);

    DoubleVector v2 = new SparseDoubleVector(4);
    v2.set(2, 1d);
    v2.set(3, 2.5);

    DoubleVector v3 = new SparseDoubleVector(4);
    v3.set(0, 2d);
    v3.set(1, 2d);

    DoubleVector v4 = new SparseDoubleVector(4);
    v4.set(2, 0.5);
    v4.set(3, 1.5);

    DoubleVector[] trainingSet = new DoubleVector[] { v1, v2, v3, v4 };
    DoubleVector[] outcomeSet = new DoubleVector[] { left, right, left, right };

    neighbours.train(trainingSet, outcomeSet);

    DoubleVector predict = neighbours.predict(v4);
    assertEquals(right, predict);

    predict = neighbours.predict(v2);
    assertEquals(right, predict);

    predict = neighbours.predict(v1);
    assertEquals(left, predict);

    predict = neighbours.predict(v3);
    assertEquals(left, predict);

    // predict between, slightly to the right
    DoubleVector vx = new SparseDoubleVector(4);
    vx.set(1, 1d);
    vx.set(3, 2.5);

    predict = neighbours.predict(vx);
    assertEquals(right, predict);
  }

}
TOP

Related Classes of de.jungblut.classification.knn.SparseKNearestNeighboursTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.