LayerCalculatorImpl lc = (LayerCalculatorImpl) dbn.getLayerCalculator();
RBM firstRBM = dbn.getFirstNeuralNetwork();
lc.addConnectionCalculator(firstRBM.getHiddenLayer(), new AparapiWeightedSumConnectionCalculator());
Matrix m1 = firstRBM.getMainConnections().getConnectionGraph();
m1.set(1, 0, 0);
m1.set(0, 0, 1);
m1.set(0, 0, 2);
m1.set(0, 1, 0);
m1.set(1, 1, 1);
m1.set(0, 1, 2);
m1.set(0, 2, 0);
m1.set(0, 2, 1);
m1.set(1, 2, 2);
RBM secondRBM = dbn.getLastNeuralNetwork();
Matrix cg1 = secondRBM.getMainConnections().getConnectionGraph();
cg1.set(0.2f, 0, 0);
cg1.set(0.4f, 0, 1);
cg1.set(-0.5f, 0, 2);
cg1.set(-0.3f, 1, 0);
cg1.set(0.1f, 1, 1);
cg1.set(0.2f, 1, 2);
Matrix cgb1 = secondRBM.getVisibleBiasConnections().getConnectionGraph();
cgb1.set(0f, 0, 0);
cgb1.set(0f, 1, 0);
cgb1.set(0f, 2, 0);
Matrix cgb2 = secondRBM.getHiddenBiasConnections().getConnectionGraph();
cgb2.set(-0.4f, 0, 0);
cgb2.set(0.2f, 1, 0);
SimpleInputProvider inputProvider = new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, null, 1, 1);
AparapiCDTrainer firstTrainer = TrainerFactory.cdSigmoidTrainer(firstRBM, null, null, null, null, 0f, 0f, 0f, 0f, 0, true);
firstTrainer.setLayerCalculator(NNFactory.rbmSigmoidSigmoid(firstRBM));
AparapiCDTrainer secondTrainer = TrainerFactory.cdSigmoidTrainer(secondRBM, null, null, null, null, 1f, 0f, 0f, 0f, 1, true);
secondTrainer.setLayerCalculator(NNFactory.rbmSigmoidSigmoid(secondRBM));
Map<NeuralNetwork, OneStepTrainer<?>> layerTrainers = new HashMap<>();
layerTrainers.put(firstRBM, firstTrainer);
layerTrainers.put(secondRBM, secondTrainer);
DNNLayerTrainer trainer = TrainerFactory.dnnLayerTrainer(dbn, layerTrainers, inputProvider, null, null);
trainer.train();
assertEquals(0.2 + 0.13203661, cg1.get(0, 0), 0.00001);
assertEquals(0.4 - 0.22863509, cg1.get(0, 1), 0.00001);
assertEquals(-0.5 + 0.12887852, cg1.get(0, 2), 0.00001);
assertEquals(-0.3 + 0.26158813, cg1.get(1, 0), 0.00001);
assertEquals(0.1 - 0.3014404, cg1.get(1, 1), 0.00001);
assertEquals(0.2 + 0.25742438, cg1.get(1, 2), 0.00001);
assertEquals(0.52276707, cgb1.get(0, 0), 0.00001);
assertEquals(- 0.54617375, cgb1.get(1, 0), 0.00001);
assertEquals(0.51522285, cgb1.get(2, 0), 0.00001);
assertEquals(-0.4 - 0.08680013, cgb2.get(0, 0), 0.00001);
assertEquals(0.2 - 0.02693379, cgb2.get(1, 0), 0.00001);
}