/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.wiieditor.neuralnetwork;
import com.wiieditor.other.ProgramConfig;
import java.io.File;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.simple.EncogUtility;
/**
*
* @author zohaibrauf
*/
public class NeuralNetworkWrapper {
private BasicNetwork neuralNetwork=null;
private MLDataSet trainingSet=null;
private String filePath=null;
public NeuralNetworkWrapper(int inputNeurons,int outputNeurons){
neuralNetwork=new BasicNetwork();
neuralNetwork.addLayer(new BasicLayer(null,true,inputNeurons));
neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),true,ProgramConfig.NEURAL_HIDDEN_LAYER_LENGTH));
neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),false,outputNeurons));
neuralNetwork.getStructure().finalizeStructure();
neuralNetwork.reset();
}
public NeuralNetworkWrapper(String filePath){
this.filePath=filePath;
neuralNetwork=(BasicNetwork)EncogDirectoryPersistence.loadObject(new File(filePath));
}
public NeuralNetworkWrapper(String filePath,int inputNeurons,int outputNeurons){
this.filePath=filePath;
File file=new File(filePath);
if(file==null || !file.exists()){
neuralNetwork=new BasicNetwork();
neuralNetwork.addLayer(new BasicLayer(null,true,inputNeurons));
neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),true,ProgramConfig.NEURAL_HIDDEN_LAYER_LENGTH));
neuralNetwork.addLayer(new BasicLayer(new ActivationSigmoid(),false,outputNeurons));
neuralNetwork.getStructure().finalizeStructure();
neuralNetwork.reset();
}else{
neuralNetwork=(BasicNetwork)EncogDirectoryPersistence.loadObject(file);
}
}
private Object getNeuralNetworkTrainingData(NeuralData nd){
MLDataSet trainingSet=new BasicMLDataSet();
MLData mdInput=new BasicMLData(nd.getInputVector());
MLData mdOuput=new BasicMLData(nd.getOutputVector());
trainingSet.add(mdInput, mdOuput);
return trainingSet;
}
private Object getNeuralNetworkTrainingData(NeuralData nd,MLDataSet trainingSet){
MLData mdInput=new BasicMLData(nd.getInputVector());
MLData mdOuput=new BasicMLData(nd.getOutputVector());
trainingSet.add(mdInput, mdOuput);
return trainingSet;
}
private Object getNeuralNetworkTrainingData(NeuralData [] ndArray){
MLDataSet trainingSet=new BasicMLDataSet();
for(int i=0;i<ndArray.length;i++){
MLData mdInput=new BasicMLData(ndArray[i].getInputVector());
MLData mdOuput=new BasicMLData(ndArray[i].getOutputVector());
trainingSet.add(mdInput, mdOuput);
}
return trainingSet;
}
private Object getNeuralNetworkTrainingData(NeuralData [] ndArray,MLDataSet trainingSet){
for(int i=0;i<ndArray.length;i++){
MLData mdInput=new BasicMLData(ndArray[i].getInputVector());
MLData mdOuput=new BasicMLData(ndArray[i].getOutputVector());
trainingSet.add(mdInput, mdOuput);
}
return trainingSet;
}
public synchronized void learn(NeuralData []ndArray){
if(trainingSet==null){
trainingSet=(MLDataSet)getNeuralNetworkTrainingData(ndArray);
}else{
trainingSet=(MLDataSet)getNeuralNetworkTrainingData(ndArray, trainingSet);
}
ResilientPropagation train = new ResilientPropagation(neuralNetwork, trainingSet);
EncogUtility.trainToError(train, ProgramConfig.NEURAL_NETWOR_MIN_ERROR);
}
public synchronized void learn(NeuralData nd){
if(trainingSet==null){
trainingSet=(MLDataSet)getNeuralNetworkTrainingData(nd);
}else{
trainingSet=(MLDataSet)getNeuralNetworkTrainingData(nd, trainingSet);
}
ResilientPropagation train = new ResilientPropagation (neuralNetwork, trainingSet);
EncogUtility.trainToError(train, ProgramConfig.NEURAL_NETWOR_MIN_ERROR);
}
public void save(){
String saveFilePath=ProgramConfig.NEURAL_NETWORK_FILE_PATH;
if(filePath!=null){
saveFilePath=filePath;
}
EncogDirectoryPersistence.saveObject(new File(saveFilePath), neuralNetwork);
}
public NeuralData calculateAndGetOuput(NeuralData nd){
MLData mdInput=new BasicMLData(nd.getInputVector());
MLData mdOutput=neuralNetwork.compute(mdInput);
nd.setOutputVector(mdOutput.getData());
return nd;
}
}