Package zdenekdrahos.App.Controller

Source Code of zdenekdrahos.App.Controller.BackPropagationController

package zdenekdrahos.App.Controller;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Observable;
import java.util.Observer;
import javax.swing.ComboBoxModel;
import javax.swing.DefaultComboBoxModel;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;
import zdenekdrahos.AI.NeuralNetwork.Builder.Activations;
import zdenekdrahos.AI.NeuralNetwork.Builder.INetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Builder.ITrainingBuilder;
import zdenekdrahos.AI.Training.Builder.TrainingBuilder;
import zdenekdrahos.AI.Training.ITraining;
import zdenekdrahos.AI.Training.Output.TrainingOutput;
import zdenekdrahos.AI.Training.TrainingInput;

public class BackPropagationController implements IBackPropagationController, Observer {

    private Map<String, List<double[]>> columns = new HashMap<String, List<double[]>>();
    private ITraining train;
    private TrainingOutput output;
    private TrainingInput input;
    private INeuralNetwork network;
    private INetworkBuilder networkBuilder;

    public BackPropagationController() {
        ITrainingBuilder trainBuilder = new TrainingBuilder();
        train = trainBuilder.build();
        networkBuilder = new NetworkBuilder();
    }

    @Override
    public void update(Observable o, Object arg) {
        if (o instanceof SourceDataController) {
            setSourceData((Double[][]) arg);
        } else if (o instanceof KohonenMapController) {
            setKohonenData((List<List<Double[]>>) arg);
        }
    }

    @Override
    public ComboBoxModel createComboBoxModel() {
        String[] keys = columns.keySet().toArray(new String[0]);
        return new DefaultComboBoxModel(keys);
    }

    @Override
    public void buildNetwork(String topology, String activationFunctions) {
        network = networkBuilder.build(getTopology(topology), getActivations(activationFunctions));
    }

    private int[] getTopology(String topologyString) {
        List<Integer> topology = new ArrayList<Integer>();
        for (String string : splitBySpace(topologyString)) {
            topology.add(Integer.parseInt(string));
        }
        int[] array = new int[topology.size()];
        for (int i = 0; i < array.length; i++) {
            array[i] = topology.get(i);
        }
        return array;
    }

    private String[] splitBySpace(String text) {
        return text.split("\\s+");
    }

    private Activations[] getActivations(String activationsString) {
        List<Activations> activations = new ArrayList<Activations>();
        activations.add(Activations.LIN);// for input layer, it's not used but layersCount == functionsCount
        for (String string : splitBySpace(activationsString)) {
            activations.add(getActivationFunction(string));
        }
        return activations.toArray(new Activations[0]);
    }

    private Activations getActivationFunction(String text) {
        if (text.equals("TANH")) {
            return Activations.TANH;
        } else if (text.equals("SIG")) {
            return Activations.SIG;
        } else {
            return Activations.LIN;
        }
    }

    @Override
    public void train(TrainingInput input, double errorToleration, String inputs, String targets) {
        this.input = input;
        this.input.inputs = columns.get(inputs).toArray(new double[0][0]);
        this.input.targets = columns.get(targets).toArray(new double[0][0]);
        output = new TrainingOutput(errorToleration);
        train.train(network, this.input, output);
    }

    @Override
    public TrainingOutput getTrainingOutput() {
        return output;
    }

    @Override
    public XYSeriesCollection getGraphModel() {
        XYSeries original = new XYSeries("Original"), prediction = new XYSeries("Prediction");
        SimulationIterator iterator;
        for (int i = 0; i < input.inputs.length; i++) {
            for (iterator = train.simulate(input.inputs[i]); iterator.hasNext();) {
                original.add(input.inputs[i][0], input.targets[i][0]);
                prediction.add(input.inputs[i][0], iterator.next());
            }
        }
        XYSeriesCollection collection = new XYSeriesCollection();
        collection.addSeries(original);
        collection.addSeries(prediction);
        return collection;

    }

    private void setSourceData(Double[][] data) {
        for (int i = 0; i < data.length; i++) {
            for (int j = 0; j < data[i].length; j++) {
                addNumberToColumn(j, data[i][j], "Source data");
            }
        }
    }

    private void setKohonenData(List<List<Double[]>> groups) {
        List<Double[]> currentGroup;
        int columnNumber = 0;
        for (int groupIndex = 0; groupIndex < groups.size(); groupIndex++) {           
            currentGroup = groups.get(groupIndex);
            for (int j = 0; j < currentGroup.get(0).length; j++) {
                for (int i = 0; i < currentGroup.size(); i++) {
                    addNumberToColumn(columnNumber, currentGroup.get(i)[j], "Kohonen data");
                }
                columnNumber++;
            }
        }
    }

    private void addNumberToColumn(int columnNumber, Double number, String text) {
        String column = getColumnText(columnNumber, text);
        if (!columns.containsKey(column)) {
            columns.put(column, new ArrayList<double[]>());
        }
        columns.get(column).add(new double[]{number});
    }

    private String getColumnText(int columnNumber, String text) {
        return String.format("%s - column #%d", text, columnNumber + 1);
    }
}
TOP

Related Classes of zdenekdrahos.App.Controller.BackPropagationController

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.