Package zdenekdrahos.AI.Training

Source Code of zdenekdrahos.AI.Training.Training

/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/

package zdenekdrahos.AI.Training;

import zdenekdrahos.AI.Training.Output.ITrainingOutput;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Sets.DataSet;
import zdenekdrahos.AI.Training.Validator.ITrainingValidator;
import zdenekdrahos.AI.Training.Validator.TrainingValidator;

public class Training implements ITraining {

    private double lastError;
    private INeuralNetwork network;
    private TrainingSettings settings;
    private SimulationIterator simulations;
    private DataSet trainingSet, testingSet;
    private ITrainingValidator validator = new TrainingValidator();
   
    public Training(TrainingSettings settings) {
        validator.checkSettings(settings);
        this.settings = settings;
        simulations = new SimulationIterator(settings.simulation);
    }

    @Override
    public void train(INeuralNetwork network, TrainingInput input, ITrainingOutput output) {       
        validator.checkTraining(network, input);
        this.network = network;
        initBackPropagation(input);
        separateData(input);
        for (int i = 1; i <= input.iterationsCount && output.continueTraining(); i++) {
            trainNetwork();
            testNetwork();
            output.processError(lastError);
        }
    }

    @Override
    public SimulationIterator simulate(double[] inputPattern) {
        simulations.simulate(network, inputPattern);
        return simulations;
    }

    private void initBackPropagation(TrainingInput input) {
        settings.backPropagation.setLearningRate(input.learningRate);
        settings.backPropagation.setMomentum(input.momentum);
    }

    private void separateData(TrainingInput input) {
        DataSet[] sets = settings.separator.separate(input.inputs, input.targets);
        trainingSet = sets[0];
        testingSet = sets[1];
    }

    private void trainNetwork() {
        INetworkValues values;
        double[] inputPattern, targetPattern;
        for (int patternIndex = 0; patternIndex < trainingSet.size(); patternIndex++) {
            inputPattern = trainingSet.inputs.get(patternIndex);
            targetPattern = trainingSet.targets.get(patternIndex);

            values = settings.feedForward.buildNetwork(network, inputPattern);
            settings.backPropagation.trainNetwork(network, values, targetPattern);
        }
    }

    private void testNetwork() {
        double errorSum = 0;
        double[] inputPattern, targetPattern;
        for (int patternIndex = 0, i = 0; patternIndex < testingSet.size(); patternIndex++, i = 0) {
            inputPattern = testingSet.inputs.get(patternIndex);
            targetPattern = testingSet.targets.get(patternIndex);
            for (SimulationIterator it = simulate(inputPattern); it.hasNext();) {
                errorSum += settings.errorCalculator.getError(targetPattern[i++], it.next());
            }
        }
        lastError = settings.meanCalculator.getMean(errorSum, testingSet.size());
    }
}
TOP

Related Classes of zdenekdrahos.AI.Training.Training

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.