Package tv.floe.metronome.deeplearning.datasets.iterator.impl

Examples of tv.floe.metronome.deeplearning.datasets.iterator.impl.MnistDataSetIterator


   
  }
 
  @Test
  public void testMnist() throws Exception {
    MnistDataSetIterator fetcher = new MnistDataSetIterator(100,200);
    MersenneTwister rand = new MersenneTwister(123);

    double learningRate = 0.001;
   
    int[] batchSteps = { 250, 200, 150, 100, 50, 25, 5 };
   
    DataSet first = fetcher.next();
/*
    RestrictedBoltzmannMachine da = new RBM.Builder().numberOfVisible(784).numHidden(400).withRandom(rand).renderWeights(1000)
        .useRegularization(false)
        .withMomentum(0).build();
*/
 
View Full Code Here


    int batchSize = 100 * datasetSize;
    int totalNumExamples = 100 * datasetSize;
   
   
   
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet recordBatch = fetcher.next();
   
   
    Map<Integer, Integer> filter = new HashMap<Integer, Integer>();
    for (int x = 0; x < classIndexes.length; x++ ) {
     
View Full Code Here

       
    int batchSize = 50;
    boolean showNetworkStats = true;
   
    // mini-batches through dataset
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet first = fetcher.next();
    int numIns = first.getFirst().numCols();
    int numLabels = first.getSecond().numCols();

    int n_layers = hiddenLayerSizes.length;
    RandomGenerator rng = new MersenneTwister(123);
   
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
       
    dbn.useRegularization = false;
    dbn.setSparsity(0.01);
    dbn.setMomentum(0);
   
   
    int recordsProcessed = 0;
    int batchesProcessed = 0;
    long totalBatchProcessingTime = 0;
   
    StopWatch watch = new StopWatch();
    watch.start();
   
    StopWatch batchWatch = new StopWatch();
   
   
    do  {
     
      recordsProcessed += batchSize;
      batchesProcessed++;
     
      System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
     
      batchWatch.reset();
      batchWatch.start();
      dbn.preTrain( first.getFirst(), 1, learningRate, preTrainEpochs);
      batchWatch.stop();
     
      totalBatchProcessingTime += batchWatch.getTime();
     
      System.out.println( "Batch Training Elapsed Time " + batchWatch.toString() );

      //System.out.println( "DBN Network Stats:\n" + dbn.generateNetworkSizeReport() );

     
      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());

    double avgBatchTime = totalBatchProcessingTime / batchesProcessed;
    double avgBatchSeconds = avgBatchTime / 1000;
    double avgBatchMinutes = avgBatchSeconds / 60;
   
View Full Code Here

    int totalNumExamples = 50;
    //int rowLimit = 100;
       
    int batchSize = 10;
    // mini-batches through dataset
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet first = fetcher.next();
    int numIns = first.getFirst().numCols();
    int numLabels = first.getSecond().numCols();

    int n_layers = hiddenLayerSizes.length;
    RandomGenerator rng = new MersenneTwister(123);
   
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
       
    int recordsProcessed = 0;
   
    do  {
     
      recordsProcessed += batchSize;
     
      System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed );
      dbn.preTrain( first.getFirst(), 1, learningRate, preTrainEpochs);

      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());

    fetcher.reset();
    first = fetcher.next();
   
    recordsProcessed = 0;
   
    do {
     
      recordsProcessed += batchSize;
     
      System.out.println( "FineTune: Batch Mode, Processed Total " + recordsProcessed );
     
     
      dbn.finetune( first.getSecond(), learningRate, fineTuneEpochs );
     
      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());
   
    System.out.println("----------- Training Complete! -----------");
   
    // save model
   
View Full Code Here

  }
 
  @Test
  public void testMNISTRenderPath() throws Exception {
   
    MnistDataSetIterator fetcher = new MnistDataSetIterator(100,200);
    MersenneTwister rand = new MersenneTwister(123);

    double learningRate = 0.001;
   
    int[] batchSteps = { 250, 200, 150, 100, 50, 25, 5 };
   
    DataSet first = fetcher.next();
/*
    RestrictedBoltzmannMachine da = new RBM.Builder().numberOfVisible(784).numHidden(400).withRandom(rand).renderWeights(1000)
        .useRegularization(false)
        .withMomentum(0).build();
*/
 
View Full Code Here

TOP

Related Classes of tv.floe.metronome.deeplearning.datasets.iterator.impl.MnistDataSetIterator

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.