int batchSize = 50;
boolean showNetworkStats = true;
// mini-batches through dataset
MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
DataSet first = fetcher.next();
int numIns = first.getFirst().numCols();
int numLabels = first.getSecond().numCols();
int n_layers = hiddenLayerSizes.length;
RandomGenerator rng = new MersenneTwister(123);
DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
dbn.useRegularization = false;
dbn.setSparsity(0.01);
dbn.setMomentum(0);
int recordsProcessed = 0;
int batchesProcessed = 0;
long totalBatchProcessingTime = 0;
StopWatch watch = new StopWatch();
watch.start();
StopWatch batchWatch = new StopWatch();
do {
recordsProcessed += batchSize;
batchesProcessed++;
System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
batchWatch.reset();
batchWatch.start();
dbn.preTrain( first.getFirst(), 1, learningRate, preTrainEpochs);
batchWatch.stop();
totalBatchProcessingTime += batchWatch.getTime();
System.out.println( "Batch Training Elapsed Time " + batchWatch.toString() );
//System.out.println( "DBN Network Stats:\n" + dbn.generateNetworkSizeReport() );
if (fetcher.hasNext()) {
first = fetcher.next();
}
} while (fetcher.hasNext());
double avgBatchTime = totalBatchProcessingTime / batchesProcessed;
double avgBatchSeconds = avgBatchTime / 1000;
double avgBatchMinutes = avgBatchSeconds / 60;