@Override
protected void updateNeuronWeights(Neuron neuron) {
double neuronError = neuron.getError();
double lrTemp = 0;
AdagradLearningRate alr = null;
for (Connection connection : neuron.getInConnections()) {
if (this.adagradLearningOn) {
alr = (AdagradLearningRate)connection.getWeight().trainingMetaData.get("adagrad");
lrTemp = alr.compute();
} else {
lrTemp = this.learningRate;
}
double input = connection.getInput();
//double weightChange = this.learningRate * neuronError * input;
double weightChange = lrTemp * neuronError * input;
Weight weight = connection.getWeight();
if (this.isInBatchMode() == false) {
weight.weightChange = weightChange;
weight.value += weightChange;
} else {
weight.weightChange += weightChange;
}
if (this.adagradLearningOn) {
alr = (AdagradLearningRate)connection.getWeight().trainingMetaData.get("adagrad");
alr.addLastIterationGradient(weightChange);
}
if (this.isMetricCollectionOn()) {
this.metrics.incWeightOpCount();
}