/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/
package zdenekdrahos.AI.Training;
import zdenekdrahos.AI.Training.Output.ITrainingOutput;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Sets.DataSet;
import zdenekdrahos.AI.Training.Validator.ITrainingValidator;
import zdenekdrahos.AI.Training.Validator.TrainingValidator;
public class Training implements ITraining {
private double lastError;
private INeuralNetwork network;
private TrainingSettings settings;
private SimulationIterator simulations;
private DataSet trainingSet, testingSet;
private ITrainingValidator validator = new TrainingValidator();
public Training(TrainingSettings settings) {
validator.checkSettings(settings);
this.settings = settings;
simulations = new SimulationIterator(settings.simulation);
}
@Override
public void train(INeuralNetwork network, TrainingInput input, ITrainingOutput output) {
validator.checkTraining(network, input);
this.network = network;
initBackPropagation(input);
separateData(input);
for (int i = 1; i <= input.iterationsCount && output.continueTraining(); i++) {
trainNetwork();
testNetwork();
output.processError(lastError);
}
}
@Override
public SimulationIterator simulate(double[] inputPattern) {
simulations.simulate(network, inputPattern);
return simulations;
}
private void initBackPropagation(TrainingInput input) {
settings.backPropagation.setLearningRate(input.learningRate);
settings.backPropagation.setMomentum(input.momentum);
}
private void separateData(TrainingInput input) {
DataSet[] sets = settings.separator.separate(input.inputs, input.targets);
trainingSet = sets[0];
testingSet = sets[1];
}
private void trainNetwork() {
INetworkValues values;
double[] inputPattern, targetPattern;
for (int patternIndex = 0; patternIndex < trainingSet.size(); patternIndex++) {
inputPattern = trainingSet.inputs.get(patternIndex);
targetPattern = trainingSet.targets.get(patternIndex);
values = settings.feedForward.buildNetwork(network, inputPattern);
settings.backPropagation.trainNetwork(network, values, targetPattern);
}
}
private void testNetwork() {
double errorSum = 0;
double[] inputPattern, targetPattern;
for (int patternIndex = 0, i = 0; patternIndex < testingSet.size(); patternIndex++, i = 0) {
inputPattern = testingSet.inputs.get(patternIndex);
targetPattern = testingSet.targets.get(patternIndex);
for (SimulationIterator it = simulate(inputPattern); it.hasNext();) {
errorSum += settings.errorCalculator.getError(targetPattern[i++], it.next());
}
}
lastError = settings.meanCalculator.getMean(errorSum, testingSet.size());
}
}