Package org.gd.spark.opendl.example

Source Code of org.gd.spark.opendl.example.DataInput

package org.gd.spark.opendl.example;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.gd.spark.opendl.downpourSGD.SampleVector;

public class DataInput {
  private static Random rand = new Random(System.currentTimeMillis());
 
  /**
   * Read sample data from mnist_784 text file
   * @param path
   * @return
   * @throws Exception
   */
  public static List<SampleVector> readMnist(String path, int x_feature, int y_feature) throws Exception {
    List<SampleVector> ret = new ArrayList<SampleVector>();
    String str = null;
        BufferedReader br = new BufferedReader(new FileReader(path));
        while (null != (str = br.readLine())) {
            String[] splits = str.split(",");
            SampleVector xy = new SampleVector(x_feature, y_feature);
            xy.getY()[Integer.valueOf(splits[0])] = 1;
            for (int i = 1; i < splits.length; i++) {
                xy.getX()[i - 1] = Double.valueOf(splits[i]);
            }
            ret.add(xy);
        }
        br.close();
    return ret;
  }

  /**
   * Parallelize list to RDD
   * @param context
   * @param list
   * @return
   * @throws Exception
   */
  public static JavaRDD<SampleVector> toRDD(JavaSparkContext context, List<SampleVector> list) throws Exception {
    return context.parallelize(list);
  }
 
  /**
   * Split total list read from file to train and test part
   * @param totalList
   * @param trainList
   * @param testList
   * @param trainRatio
   */
  public static void splitList(List<SampleVector> totalList, List<SampleVector> trainList, List<SampleVector> testList, double trainRatio) {
    for (SampleVector sample : totalList) {
      if (rand.nextDouble() <= trainRatio) {
        trainList.add(sample);
      }
      else {
        testList.add(sample);
      }
    }
  }
}
TOP

Related Classes of org.gd.spark.opendl.example.DataInput

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.