Package zdenekdrahos.AI.BackPropagation

Source Code of zdenekdrahos.AI.BackPropagation.BackPropagation

/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/

package zdenekdrahos.AI.BackPropagation;

import java.util.List;
import zdenekdrahos.AI.ActivationFunctions.IActivationFunction;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Weights.IWeight;
import zdenekdrahos.AI.NeuralNetwork.Weights.IWeights;

public class BackPropagation implements IBackPropagation {

    private PreviousAdjustments previousAdjustments;
    private WeightErrors weightErrors;
    private INeuralNetwork network;
    private INetworkValues networkValues;
    private double[] target;
    private double learningRate, momentum;
    private List<Double> currentValues;

    public BackPropagation() {
        learningRate = 0.75;
        momentum = 0.9;
        weightErrors = new WeightErrors();
        previousAdjustments = new PreviousAdjustments();
    }

    @Override
    public void setLearningRate(double learningRate) {
        if (learningRate > 0) {
            this.learningRate = learningRate;
        } else {
            throw new IllegalArgumentException("Learning rate");
        }
    }

    @Override
    public void setMomentum(double momentum) {
        if (momentum >= 0 && momentum <= 1) {
            this.momentum = momentum;
        } else {
            throw new IllegalArgumentException("Momentum");
        }
    }

    @Override
    public void trainNetwork(INeuralNetwork network, INetworkValues networkValues, double[] target) {
        this.network = network;
        this.networkValues = networkValues;
        this.target = target;
        initHelperStructures();
        propagate();
        updateWeights();
    }

    private void initHelperStructures() {
        weightErrors.initDelta(network);
        previousAdjustments.initPreviousAdjustment(network);
    }

    private void propagate() {
        for (int layerIndex = network.getLayersCount() - 1; layerIndex > 0; layerIndex--) {
            currentValues = networkValues.getLayerValues(layerIndex);
            weightErrors.setEditedLayer(layerIndex);
            propagateLayer(layerIndex);
        }
    }

    private void propagateLayer(int layerIndex) {
        double error, delta;
        int neuronsCount = network.getLayer(layerIndex).getNeuronsCount();
        for (int neuronIndex = 0; neuronIndex < neuronsCount; neuronIndex++) {
            if (network.isOutputLayer(layerIndex)) {
                error = getErrorInOutputLayer(neuronIndex);
            } else {
                error = getErrorInHiddenLayer(layerIndex, neuronIndex);
            }
            delta = getDelta(layerIndex, neuronIndex, error);
            weightErrors.setNeuronError(neuronIndex, delta);
        }
    }

    private double getErrorInOutputLayer(int neuronIndex) {
        return target[neuronIndex] - currentValues.get(neuronIndex);
    }

    private double getErrorInHiddenLayer(int layerIndex, int neuronIndex) {
        IWeight weight;
        double delta, error = 0;
        int nextLayer = layerIndex + 1;
        int neuronsInNextLayer = network.getLayer(nextLayer).getNeuronsCount();
        for (int nextNeuronIndex = 0; nextNeuronIndex < neuronsInNextLayer; nextNeuronIndex++) {
            delta = weightErrors.get(nextLayer, nextNeuronIndex);
            weight = network.getWeights().getConnectionWeight(nextLayer, nextNeuronIndex, neuronIndex + 1); // +1 because of bias
            error += delta * weight.getWeight();
        }
        return error;
    }

    private double getDelta(int layerIndex, int neuronIndex, double error) {
        IActivationFunction act = network.getLayer(layerIndex).getActivationFunction();
        double derivate = act.derivate(currentValues.get(neuronIndex));
        return error * derivate;
    }

    private void updateWeights() {
        double previousNeuronValue, previousAdjustment, weightAdjustment, newWeight;
        IWeight oldWeight;
        IWeights weights = network.getWeights();
        for (int layerIndex = network.getLayersCount() - 1; layerIndex > 0; layerIndex--) {
            int currentNeuronsCount = network.getLayer(layerIndex).getNeuronsCount();
            int previousNeuronsCount = network.getLayer(layerIndex - 1).getNeuronsCount() + 1;

            for (int neuronIndex = 0; neuronIndex < currentNeuronsCount; neuronIndex++) {
                double neuronDelta = weightErrors.get(layerIndex, neuronIndex);               
                for (int previousNeuronIndex = 0; previousNeuronIndex < previousNeuronsCount; previousNeuronIndex++) {
                    // weight adjustment
                    previousNeuronValue = previousNeuronIndex == 0 ? // bias
                            1 : networkValues.getNeuronValue(layerIndex - 1, previousNeuronIndex - 1);                   
                    previousAdjustment = previousAdjustments.get(layerIndex, neuronIndex, previousNeuronIndex);
                    weightAdjustment = learningRate * neuronDelta * previousNeuronValue + momentum * previousAdjustment;
                    // weights
                    oldWeight = weights.getConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex);                   
                    newWeight = oldWeight.getWeight() + weightAdjustment;
                    // update helper structure and weights in network
                    previousAdjustments.set(layerIndex, neuronIndex, previousNeuronIndex, weightAdjustment);
                    weights.setConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex, newWeight);
                }
            }
        }
    }
}
TOP

Related Classes of zdenekdrahos.AI.BackPropagation.BackPropagation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.