/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/
package zdenekdrahos.AI.BackPropagation;
import java.util.List;
import zdenekdrahos.AI.ActivationFunctions.IActivationFunction;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Weights.IWeight;
import zdenekdrahos.AI.NeuralNetwork.Weights.IWeights;
public class BackPropagation implements IBackPropagation {
private PreviousAdjustments previousAdjustments;
private WeightErrors weightErrors;
private INeuralNetwork network;
private INetworkValues networkValues;
private double[] target;
private double learningRate, momentum;
private List<Double> currentValues;
public BackPropagation() {
learningRate = 0.75;
momentum = 0.9;
weightErrors = new WeightErrors();
previousAdjustments = new PreviousAdjustments();
}
@Override
public void setLearningRate(double learningRate) {
if (learningRate > 0) {
this.learningRate = learningRate;
} else {
throw new IllegalArgumentException("Learning rate");
}
}
@Override
public void setMomentum(double momentum) {
if (momentum >= 0 && momentum <= 1) {
this.momentum = momentum;
} else {
throw new IllegalArgumentException("Momentum");
}
}
@Override
public void trainNetwork(INeuralNetwork network, INetworkValues networkValues, double[] target) {
this.network = network;
this.networkValues = networkValues;
this.target = target;
initHelperStructures();
propagate();
updateWeights();
}
private void initHelperStructures() {
weightErrors.initDelta(network);
previousAdjustments.initPreviousAdjustment(network);
}
private void propagate() {
for (int layerIndex = network.getLayersCount() - 1; layerIndex > 0; layerIndex--) {
currentValues = networkValues.getLayerValues(layerIndex);
weightErrors.setEditedLayer(layerIndex);
propagateLayer(layerIndex);
}
}
private void propagateLayer(int layerIndex) {
double error, delta;
int neuronsCount = network.getLayer(layerIndex).getNeuronsCount();
for (int neuronIndex = 0; neuronIndex < neuronsCount; neuronIndex++) {
if (network.isOutputLayer(layerIndex)) {
error = getErrorInOutputLayer(neuronIndex);
} else {
error = getErrorInHiddenLayer(layerIndex, neuronIndex);
}
delta = getDelta(layerIndex, neuronIndex, error);
weightErrors.setNeuronError(neuronIndex, delta);
}
}
private double getErrorInOutputLayer(int neuronIndex) {
return target[neuronIndex] - currentValues.get(neuronIndex);
}
private double getErrorInHiddenLayer(int layerIndex, int neuronIndex) {
IWeight weight;
double delta, error = 0;
int nextLayer = layerIndex + 1;
int neuronsInNextLayer = network.getLayer(nextLayer).getNeuronsCount();
for (int nextNeuronIndex = 0; nextNeuronIndex < neuronsInNextLayer; nextNeuronIndex++) {
delta = weightErrors.get(nextLayer, nextNeuronIndex);
weight = network.getWeights().getConnectionWeight(nextLayer, nextNeuronIndex, neuronIndex + 1); // +1 because of bias
error += delta * weight.getWeight();
}
return error;
}
private double getDelta(int layerIndex, int neuronIndex, double error) {
IActivationFunction act = network.getLayer(layerIndex).getActivationFunction();
double derivate = act.derivate(currentValues.get(neuronIndex));
return error * derivate;
}
private void updateWeights() {
double previousNeuronValue, previousAdjustment, weightAdjustment, newWeight;
IWeight oldWeight;
IWeights weights = network.getWeights();
for (int layerIndex = network.getLayersCount() - 1; layerIndex > 0; layerIndex--) {
int currentNeuronsCount = network.getLayer(layerIndex).getNeuronsCount();
int previousNeuronsCount = network.getLayer(layerIndex - 1).getNeuronsCount() + 1;
for (int neuronIndex = 0; neuronIndex < currentNeuronsCount; neuronIndex++) {
double neuronDelta = weightErrors.get(layerIndex, neuronIndex);
for (int previousNeuronIndex = 0; previousNeuronIndex < previousNeuronsCount; previousNeuronIndex++) {
// weight adjustment
previousNeuronValue = previousNeuronIndex == 0 ? // bias
1 : networkValues.getNeuronValue(layerIndex - 1, previousNeuronIndex - 1);
previousAdjustment = previousAdjustments.get(layerIndex, neuronIndex, previousNeuronIndex);
weightAdjustment = learningRate * neuronDelta * previousNeuronValue + momentum * previousAdjustment;
// weights
oldWeight = weights.getConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex);
newWeight = oldWeight.getWeight() + weightAdjustment;
// update helper structure and weights in network
previousAdjustments.set(layerIndex, neuronIndex, previousNeuronIndex, weightAdjustment);
weights.setConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex, newWeight);
}
}
}
}
}