Package de.jungblut.reader

Source Code of de.jungblut.reader.IrisReader

package de.jungblut.reader;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;

import com.google.common.base.Preconditions;

import de.jungblut.classification.eval.EvaluationSplit;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;

/**
* Dataset vectorizer for the iris dataset. Following outcome indices are used:
* Iris-setosa = 0, Iris-versicolor = 1, Iris-virginica = 2.
*/
public final class IrisReader {

  private IrisReader() {
    throw new IllegalAccessError();
  }

  /**
   * @return a tuple, on first dimension are the features, on the second are the
   *         outcomes (0 or 1 in the first element of a vector)
   */
  public static Dataset readIrisDataset(String path) {
    DoubleVector[] features = new DoubleVector[150];
    DenseDoubleVector[] outcome = new DenseDoubleVector[150];

    try (BufferedReader br = new BufferedReader(new FileReader(path))) {
      String line = null;
      int index = 0;
      while ((line = br.readLine()) != null) {
        String[] split = line.split(",");
        Preconditions.checkArgument(split.length == 5,
            "CSV length was not 5! Given " + split.length);
        features[index] = new DenseDoubleVector(4);
        for (int i = 0; i < split.length - 1; i++) {
          features[index].set(i, Double.parseDouble(split[i]));
        }
        if (split[split.length - 1].equals("Iris-setosa")) {
          outcome[index] = new DenseDoubleVector(new double[] { 1, 0, 0 });
        } else if (split[split.length - 1].equals("Iris-versicolor")) {
          outcome[index] = new DenseDoubleVector(new double[] { 0, 1, 0 });
        } else {
          outcome[index] = new DenseDoubleVector(new double[] { 0, 0, 1 });
        }
        index++;
      }
    } catch (IOException e) {
      e.printStackTrace();
    }

    return new Dataset(features, outcome);
  }

  /**
   * @return a split that has 40 train examples for every class and 10 test
   *         items.
   */
  public static EvaluationSplit getEvaluationSplit(Dataset irisData) {

    DoubleVector[] trainFeatures = new DenseDoubleVector[3 * 40];
    DoubleVector[] trainOutcomes = new DenseDoubleVector[3 * 40];
    DoubleVector[] testFeatures = new DenseDoubleVector[3 * 10];
    DoubleVector[] testOutcomes = new DenseDoubleVector[3 * 10];

    DoubleVector[] features = irisData.getFeatures();
    DoubleVector[] outcomes = irisData.getOutcomes();

    int[] trainSplitPoints = new int[] { 40, 90, 140 };
    int[] trainDestStartPoints = new int[] { 0, 40, 80 };
    int[] testDestStartPoints = new int[] { 0, 10, 20 };

    for (int i = 0; i < 3; i++) {
      int splitPoint = trainSplitPoints[i];
      System.arraycopy(features, splitPoint - 40, trainFeatures,
          trainDestStartPoints[i], 40);
      System.arraycopy(outcomes, splitPoint - 40, trainOutcomes,
          trainDestStartPoints[i], 40);

      System.arraycopy(features, splitPoint, testFeatures,
          testDestStartPoints[i], 10);
      System.arraycopy(outcomes, splitPoint, testOutcomes,
          testDestStartPoints[i], 10);
    }

    return new EvaluationSplit(trainFeatures, trainOutcomes, testFeatures,
        testOutcomes);
  }

}
TOP

Related Classes of de.jungblut.reader.IrisReader

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.