package org.neuroph.nnet.learning;
import java.util.Iterator;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingElement;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.util.NeuralNetworkCODEC;
/**
* This class implements a simulated annealing learning rule for supervised
* neural networks. It is based on the generic SimulatedAnnealing class. It is
* used in the same manner as any other training class that implements the
* SupervisedLearning interface.
*
* Simulated annealing is a common training method. It is often used in
* conjunction with a propagation training method. Simulated annealing can be
* very good when propagation training has reached a local minimum.
*
* The name and inspiration come from annealing in metallurgy, a technique
* involving heating and controlled cooling of a material to increase the size
* of its crystals and reduce their defects. The heat causes the atoms to become
* unstuck from their initial positions (a local minimum of the internal energy)
* and wander randomly through states of higher energy; the slow cooling gives
* them more chances of finding configurations with lower internal energy than
* the initial one.
*
* @author Jeff Heaton (http://www.jeffheaton.com)
*/
public class SimulatedAnnealingLearning extends SupervisedLearning {
@Override
protected void updatePatternError(double[] patternError) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
protected void updateTotalNetworkError() {
throw new UnsupportedOperationException("Not supported yet.");
}
/**
* The serial id.
*/
private static final long serialVersionUID = 1L;
/**
* The neural network that is to be trained.
*/
protected NeuralNetwork network;
/**
* The starting temperature.
*/
private double startTemperature;
/**
* The ending temperature.
*/
private double stopTemperature;
/**
* The number of cycles that will be used.
*/
private int cycles;
/**
* The current temperature.
*/
protected double temperature;
/**
* Current weights from the neural network.
*/
private double[] weights;
/**
* Best weights so far.
*/
private double[] bestWeights;
/**
* Construct a simulated annleaing trainer for a feedforward neural network.
*
* @param network
* The neural network to be trained.
* @param startTemp
* The starting temperature.
* @param stopTemp
* The ending temperature.
* @param cycles
* The number of cycles in a training iteration.
*/
public SimulatedAnnealingLearning(final NeuralNetwork network,
final double startTemp, final double stopTemp, final int cycles) {
this.network = network;
this.temperature = startTemp;
this.startTemperature = startTemp;
this.stopTemperature = stopTemp;
this.cycles = cycles;
this.weights = new double[NeuralNetworkCODEC
.determineArraySize(network)];
this.bestWeights = new double[NeuralNetworkCODEC
.determineArraySize(network)];
NeuralNetworkCODEC.network2array(network, this.weights);
NeuralNetworkCODEC.network2array(network, this.bestWeights);
}
public SimulatedAnnealingLearning(final NeuralNetwork network) {
this(network, 10, 2, 1000);
}
/**
* Get the best network from the training.
*
* @return The best network.
*/
public NeuralNetwork getNetwork() {
return this.network;
}
/**
* Randomize the weights and thresholds. This function does most of the work
* of the class. Each call to this class will randomize the data according
* to the current temperature. The higher the temperature the more
* randomness.
*/
public void randomize() {
for (int i = 0; i < this.weights.length; i++) {
double add = 0.5 - (Math.random());
add /= this.startTemperature;
add *= this.temperature;
this.weights[i] = this.weights[i] + add;
}
NeuralNetworkCODEC.array2network(this.weights, this.network);
}
/**
* Used internally to calculate the error for a training set.
* @param trainingSet The training set to calculate for.
* @return The error value.
*/
private double determineError(TrainingSet trainingSet) {
double result = 0d;
Iterator<TrainingElement> iterator = trainingSet.iterator();
while (iterator.hasNext() && !isStopped()) {
SupervisedTrainingElement supervisedTrainingElement = (SupervisedTrainingElement) iterator
.next();
double[] input = supervisedTrainingElement.getInput();
this.neuralNetwork.setInput(input);
this.neuralNetwork.calculate();
double[] output = this.neuralNetwork.getOutput();
double[] desiredOutput = supervisedTrainingElement
.getDesiredOutput();
double[] patternError = this.getPatternError(output, desiredOutput);
this.updateTotalNetworkError(patternError);
double sqrErrorSum = 0;
for (double error : patternError) {
sqrErrorSum += (error * error);
}
result += sqrErrorSum / (2 * patternError.length);
}
return result;
}
/**
* Perform one simulated annealing epoch.
*/
@Override
public void doLearningEpoch(TrainingSet trainingSet) {
System.arraycopy(this.weights, 0, this.bestWeights, 0,
this.weights.length);
double bestError = determineError(trainingSet);
this.temperature = this.startTemperature;
for (int i = 0; i < this.cycles; i++) {
randomize();
double currentError = determineError(trainingSet);
if (currentError < bestError) {
System.arraycopy(this.weights, 0, this.bestWeights, 0,
this.weights.length);
bestError = currentError;
} else
System.arraycopy(this.bestWeights, 0, this.weights, 0,
this.weights.length);
NeuralNetworkCODEC.array2network(this.bestWeights, network);
final double ratio = Math.exp(Math.log(this.stopTemperature
/ this.startTemperature)
/ (this.cycles - 1));
this.temperature *= ratio;
}
this.previousEpochError = this.totalNetworkError;
this.totalNetworkError = bestError;
// moved stopping condition to separate method hasReachedStopCondition()
// so it can be overriden / customized in subclasses
if (hasReachedStopCondition()) {
stopLearning();
}
}
/**
* Update the total error.
*/
@Override
protected void updateTotalNetworkError(double[] patternError) {
double sqrErrorSum = 0;
for (double error : patternError) {
sqrErrorSum += (error * error);
}
this.totalNetworkError += sqrErrorSum / (2 * patternError.length);
}
/**
* Not used.
*/
@Override
protected void updateNetworkWeights(double[] patternError) {
}
}