package zdenekdrahos.Testing;
import java.io.FileNotFoundException;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import zdenekdrahos.AI.NeuralNetwork.Builder.Activations;
import zdenekdrahos.AI.NeuralNetwork.Builder.INetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.Training.Builder.ITrainingBuilder;
import zdenekdrahos.AI.Training.Builder.TrainingBuilder;
import zdenekdrahos.AI.Training.ITraining;
import zdenekdrahos.AI.Training.TrainingInput;
import zdenekdrahos.AI.Training.Output.TrainingOutput;
import zdenekdrahos.App.Controller.SourceDataController;
import zdenekdrahos.AI.Training.Simulation.SimulationIterator;
public class MainTrainingData {
static double[][] input, output;
public static void main(String[] args) throws FileNotFoundException {
SourceDataController control = new SourceDataController();
control.loadCsvFile("import-data/data-min.csv", 2, ';');
Double[][] data = control.getData();
input = new double[data.length][1];
output = new double[data.length][1];
final XYSeries series1 = new XYSeries("Original");
for (int i = 0; i < data.length; i++) {
input[i] = new double[]{data[i][0]};
output[i] = new double[]{data[i][1]};
series1.add(data[i][0], data[i][1]);
}
int[] topology = {1, 2, 1};
Activations[] activations = {Activations.LIN, Activations.SIG, Activations.LIN};
INetworkBuilder networkBuilder = new NetworkBuilder();
INeuralNetwork network = networkBuilder.build(topology, activations);
TrainingInput in = new TrainingInput();
in.iterationsCount = 2000;
in.learningRate = 0.1;
in.momentum = 0;
in.inputs = input;
in.targets = output;
ITrainingBuilder trainBuilder = new TrainingBuilder();
ITraining train = trainBuilder.build();
TrainingOutput out = new TrainingOutput(0.00005);
train.train(network, in, out);
System.out.printf("Solution in %d iteration, minError = %f, lastError: %f\n\n", out.lastIterationNumber, out.minError, out.lastError);
final XYSeries series2 = new XYSeries("Prediction");
SimulationIterator iterator;
for (int i = 0; i < input.length; i++) {
for (iterator = train.simulate(input[i]); iterator.hasNext();) {
series2.add(input[i][0], iterator.next());
}
}
XYSeriesCollection collection = new XYSeriesCollection();
collection.addSeries(series1);
collection.addSeries(series2);
displayGraph(collection);
}
private static void displayGraph(XYSeriesCollection collection) {
final JFreeChart chart = ChartFactory.createXYLineChart(
"XY Series Demo",
"X",
"Y",
collection,
PlotOrientation.VERTICAL,
true,
true,
false);
final ChartPanel chartPanel = new ChartPanel(chart);
ChartFrame frame = new ChartFrame("Test", chart);
frame.pack();
frame.setVisible(true);
}
}