// the vector to pull from the local read through cache
CachedVector cv = new CachedVector( this.nn.getInputsCount(), this.rec_factory.getOutputVectorSize() ) ;// rec_factory.getFeatureVectorSize() );
cachedVecReader.Reset();
BackPropogationLearningAlgorithm bp = ((BackPropogationLearningAlgorithm)this.nn.getLearningRule());
bp.clearTotalSquaredError();
try {
while (cachedVecReader.next(cv)) {
bp.getMetrics().startTrainingRecordTimer();
this.nn.train(cv.vec_output, cv.vec_input);
bp.getMetrics().stopTrainingRecordTimer();
}
} catch (IOException e) {
e.printStackTrace();
}
// TODO: clean up post-epoch -- this may should be handled via the nn interface?
bp.completeTrainingEpoch();
String marker = "";
if (hitErrThreshold) {
marker += ", Hit Err Threshold at: " + this.trainingCompleteEpoch;
}
if (bp.checkForLearningStallOut() && false == bp.hasHitMinErrorThreshold()) {
marker += " [ --- STALL ---]";
this.nn.randomizeWeights();
if (this.stallBustingOn) {
bp.resetStallTracking();
System.out.println("[ --- STALL WORKER RESET --- ]: " + bp.getSetMaxStalledEpochs());
}
}
String alr_debug = bp.DebugAdagrad();
this.metrics.printProgressiveStepDebugMsg(this.CurrentIteration, "Epoch: " + this.CurrentIteration + " > RMSE: " + bp.calcRMSError() + ", Records Trainined: " + this.cachedVecReader.recordsInCache() + marker + ", ALR: " + alr_debug );
if (this.metricsOn) {
bp.getMetrics().PrintMetrics();
}
NeuralNetworkWeightsDelta nnwd = new NeuralNetworkWeightsDelta();
nnwd.network = this.nn;
nnwd.RMSE = bp.calcRMSError();
this.lastRMSE = nnwd.RMSE;
NetworkWeightsUpdateable nwu = new NetworkWeightsUpdateable();
nwu.networkUpdate = nnwd;
nwu.networkUpdate.CurrentIteration = this.CurrentIteration;