Package zdenekdrahos.Testing

Source Code of zdenekdrahos.Testing.MainTrainingData

package zdenekdrahos.Testing;

import java.io.FileNotFoundException;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import zdenekdrahos.AI.NeuralNetwork.Builder.Activations;
import zdenekdrahos.AI.NeuralNetwork.Builder.INetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Builder.ITrainingBuilder;
import zdenekdrahos.AI.Training.Builder.TrainingBuilder;
import zdenekdrahos.AI.Training.ITraining;
import zdenekdrahos.AI.Training.TrainingInput;
import zdenekdrahos.AI.Training.Output.TrainingOutput;
import zdenekdrahos.App.Controller.SourceDataController;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;

public class MainTrainingData {

    static double[][] input, output;

    public static void main(String[] args) throws FileNotFoundException {       
        SourceDataController control = new SourceDataController();
        control.loadCsvFile("import-data/data-min.csv", 2, ';');
        Double[][] data = control.getData();
        input = new double[data.length][1];
        output = new double[data.length][1];


        final XYSeries series1 = new XYSeries("Original");
        for (int i = 0; i < data.length; i++) {
            input[i] = new double[]{data[i][0]};
            output[i] = new double[]{data[i][1]};
            series1.add(data[i][0], data[i][1]);
        }

        int[] topology = {1, 2, 1};
        Activations[] activations = {Activations.LIN, Activations.SIG, Activations.LIN};

        INetworkBuilder networkBuilder = new NetworkBuilder();
        INeuralNetwork network = networkBuilder.build(topology, activations);

        TrainingInput in = new TrainingInput();
        in.iterationsCount = 2000;
        in.learningRate = 0.1;
        in.momentum = 0;
        in.inputs = input;
        in.targets = output;

        ITrainingBuilder trainBuilder = new TrainingBuilder();
        ITraining train = trainBuilder.build();
        TrainingOutput out = new TrainingOutput(0.00005);
        train.train(network, in, out);
        System.out.printf("Solution in %d iteration, minError = %f, lastError: %f\n\n", out.lastIterationNumber, out.minError, out.lastError);

        final XYSeries series2 = new XYSeries("Prediction");

        SimulationIterator iterator;
        for (int i = 0; i < input.length; i++) {
            for (iterator = train.simulate(input[i]); iterator.hasNext();) {               
                series2.add(input[i][0], iterator.next());
            }
        }

        XYSeriesCollection collection = new XYSeriesCollection();
        collection.addSeries(series1);
        collection.addSeries(series2);
        displayGraph(collection);
    }

    private static void displayGraph(XYSeriesCollection collection) {
        final JFreeChart chart = ChartFactory.createXYLineChart(
                "XY Series Demo",
                "X",
                "Y",
                collection,
                PlotOrientation.VERTICAL,
                true,
                true,
                false);

        final ChartPanel chartPanel = new ChartPanel(chart);

        ChartFrame frame = new ChartFrame("Test", chart);
        frame.pack();
        frame.setVisible(true);
    }
}
TOP

Related Classes of zdenekdrahos.Testing.MainTrainingData

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.