package org.gd.spark.opendl.example;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.gd.spark.opendl.downpourSGD.SampleVector;
public class DataInput {
private static Random rand = new Random(System.currentTimeMillis());
/**
* Read sample data from mnist_784 text file
* @param path
* @return
* @throws Exception
*/
public static List<SampleVector> readMnist(String path, int x_feature, int y_feature) throws Exception {
List<SampleVector> ret = new ArrayList<SampleVector>();
String str = null;
BufferedReader br = new BufferedReader(new FileReader(path));
while (null != (str = br.readLine())) {
String[] splits = str.split(",");
SampleVector xy = new SampleVector(x_feature, y_feature);
xy.getY()[Integer.valueOf(splits[0])] = 1;
for (int i = 1; i < splits.length; i++) {
xy.getX()[i - 1] = Double.valueOf(splits[i]);
}
ret.add(xy);
}
br.close();
return ret;
}
/**
* Parallelize list to RDD
* @param context
* @param list
* @return
* @throws Exception
*/
public static JavaRDD<SampleVector> toRDD(JavaSparkContext context, List<SampleVector> list) throws Exception {
return context.parallelize(list);
}
/**
* Split total list read from file to train and test part
* @param totalList
* @param trainList
* @param testList
* @param trainRatio
*/
public static void splitList(List<SampleVector> totalList, List<SampleVector> trainList, List<SampleVector> testList, double trainRatio) {
for (SampleVector sample : totalList) {
if (rand.nextDouble() <= trainRatio) {
trainList.add(sample);
}
else {
testList.add(sample);
}
}
}
}