Package com.wiieditor.neuralnetwork

Source Code of com.wiieditor.neuralnetwork.NeuralNetworkWrapper

/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.wiieditor.neuralnetwork;

import com.wiieditor.other.ProgramConfig;
import java.io.File;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.simple.EncogUtility;


/**
*
* @author zohaibrauf
*/
public class NeuralNetworkWrapper {
   
    private BasicNetwork neuralNetwork=null;
    private MLDataSet trainingSet=null;
   
    private String filePath=null;
    public NeuralNetworkWrapper(int inputNeurons,int outputNeurons){
        neuralNetwork=new BasicNetwork();
        neuralNetwork.addLayer(new BasicLayer(null,true,inputNeurons));
  neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),true,ProgramConfig.NEURAL_HIDDEN_LAYER_LENGTH));
  neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),false,outputNeurons));
  neuralNetwork.getStructure().finalizeStructure();
  neuralNetwork.reset();
       
    }
   
    public NeuralNetworkWrapper(String filePath){
        this.filePath=filePath;
        neuralNetwork=(BasicNetwork)EncogDirectoryPersistence.loadObject(new File(filePath));
       
    }
   
    public NeuralNetworkWrapper(String filePath,int inputNeurons,int outputNeurons){
        this.filePath=filePath;
        File file=new File(filePath);
        if(file==null || !file.exists()){
            neuralNetwork=new BasicNetwork();
            neuralNetwork.addLayer(new BasicLayer(null,true,inputNeurons));
            neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),true,ProgramConfig.NEURAL_HIDDEN_LAYER_LENGTH));
            neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),false,outputNeurons));
            neuralNetwork.getStructure().finalizeStructure();
            neuralNetwork.reset();
        }else{
            neuralNetwork=(BasicNetwork)EncogDirectoryPersistence.loadObject(file);
        }
    }
    private Object getNeuralNetworkTrainingData(NeuralData nd){
        MLDataSet trainingSet=new BasicMLDataSet();
       
        MLData mdInput=new BasicMLData(nd.getInputVector());
        MLData mdOuput=new BasicMLData(nd.getOutputVector());
       
               
        trainingSet.add(mdInput, mdOuput);
       
        return trainingSet;
    }
   
    private Object getNeuralNetworkTrainingData(NeuralData nd,MLDataSet trainingSet){
        MLData mdInput=new BasicMLData(nd.getInputVector());
        MLData mdOuput=new BasicMLData(nd.getOutputVector());
       
               
        trainingSet.add(mdInput, mdOuput);
       
        return trainingSet;
    }
   
    private Object getNeuralNetworkTrainingData(NeuralData [] ndArray){
        MLDataSet trainingSet=new BasicMLDataSet();
       
        for(int i=0;i<ndArray.length;i++){
            MLData mdInput=new BasicMLData(ndArray[i].getInputVector());
            MLData mdOuput=new BasicMLData(ndArray[i].getOutputVector());
   
            trainingSet.add(mdInput, mdOuput);
        }
        return trainingSet;
    }
   
    private Object getNeuralNetworkTrainingData(NeuralData [] ndArray,MLDataSet trainingSet){

        for(int i=0;i<ndArray.length;i++){
            MLData mdInput=new BasicMLData(ndArray[i].getInputVector());
            MLData mdOuput=new BasicMLData(ndArray[i].getOutputVector());
   
            trainingSet.add(mdInput, mdOuput);
        }
        return trainingSet;
    }
   
    public synchronized void learn(NeuralData []ndArray){
        if(trainingSet==null){
            trainingSet=(MLDataSet)getNeuralNetworkTrainingData(ndArray);
        }else{
            trainingSet=(MLDataSet)getNeuralNetworkTrainingData(ndArray, trainingSet);
        }
        ResilientPropagation train = new ResilientPropagation(neuralNetwork, trainingSet);
      
        EncogUtility.trainToError(train, ProgramConfig.NEURAL_NETWOR_MIN_ERROR);
    }
   
    public synchronized void learn(NeuralData nd){
        if(trainingSet==null){
            trainingSet=(MLDataSet)getNeuralNetworkTrainingData(nd);
        }else{
            trainingSet=(MLDataSet)getNeuralNetworkTrainingData(nd, trainingSet);
        }
        ResilientPropagation train = new ResilientPropagation (neuralNetwork, trainingSet);
      
       
        EncogUtility.trainToError(train, ProgramConfig.NEURAL_NETWOR_MIN_ERROR);
    }
   
    public void save(){
        String saveFilePath=ProgramConfig.NEURAL_NETWORK_FILE_PATH;
       
        if(filePath!=null){
            saveFilePath=filePath;
        }
        EncogDirectoryPersistence.saveObject(new File(saveFilePath), neuralNetwork);
    }
   
    public NeuralData calculateAndGetOuput(NeuralData nd){
       
        MLData mdInput=new BasicMLData(nd.getInputVector());
        MLData mdOutput=neuralNetwork.compute(mdInput);
       
       
        nd.setOutputVector(mdOutput.getData());
        return nd;
    }
   
   
   
   
}
TOP

Related Classes of com.wiieditor.neuralnetwork.NeuralNetworkWrapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.