Package tv.floe.metronome.deeplearning.datasets

Source Code of tv.floe.metronome.deeplearning.datasets.DataSet

package tv.floe.metronome.deeplearning.datasets;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import tv.floe.metronome.math.MatrixUtils;
import tv.floe.metronome.types.Pair;

import com.google.common.collect.Lists;

/**
* A data set (example/outcome pairs)
* The outcomes are specifically for neural network encoding such that
* any labels that are considered true are 1s. The rest are zeros.
* @author Adam Gibson
*
*/
public class DataSet extends Pair<Matrix, Matrix> {

  private static final long serialVersionUID = 1935520764586513365L;
  private static Logger log = LoggerFactory.getLogger(DataSet.class);

  public DataSet(Pair<Matrix,Matrix> pair) {
    this(pair.getFirst(),pair.getSecond());
  }

  public DataSet(Matrix first, Matrix second) {
    super(first, second);
  }

  public DataSet copy() {
    return new DataSet(getFirst(),getSecond());
  }
 
  public static DataSet merge(List<DataSet> data) {
   
    if (data.isEmpty()) {
   
      throw new IllegalArgumentException("Unable to merge empty dataset");
     
    }
   
    DataSet first = data.iterator().next();

    Matrix inputData = new DenseMatrix( data.size(), first.getFirst().numCols() );
    Matrix outputLabels = new DenseMatrix( data.size(), first.getSecond().numCols() );

   
    // samples.assignRow(x, input.viewRow(x) );

    for (int i = 0; i < data.size(); i++) {
     
      //in.putRow(i,data.get(i).getFirst());
      inputData.assignRow( i, data.get(i).getFirst().viewRow(0) );
      //out.putRow(i,data.get(i).getSecond());
      outputLabels.assignRow( i, data.get(i).getSecond().viewRow(0) );

    }
   
    return new DataSet( inputData, outputLabels );
   
  }

  public void validate() {
    if(getFirst().numRows() != getSecond().numRows())
      throw new IllegalStateException("Invalid dataset");
  }
 
  public int outcome() {
    if(this.numExamples() > 1)
      throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
    //return SimpleBlas.iamax(getSecond());
    return MatrixUtils.iamax( getSecond().viewRow(0) );
  }

  public DataSet get(int i) {
    //return new DataSet(getFirst().getRow(i),getSecond().getRow(i));
   
    Matrix inputRow = getFirst().viewPart(i, 1, 0, getFirst().numCols());
    Matrix labelRow = getSecond().viewPart(i, 1, 0, getSecond().numCols());
   
    return new DataSet( inputRow, labelRow );
  }
 
  public List<List<DataSet>> batchBy(int num) {
   
    return Lists.partition(asList(),num);
   
  }
 
 
 

  /**
   * Sorts the dataset by label:
   * Splits the data set such that examples are sorted by their labels.
   * A ten label dataset would produce lists with batches like the following:
   * x1   y = 1
   * x2   y = 2
   * ...
   * x10  y = 10
   * @return a list of data sets partitioned by outcomes
   */
  public List<List<DataSet>> sortAndBatchByNumLabels() {
    sortByLabel();
    return Lists.partition(asList(),numOutcomes());
  }

  public List<List<DataSet>> batchByNumLabels() {
    return Lists.partition(asList(),numOutcomes());
  }


  public List<DataSet> asList() {
   
    List<DataSet> list = new ArrayList<DataSet>(numExamples());
   
    for (int i = 0; i < numExamples(); i++)  {
     
      //list.add(new DataSet(getFirst().getRow(i),getSecond().getRow(i)));
     
      //Matrix inputRow = getFirst().viewPart(i, 1, 0, getFirst().numCols());
      Matrix inputRow = MatrixUtils.viewRowAsMatrix( getFirst(), i );
      //Matrix labelRow = getSecond().viewPart(i, 1, 0, getSecond().numCols());
      Matrix labelRow = MatrixUtils.viewRowAsMatrix( getSecond(), i );
     
      list.add( new DataSet( inputRow, labelRow ) );
     
     
    }
   
    return list;
  }

  public Pair<DataSet,DataSet> splitTestAndTrain(int numHoldout) {

    if(numHoldout >= numExamples())
      throw new IllegalArgumentException("Unable to split on size larger than the number of rows");


    List<DataSet> list = asList();

    Collections.rotate(list, 3);
    Collections.shuffle(list);
    List<List<DataSet>> partition = new ArrayList<List<DataSet>>();
    partition.add(list.subList(0, numHoldout));
    partition.add(list.subList(numHoldout, list.size()));
    DataSet train = merge(partition.get(0));
    DataSet test = merge(partition.get(1));
    return new Pair<DataSet, DataSet>(train,test);
  }

  /**
   * Organizes the dataset to minimize sampling error
   * while still allowing efficient batching.
   */
  public void sortByLabel() {
    Map<Integer,Queue<DataSet>> map = new HashMap<Integer,Queue<DataSet>>();
    List<DataSet> data = asList();
    int numLabels = numOutcomes();
    int examples = numExamples();
    for(DataSet d : data) {
      int label = getLabel(d);
      Queue<DataSet> q = map.get(label);
      if(q == null) {
        q = new ArrayDeque<DataSet>();
        map.put(label, q);
      }
      q.add(d);
    }

    for(Integer label : map.keySet()) {
      log.info("Label " + label + " has " + map.get(label).size() + " elements");
    }

    //ideal input splits: 1 of each label in each batch
    //after we run out of ideal batches: fall back to a new strategy
    boolean optimal = true;
    for(int i = 0; i < examples; i++) {
      if(optimal) {
        for(int j = 0; j < numLabels; j++) {
          Queue<DataSet> q = map.get(j);
          DataSet next = q.poll();
          //add a row; go to next
          if(next != null) {
            addRow(next,i);
            i++;
          }
          else {
            optimal = false;
            break;
          }
        }
      }
      else {
        DataSet add = null;
        for(Queue<DataSet> q : map.values()) {
          if(!q.isEmpty()) {
            add = q.poll();
            break;
          }
        }

        addRow(add,i);

      }


    }


  }


