Package mechanics

Source Code of mechanics.NeuralNetwork

package mechanics;

import au.com.bytecode.opencsv.CSVReader;

import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.NeuralDataPair;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.layers.Layer;
import org.encog.neural.networks.synapse.Synapse;
import org.encog.neural.networks.synapse.WeightedSynapse;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;

/**
*
* @author Conker
*/
public class NeuralNetwork {

    public BasicNetwork network;
    private NeuralDataSet trainingData;

    public NeuralNetwork() {
        network = new BasicNetwork();
    }

    public int init() {
        int success;
        Layer outputLayer = new BasicLayer(new ActivationSigmoid(), true, 6);
        Layer hiddenLayer1 = new BasicLayer(new ActivationSigmoid(), true, 6);
        Layer inputLayer = new BasicLayer(new ActivationSigmoid(), false, 4);

        Synapse synapse1 = new WeightedSynapse(hiddenLayer1, outputLayer);
        Synapse synapse2 = new WeightedSynapse(inputLayer, hiddenLayer1);

        hiddenLayer1.addSynapse(synapse1);
        inputLayer.addSynapse(synapse2);

        network.tagLayer("INPUT", inputLayer);
        network.tagLayer("OUTPUT", outputLayer);

        network.getStructure().finalizeStructure();
        network.reset();

        createTrainingData();

        success = trainNeuralNetwork();

        //testNeuralNetwork();
        return success;
    }

    private void createTrainingData() {
        String number;
        double[][] INPUT = new double[30][4];
        double[][] IDEAL = new double[30][6];
        try {
            CSVReader reader = new CSVReader(new FileReader("test.csv"), ';');
            List myEntries = reader.readAll();

            for (int i = 0; i < myEntries.size(); i++) {
                for (int j = 0; j < 4; j++) {
                    number = ((String[]) (myEntries.get(i)))[j].replace(",", ".");
                    INPUT[i][j] = Double.valueOf(number);
                }

                for (int j = 4; j < 10; j++) {
                    number = ((String[]) (myEntries.get(i)))[j].replace(",", ".");
                    IDEAL[i][j - 4] = Double.valueOf(number);
                }
            }

            /*
            for(int i = 0; i < 30; i++) {
            for(int j = 0; j < 4; j++)
            System.out.print(INPUT[i][j] + " ");
            System.out.print("\n");
            for(int j = 0; j < 6; j++)
            System.out.print(IDEAL[i][j] + " ");
            System.out.print("\n");
            }
            */

            trainingData = new BasicNeuralDataSet(INPUT, IDEAL);

        } catch (IOException e) {
            System.out.print(e + "Problem opening file");
        }
    }

    private int trainNeuralNetwork() {
        final Train train = new ResilientPropagation(network, trainingData);

        int epoch = 1;
        do {
            train.iteration();
            //System.out.println("Epoch #" + epoch + " Error: " + train.getError());
            epoch++;
            if (epoch > 500) {
                return 1;
            }
        } while (train.getError() > 0.01);
        return 0;
    }

    private void testNeuralNetwork() {
        System.out.print("Neural network results:\n");
        for (NeuralDataPair pair : trainingData) {
            final NeuralData output = network.compute(pair.getInput());
            System.out.println(pair.getInput().getData(0) + ", "
                    + pair.getInput().getData(1) + ", "
                    + pair.getInput().getData(2) + ", "
                    + pair.getInput().getData(3) + ", "
                    + ", actual = " + output.getData(0) + ", "
                    + output.getData(1) + ", "
                    + output.getData(2) + ", "
                    + output.getData(3) + ", "
                    + output.getData(4) + ", "
                    + output.getData(5) + ", " + "\n");
        }
    }

    public double [] query(double [] data) {
        BasicNeuralData neuralData = new BasicNeuralData(data);
        final NeuralData output = network.compute(neuralData);
        return output.getData();
    }
}
TOP

Related Classes of mechanics.NeuralNetwork

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.