* @see com.github.neuralnetworks.training.rbm.CDTrainerBase#updateWeights(com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix)
* before each update the kernel update parameters are refreshed
*/
@Override
protected void updateWeights(Matrix posPhaseVisible, Matrix posPhaseHidden, Matrix negPhaseVisible, Matrix negPhaseHidden) {
RBM rbm = getNeuralNetwork();
int mbs = posPhaseHidden.getColumns();
if (weightUpdatesKernel == null || weightUpdatesKernel.getMiniBatchSize() != mbs) {
weightUpdatesKernel = new CDWeightUpdatesKernel(posPhaseVisible, posPhaseHidden, negPhaseVisible, negPhaseHidden, rbm.getMainConnections().getConnectionGraph(), getLearningRate(), getMomentum(), getl1weightDecay(), getl2weightDecay());
}
Environment.getInstance().getExecutionStrategy().execute(weightUpdatesKernel, rbm.getMainConnections().getConnectionGraph().getRows());
// update visible bias
if (rbm.getVisibleBiasConnections() != null) {
if (visibleBiasUpdatesKernel == null || visibleBiasUpdatesKernel.getMiniBatchSize() != mbs) {
visibleBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getVisibleBiasConnections().getConnectionGraph().getElements(), posPhaseVisible, negPhaseVisible, getLearningRate(), getMomentum());
}
Environment.getInstance().getExecutionStrategy().execute(visibleBiasUpdatesKernel, rbm.getVisibleBiasConnections().getConnectionGraph().getElements().length);
}
// update hidden bias
if (rbm.getHiddenBiasConnections() != null) {
if (hiddenBiasUpdatesKernel == null || hiddenBiasUpdatesKernel.getMiniBatchSize() != mbs) {
hiddenBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getHiddenBiasConnections().getConnectionGraph().getElements(), posPhaseHidden, negPhaseHidden, getLearningRate(), getMomentum());
}
Environment.getInstance().getExecutionStrategy().execute(hiddenBiasUpdatesKernel, rbm.getHiddenBiasConnections().getConnectionGraph().getElements().length);
}
}