  public void addRow(DataSet d, int i) {
   
    if (i > numExamples() || d == null) {
      throw new IllegalArgumentException("Invalid index for adding a row");
    }
   
    //getFirst().putRow(i, d.getFirst());
    getFirst().assignRow( i, d.getFirst().viewRow(0) );
   
    //getSecond().putRow(i,d.getSecond());
    getSecond().assignRow( i, d.getSecond().viewRow(0) );
  }


  private int getLabel(DataSet data) {
   
    //return SimpleBlas.iamax(data.getSecond());
   
    return MatrixUtils.iamax( data.getSecond().viewRow(0) );
   
  }


  public Matrix exampleSums() {
    //return getFirst().columnSums();
    return MatrixUtils.columnSums(getFirst());
  }

/*  public Matrix exampleMaxs() {
    return getFirst().columnMaxs();
  }
*/
  public Matrix exampleMeans() {
    //return getFirst().columnMeans();
    return MatrixUtils.columnMeans(getFirst());
  }
/*
  public void saveTo(File file,boolean binary) throws IOException {
    if(file.exists())
      file.delete();
    file.createNewFile();

    if(binary) {
      DataOutputStream bos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
      getFirst().out(bos);
      getSecond().out(bos);
      bos.flush();
      bos.close();

    }
    else {
      BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(file));
      for(int i = 0; i < numExamples(); i++) {
        bos.write(getFirst().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
        bos.write("\t".getBytes());
        bos.write(getSecond().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
        bos.write("\n".getBytes())  ;


      }
      bos.flush();
      bos.close();

    }
  }
*/
/*
  public static DataSet load(File path) throws IOException {
    DataInputStream bis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)));
    Matrix x = new DenseMatrix(1,1);
    Matrix y = new DenseMatrix(1,1);
    x.in(bis);
    y.in(bis);
    bis.close();
    return new DataSet(x,y);
  }
*/
  /**
   * Sample without replacement and a random rng
   * @param numSamples the number of samples to get
   * @return a sample data set without replacement
   */
  public DataSet sample(int numSamples) {
    return sample(numSamples,new MersenneTwister(System.currentTimeMillis()));
  }

  /**
   * Sample without replacement
   * @param numSamples the number of samples to get
   * @param rng the rng to use
   * @return the sampled dataset without replacement
   */
  public DataSet sample(int numSamples,RandomGenerator rng) {
    return sample(numSamples,rng,false);
  }

  /**
   * Sample a dataset numSamples times
   * @param numSamples the number of samples to get
   * @param withReplacement the rng to use
   * @return the sampled dataset without replacement
   */
  public DataSet sample(int numSamples,boolean withReplacement) {
    return sample(numSamples,new MersenneTwister(System.currentTimeMillis()),withReplacement);
  }

  /**
   * Sample a dataset
   * @param numSamples the number of samples to get
   * @param rng the rng to use
   * @param withReplacement whether to allow duplicates (only tracked by example row number)
   * @return the sample dataset
   */
  public DataSet sample(int numSamples,RandomGenerator rng,boolean withReplacement) {
    if(numSamples >= numExamples())
      return this;
    else {
      Matrix examples = new DenseMatrix(numSamples,getFirst().numCols());
      Matrix outcomes = new DenseMatrix(numSamples,numOutcomes());
      Set<Integer> added = new HashSet<Integer>();
      for(int i = 0; i < numSamples; i++) {
        int picked = rng.nextInt(numExamples());
        while(added.contains(picked)) {
          picked = rng.nextInt(numExamples());

        }
        //examples.putRow(i,getFirst().getRow(i));
        examples.assignRow( i, getFirst().viewRow(i) );
        //outcomes.putRow(i,getSecond().getRow(i));
        outcomes.assignRow( i, getSecond().viewRow(i) );

      }
      return new DataSet(examples,outcomes);
    }
  }

  public void roundToTheNearest(int roundTo) {
   
    for (int i = 0; i < MatrixUtils.length( getFirst() ); i++) {
     
      //double curr = getFirst().get(i);
      double curr = MatrixUtils.getElement( getFirst(), i );
     
      //getFirst().put(i,MathUtils.roundDouble(curr, roundTo));
      MatrixUtils.setElement( getFirst(), i, curr );
     
    }
   
  }

  public int numOutcomes() {
    return getSecond().numCols();
  }

  public int numExamples() {
    return getFirst().numRows();
  }

 
 

  @Override
  public String toString() {
    StringBuilder builder = new StringBuilder();
    builder.append("===========INPUT===================\n")
    .append(getFirst().toString().replaceAll(";","\n"))
    .append("\n=================OUTPUT==================\n")
    .append(getSecond().toString().replaceAll(";","\n"));
    return builder.toString();
  }
/*
  public static void main(String[] args) throws IOException {
    MnistDataFetcher fetcher = new MnistDataFetcher();
    fetcher.fetch(100);
    DataSet write = new DataSet(fetcher.next());
    write.saveTo(new File(args[0]), false);


  }
*/
}
TOP

Related Classes of tv.floe.metronome.deeplearning.datasets.DataSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.