Package de.jungblut.reader

Source Code of de.jungblut.reader.MNISTReader

package de.jungblut.reader;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;

/**
* MNIST CSV reader from kaggle: www.kaggle.com/c/digit-recognizer/
*
* @author thomas.jungblut
*
*/
public final class MNISTReader {

  private MNISTReader() {
    throw new IllegalAccessError();
  }

  public static Dataset readMNISTTrainImages(String path) {
    List<DoubleVector> features = new ArrayList<>();
    List<DenseDoubleVector> prediction = new ArrayList<>();
    String line = null;
    int numLine = 0;
    try (BufferedReader br = new BufferedReader(new FileReader(path))) {

      while ((line = br.readLine()) != null) {
        if (numLine == 0) {
          numLine++;
          continue;
        }

        String[] split = line.split(",");
        DenseDoubleVector featureVector = new DenseDoubleVector(
            split.length - 1);
        for (int i = 1; i < split.length; i++) {
          featureVector.set(i - 1, Integer.parseInt(split[i]));
        }

        DenseDoubleVector predVector = new DenseDoubleVector(10);
        predVector.set(Integer.parseInt(split[0]), 1.0d);

        features.add(featureVector);
        prediction.add(predVector);
        numLine++;
      }
    } catch (IOException e) {
      e.printStackTrace();
    }

    return new Dataset(features.toArray(new DoubleVector[features.size()]),
        prediction.toArray(new DenseDoubleVector[features.size()]));
  }
}
TOP

Related Classes of de.jungblut.reader.MNISTReader

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.