* @see com.github.neuralnetworks.training.rbm.CDTrainerBase#updateWeights(com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix)
* before each update the kernel update parameters are refreshed
*/
@Override
protected void updateWeights() {
RBM rbm = getNeuralNetwork();
RBMLayerCalculator lc = getLayerCalculator();
int mbs = lc.getPositivePhaseVisible().getDimensions()[lc.getPositivePhaseVisible().getDimensions().length - 1];
if (weightUpdatesKernel == null || weightUpdatesKernel.getMiniBatchSize() != mbs) {
weightUpdatesKernel = new CDWeightUpdatesKernel(lc.getPositivePhaseVisible(), lc.getPositivePhaseHidden(), lc.getNegativePhaseVisible(), lc.getNegativePhaseHidden(), rbm.getMainConnections().getWeights(), getLearningRate(), getMomentum(), getl1weightDecay(), getl2weightDecay());
}
Environment.getInstance().getExecutionStrategy().execute(weightUpdatesKernel, rbm.getMainConnections().getWeights().getRows());
// update visible bias
if (rbm.getVisibleBiasConnections() != null) {
if (visibleBiasUpdatesKernel == null || visibleBiasUpdatesKernel.getMiniBatchSize() != mbs) {
visibleBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getVisibleBiasConnections().getWeights(), lc.getPositivePhaseVisible(), lc.getNegativePhaseVisible(), getLearningRate(), getMomentum());
}
Environment.getInstance().getExecutionStrategy().execute(visibleBiasUpdatesKernel, rbm.getVisibleBiasConnections().getWeights().getSize());
}
// update hidden bias
if (rbm.getHiddenBiasConnections() != null) {
if (hiddenBiasUpdatesKernel == null || hiddenBiasUpdatesKernel.getMiniBatchSize() != mbs) {
hiddenBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getHiddenBiasConnections().getWeights(), lc.getPositivePhaseHidden(), lc.getNegativePhaseHidden(), getLearningRate(), getMomentum());
}
Environment.getInstance().getExecutionStrategy().execute(hiddenBiasUpdatesKernel, rbm.getHiddenBiasConnections().getWeights().getSize());
}
}