package zdenekdrahos.App.Controller;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Observable;
import java.util.Observer;
import javax.swing.ComboBoxModel;
import javax.swing.DefaultComboBoxModel;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;
import zdenekdrahos.AI.NeuralNetwork.Builder.Activations;
import zdenekdrahos.AI.NeuralNetwork.Builder.INetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Builder.ITrainingBuilder;
import zdenekdrahos.AI.Training.Builder.TrainingBuilder;
import zdenekdrahos.AI.Training.ITraining;
import zdenekdrahos.AI.Training.Output.TrainingOutput;
import zdenekdrahos.AI.Training.TrainingInput;
public class BackPropagationController implements IBackPropagationController, Observer {
private Map<String, List<double[]>> columns = new HashMap<String, List<double[]>>();
private ITraining train;
private TrainingOutput output;
private TrainingInput input;
private INeuralNetwork network;
private INetworkBuilder networkBuilder;
public BackPropagationController() {
ITrainingBuilder trainBuilder = new TrainingBuilder();
train = trainBuilder.build();
networkBuilder = new NetworkBuilder();
}
@Override
public void update(Observable o, Object arg) {
if (o instanceof SourceDataController) {
setSourceData((Double[][]) arg);
} else if (o instanceof KohonenMapController) {
setKohonenData((List<List<Double[]>>) arg);
}
}
@Override
public ComboBoxModel createComboBoxModel() {
String[] keys = columns.keySet().toArray(new String[0]);
return new DefaultComboBoxModel(keys);
}
@Override
public void buildNetwork(String topology, String activationFunctions) {
network = networkBuilder.build(getTopology(topology), getActivations(activationFunctions));
}
private int[] getTopology(String topologyString) {
List<Integer> topology = new ArrayList<Integer>();
for (String string : splitBySpace(topologyString)) {
topology.add(Integer.parseInt(string));
}
int[] array = new int[topology.size()];
for (int i = 0; i < array.length; i++) {
array[i] = topology.get(i);
}
return array;
}
private String[] splitBySpace(String text) {
return text.split("\\s+");
}
private Activations[] getActivations(String activationsString) {
List<Activations> activations = new ArrayList<Activations>();
activations.add(Activations.LIN);// for input layer, it's not used but layersCount == functionsCount
for (String string : splitBySpace(activationsString)) {
activations.add(getActivationFunction(string));
}
return activations.toArray(new Activations[0]);
}
private Activations getActivationFunction(String text) {
if (text.equals("TANH")) {
return Activations.TANH;
} else if (text.equals("SIG")) {
return Activations.SIG;
} else {
return Activations.LIN;
}
}
@Override
public void train(TrainingInput input, double errorToleration, String inputs, String targets) {
this.input = input;
this.input.inputs = columns.get(inputs).toArray(new double[0][0]);
this.input.targets = columns.get(targets).toArray(new double[0][0]);
output = new TrainingOutput(errorToleration);
train.train(network, this.input, output);
}
@Override
public TrainingOutput getTrainingOutput() {
return output;
}
@Override
public XYSeriesCollection getGraphModel() {
XYSeries original = new XYSeries("Original"), prediction = new XYSeries("Prediction");
SimulationIterator iterator;
for (int i = 0; i < input.inputs.length; i++) {
for (iterator = train.simulate(input.inputs[i]); iterator.hasNext();) {
original.add(input.inputs[i][0], input.targets[i][0]);
prediction.add(input.inputs[i][0], iterator.next());
}
}
XYSeriesCollection collection = new XYSeriesCollection();
collection.addSeries(original);
collection.addSeries(prediction);
return collection;
}
private void setSourceData(Double[][] data) {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[i].length; j++) {
addNumberToColumn(j, data[i][j], "Source data");
}
}
}
private void setKohonenData(List<List<Double[]>> groups) {
List<Double[]> currentGroup;
int columnNumber = 0;
for (int groupIndex = 0; groupIndex < groups.size(); groupIndex++) {
currentGroup = groups.get(groupIndex);
for (int j = 0; j < currentGroup.get(0).length; j++) {
for (int i = 0; i < currentGroup.size(); i++) {
addNumberToColumn(columnNumber, currentGroup.get(i)[j], "Kohonen data");
}
columnNumber++;
}
}
}
private void addNumberToColumn(int columnNumber, Double number, String text) {
String column = getColumnText(columnNumber, text);
if (!columns.containsKey(column)) {
columns.put(column, new ArrayList<double[]>());
}
columns.get(column).add(new double[]{number});
}
private String getColumnText(int columnNumber, String text) {
return String.format("%s - column #%d", text, columnNumber + 1);
}
}