int recordsProcessed = 0;
StopWatch batchWatch = new StopWatch();
DataSet hdfs_recordBatch = null; //this.hdfs_fetcher.next();
System.out.println("Iteration: " + this.currentIteration );
// if (hdfs_recordBatch.getFirst().numRows() > 0) {
// do {
if ( TrainingState.PRE_TRAIN == this.currentTrainingState ) {
System.out.println("Worker > PRE TRAIN! " );
if ( this.hdfs_fetcher.hasNext() ) {
hdfs_recordBatch = this.hdfs_fetcher.next();
System.out.println("Worker > Has Next! > Recs: " + hdfs_recordBatch.getFirst().numRows() );
// check for the straggler batch condition
if (0 == this.currentIteration && hdfs_recordBatch.getFirst().numRows() > 0 && hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
// System.out.println( "Worker > Straggler Batch Condition!" );
// ok, only in this situation do we lower the batch size
this.batchSize = hdfs_recordBatch.getFirst().numRows();
// re-setup the dataset iterator
try {
this.hdfs_fetcher = new MnistHDFSDataSetIterator( this.batchSize, this.totalTrainingDatasetSize, (TextRecordParser)lineParser );
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
// System.out.println( "Worker > PreTrain: Setting up for a straggler split... (sub batch size)" );
// System.out.println( "New batch size: " + this.batchSize );
} else {
// System.out.println( "Worker > NO Straggler Batch Condition!" );
}
if (hdfs_recordBatch.getFirst().numRows() > 0) {
if (hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
// System.out.println( "Worker > PreTrain: [Jagged End of Split: Skipped] Processed Total " + recordsProcessed + " Total Time " + watch.toString() );
} else {
// System.out.println( "Worker > Normal Processing!" );
// calc stats on number records processed
recordsProcessed += hdfs_recordBatch.getFirst().numRows();
//System.out.println( "PreTrain: Batch Size: " + hdfs_recordBatch.getFirst().numRows() );
batchWatch.reset();
batchWatch.start();
this.dbn.preTrain( hdfs_recordBatch.getFirst(), 1, this.learningRate, this.preTrainEpochs);
batchWatch.stop();
System.out.println( "Worker > PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Batch Time " + batchWatch.toString() + " Total Time " + watch.toString() );
} // if
} else {
// in case we get a blank line
System.out.println( "Worker > PreTrain > Idle pass, no records left to process in phase" );
}
} else {
System.out.println( "Worker > PreTrain > Idle pass, no records left to process in phase" );
}
// System.out.println( "Worker > Check PreTrain completion > completedEpochs: " + this.completedDatasetEpochs + ", preTrainDatasetPasses: " + this.preTrainDatasetPasses );
// check for completion of split, to signal master on state change
if (false == this.hdfs_fetcher.hasNext() && this.completedDatasetEpochs + 1 >= this.preTrainDatasetPasses ) {
this.preTrainPhaseComplete = true;
// System.out.println( "Worker > Completion of pre-train phase" );
}
} else if ( TrainingState.FINE_TUNE == this.currentTrainingState) {
//System.out.println( "DBN Network Stats:\n" + dbn.generateNetworkSizeReport() );
if ( this.hdfs_fetcher.hasNext() ) {
hdfs_recordBatch = this.hdfs_fetcher.next();
if (hdfs_recordBatch.getFirst().numRows() > 0) {
if (hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
// System.out.println( "Worker > FineTune: [Jagged End of Split: Skipped] Processed Total " + recordsProcessed + " Total Time " + watch.toString() );
} else {
batchWatch.reset();
batchWatch.start();
this.dbn.finetune( hdfs_recordBatch.getSecond(), learningRate, fineTuneEpochs );
batchWatch.stop();
System.out.println( "Worker > FineTune > Batch Mode, Processed Total " + recordsProcessed + ", Batch Time " + batchWatch.toString() + " Total Time " + watch.toString() );