/*
RestrictedBoltzmannMachine da = new RBM.Builder().numberOfVisible(784).numHidden(400).withRandom(rand).renderWeights(1000)
.useRegularization(false)
.withMomentum(0).build();
*/
RestrictedBoltzmannMachine rbm = new RestrictedBoltzmannMachine( 784, 400, null );
rbm.useRegularization = false;
//rbm.scaleWeights( 1000 );
rbm.momentum = 0 ;
rbm.sparsity = 0.01;
// TODO: investigate "render weights"
rbm.trainingDataset = first.getFirst();
//MatrixUtils.debug_print( rbm.trainingDataset );
// render base activations pre train
this.renderActivationsToDisk(rbm, "init");
this.renderWeightValuesToDisk(rbm, "init");
this.renderFiltersToDisk(rbm, "init");
System.out.println(" ----- Training ------");
//for(int i = 0; i < 2; i++) {
int epoch = 0;
System.out.println("Epoch " + epoch + " Negative Log Likelhood: " + rbm.getReConstructionCrossEntropy() );
for (int stepIndex = 0; stepIndex < batchSteps.length; stepIndex++ ) {
int minCrossEntropy = batchSteps[ stepIndex ];
while ( rbm.getReConstructionCrossEntropy() > minCrossEntropy) {
System.out.println("Epoch " + epoch + " Negative Log Likelhood: " + rbm.getReConstructionCrossEntropy() );
//rbm.trainTillConvergence( first.getFirst(), learningRate, new Object[]{ 1 } );
//rbm.trainTillConvergence(learningRate, 1, first.getFirst());
// new Object[]{1,0.01,1000}
rbm.trainTillConvergence(first.getFirst(), learningRate, new Object[]{ 1, learningRate, 10 } );
epoch++;
}
System.out.println(" ----- Visualizing Reconstructions Step " + minCrossEntropy + " CE ------");
if ( stepIndex == 0 ) {
renderBatchOfReconstructions( rbm, first, true, String.valueOf(rbm.getReConstructionCrossEntropy()), true );
} else {
renderBatchOfReconstructions( rbm, first, true, String.valueOf(rbm.getReConstructionCrossEntropy()), false );
}
}