/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neuroph.core;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Observable;
import java.util.Random;
import org.neuroph.core.exceptions.VectorSizeMismatchException;
import org.neuroph.core.learning.IterativeLearning;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.plugins.LabelsPlugin;
import org.neuroph.util.plugins.PluginBase;
/**
*<pre>
* Base class for artificial neural networks. It provides generic structure and functionality
* for the neural networks. Neural network contains a collection of neuron layers and learning rule.
* Custom neural networks are created by deriving from this class, creating layers of interconnected network specific neurons,
* and setting network specific learning rule.
*</pre>
*
* @see Layer
* @see LearningRule
* @author Zoran Sevarac <sevarac@gmail.com>
*/
public class NeuralNetwork extends Observable implements Runnable, Serializable {
/**
* The class fingerprint that is set to indicate serialization
* compatibility with a previous version of the class.
*/
private static final long serialVersionUID = 4L;
/**
* Network type id (see neuroph.util.NeuralNetworkType)
*/
private NeuralNetworkType type;
/**
* Neural network
*/
private List<Layer> layers;
/**
* Reference to network input neurons
*/
private List<Neuron> inputNeurons;
/**
* Reference to newtwork output neurons
*/
private List<Neuron> outputNeurons;
/**
* Learning rule for this network
*/
private LearningRule learningRule; // learning algorithme
/**
* Separate thread for learning rule
*/
private transient Thread learningThread; // thread for learning rule
/**
* Plugins collection
*/
private Map<Class, PluginBase> plugins;
/**
* Label for this network
*/
private String label="";
/**
* Creates an instance of empty neural network.
*/
public NeuralNetwork() {
this.layers = new ArrayList<Layer>();
this.plugins = new HashMap<Class, PluginBase>();
this.addPlugin(new LabelsPlugin());
}
/**
* Adds layer to neural network
*
* @param layer
* layer to add
*/
public void addLayer(Layer layer) {
layer.setParentNetwork(this);
this.layers.add(layer);
}
/**
* Adds layer to specified index position in network
*
* @param idx
* index position to add layer
* @param layer
* layer to add
*/
public void addLayer(int idx, Layer layer) {
layer.setParentNetwork(this);
this.layers.add(idx, layer);
}
/**
* Removes specified layer from network
*
* @param layer
* layer to remove
*/
public void removeLayer(Layer layer) {
this.layers.remove(layer);
}
/**
* Removes layer at specified index position from net
*
* @param idx
* int value represents index postion of layer which should be
* removed
*/
public void removeLayerAt(int idx) {
this.layers.remove(idx);
}
/**
* Returns interface for iterating layers
*
* @return iterator interface for network getLayersIterator
*/
public Iterator<Layer> getLayersIterator() {
return this.layers.iterator();
}
/**
* Returns layers collection
*
* @return layers collection
*/
public List<Layer> getLayers() {
return this.layers;
}
/**
* Returns layer at specified index
*
* @param idx
* layer index position
* @return layer at specified index position
*/
public Layer getLayerAt(int idx) {
return this.layers.get(idx);
}
/**
* Returns index position of the specified layer
*
* @param layer
* requested Layer object
* @return layer position index
*/
public int indexOf(Layer layer) {
return this.layers.indexOf(layer);
}
/**
* Returns number of layers in network
*
* @return number of layes in net
*/
public int getLayersCount() {
return this.layers.size();
}
/**
* Sets network input. Input is array of double values.
*
* @param inputArray
* network input as double array
*/
public void setInput(double ... inputVector) throws VectorSizeMismatchException {
if (inputVector.length != inputNeurons.size())
throw new VectorSizeMismatchException("Input vector size does not match network input dimension!");
int i = 0;
for(Neuron neuron : this.inputNeurons) {
neuron.setInput(inputVector[i]); // set input to the coresponding neuron
i++;
}
}
/**
* Returns network output Vector. Output Vector is a collection of Double
* values.
*
* @return network output Vector
*/
public double[] getOutput() {
double[] outputVector = new double[outputNeurons.size()];
int i = 0;
for(Neuron neuron : this.outputNeurons) {
outputVector[i] = neuron.getOutput();
i++;
}
return outputVector;
}
// /**
// * Returns network output vector as double array
// *
// * @return network output vector as double array
// */
// public double[] getOutputAsArray() {
// return VectorParser.convertToArray(getOutput());
// }
/**
* Performs calculation on whole network
*/
public void calculate() {
for(Layer layer : this.layers) {
layer.calculate();
}
}
/**
* Resets the activation levels for whole network
*/
public void reset() {
for(Layer layer : this.layers) {
layer.reset();
}
}
/**
* Implementation of Runnable interface for calculating network in the
* separate thread.
*/
@Override
public void run() {
this.calculate();
}
/**
* Starts learning in a new thread to learn the specified training set,
* and immediately returns from method to the current thread execution
* @param trainingSetToLearn
* set of training elements to learn
*/
public void learnInNewThread(TrainingSet trainingSetToLearn) {
learningRule.setTrainingSet(trainingSetToLearn);
learningThread = new Thread(learningRule);
learningRule.setStarted();
learningThread.start();
}
/**
* Starts learning with specified learning rule in new thread to learn the
* specified training set, and immediately returns from method to the current thread execution
* @param trainingSetToLearn
* set of training elements to learn
* @param learningRule
* learning algorithm
*/
public void learnInNewThread(TrainingSet trainingSetToLearn, LearningRule learningRule) {
setLearningRule(learningRule);
learningRule.setTrainingSet(trainingSetToLearn);
learningThread = new Thread(learningRule);
learningRule.setStarted();
learningThread.start();
}
/**
* Starts the learning in the current running thread to learn the specified
* training set, and returns from method when network is done learning
* @param trainingSetToLearn
* set of training elements to learn
*/
public void learnInSameThread(TrainingSet trainingSetToLearn) {
learningRule.setTrainingSet(trainingSetToLearn);
learningRule.setStarted();
learningRule.run();
}
/**
* Starts the learning with specified learning rule in the current running
* thread to learn the specified training set, and returns from method when network is done learning
* @param trainingSetToLearn
* set of training elements to learn
* @param learningRule
* learning algorithm
* *
*/
public void learnInSameThread(TrainingSet trainingSetToLearn, LearningRule learningRule) {
setLearningRule(learningRule);
learningRule.setTrainingSet(trainingSetToLearn);
learningRule.setStarted();
learningRule.run();
}
/**
* Stops learning
*/
public void stopLearning() {
learningRule.stopLearning();
}
/**
* Pause the learning - puts learning thread in wait state.
* Makes sense only wen learning is done in new thread with learnInNewThread() method
*/
public void pauseLearning() {
if ( learningRule instanceof IterativeLearning)
((IterativeLearning)learningRule).pause();
}
/**
* Resumes paused learning - notifies the learning thread to continue
*/
public void resumeLearning() {
if ( learningRule instanceof IterativeLearning)
((IterativeLearning)learningRule).resume();
}
/**
* Randomizes connection weights for the whole network
*/
public void randomizeWeights() {
for(Layer layer : this.layers) {
layer.randomizeWeights();
}
}
/**
* Randomizes connection weights for the whole network within specified value range
*/
public void randomizeWeights(double minWeight, double maxWeight) {
for(Layer layer : this.layers) {
layer.randomizeWeights(minWeight, maxWeight);
}
}
/**
* Initialize connection weights for the whole network to a value
*
* @param value the weight value
*/
public void initializeWeights(double value) {
for(Layer layer : this.layers) {
layer.initializeWeights(value);
}
}
/**
* Initialize connection weights for the whole network using a
* random number generator
*
* @param generator the random number generator
*/
public void initializeWeights(Random generator) {
for(Layer layer : this.layers) {
layer.initializeWeights(generator);
}
}
public void initializeWeights(double min, double max) {
for(Layer layer : this.layers) {
layer.initializeWeights(min, max);
}
}
/**
* Returns type of this network
*
* @return network type
*/
public NeuralNetworkType getNetworkType() {
return type;
}
/**
* Sets type for this network
*
* @param type network type
*/
public void setNetworkType(NeuralNetworkType type) {
this.type = type;
}
/**
* Gets reference to input neurons Vector.
*
* @return input neurons Vector
*/
public List<Neuron> getInputNeurons() {
return this.inputNeurons;
}
/**
* Sets reference to input neurons Vector
*
* @param inputNeurons
* input neurons collection
*/
public void setInputNeurons(List<Neuron> inputNeurons) {
this.inputNeurons = inputNeurons;
}
/**
* Returns reference to output neurons Vector.
*
* @return output neurons Vector
*/
public List<Neuron> getOutputNeurons() {
return this.outputNeurons;
}
/**
* Sets reference to output neurons Vector.
*
* @param outputNeurons
* output neurons collection
*/
public void setOutputNeurons(List<Neuron> outputNeurons) {
this.outputNeurons = outputNeurons;
}
/**
* Returns the learning algorithm of this network
*
* @return algorithm for network training
*/
public LearningRule getLearningRule() {
return this.learningRule;
}
/**
* Sets learning algorithm for this network
*
* @param learningRule learning algorithm for this network
*/
public void setLearningRule(LearningRule learningRule) {
learningRule.setNeuralNetwork(this);
this.learningRule = learningRule;
}
/**
* Returns the current learning thread (if it is learning in the new thread
* Check what happens if it learns in the same thread)
*/
public Thread getLearningThread() {
return learningThread;
}
/**
* Notifies observers about some change
*/
public void notifyChange() {
setChanged();
notifyObservers();
clearChanged();
}
/**
* Creates connection with specified weight value between specified neurons
*
* @param fromNeuron neuron to connect
* @param toNeuron neuron to connect to
* @param weightVal connection weight value
*/
public void createConnection(Neuron fromNeuron, Neuron toNeuron, double weightVal) {
Connection connection = new Connection(fromNeuron, toNeuron, weightVal);
toNeuron.addInputConnection(connection);
}
@Override
public String toString() {
if (plugins.containsKey("LabelsPlugin")) {
LabelsPlugin labelsPlugin = ((LabelsPlugin)this.getPlugin(LabelsPlugin.class));
String label = labelsPlugin.getLabel(this);
if (label!=null) return label;
}
return super.toString();
}
/**
* Saves neural network into the specified file.
*
* @param filePath
* file path to save network into
*/
public void save(String filePath) {
ObjectOutputStream out = null;
try {
File file = new File(filePath);
out = new ObjectOutputStream( new BufferedOutputStream( new FileOutputStream(file)));
out.writeObject(this);
out.flush();
} catch(IOException ioe) {
ioe.printStackTrace();
} finally {
if(out != null) {
try {
out.close();
} catch (IOException e) {
}
}
}
}
/**
* Loads neural network from the specified file.
*
* @param filePath
* file path to load network from
* @return loaded neural network as NeuralNetwork object
*/
public static NeuralNetwork load(String filePath) {
ObjectInputStream oistream = null;
try {
File file = new File(filePath);
if (!file.exists()) {
throw new FileNotFoundException("Cannot find file: " + filePath);
}
oistream = new ObjectInputStream( new BufferedInputStream(new FileInputStream(filePath)));
NeuralNetwork nnet = (NeuralNetwork) oistream.readObject();
return nnet;
} catch(IOException ioe) {
ioe.printStackTrace();
} catch(ClassNotFoundException cnfe) {
cnfe.printStackTrace();
} finally {
if(oistream != null) {
try {
oistream.close();
} catch (IOException ioe) {
}
}
}
return null;
}
/**
* Loads neural network from the specified InputStream.
*
* @param inputStream
* input stream to load network from
* @return loaded neural network as NeuralNetwork object
*/
public static NeuralNetwork load(InputStream inputStream) {
ObjectInputStream oistream = null;
try {
oistream = new ObjectInputStream(new BufferedInputStream(inputStream));
NeuralNetwork nnet = (NeuralNetwork) oistream.readObject();
return nnet;
} catch(IOException ioe) {
ioe.printStackTrace();
} catch(ClassNotFoundException cnfe) {
cnfe.printStackTrace();
} finally {
if(oistream != null) {
try {
oistream.close();
} catch (IOException ioe) {
}
}
}
return null;
}
/**
* Adds plugin to neural network
* @param plugin neural network plugin to add
*/
public void addPlugin(PluginBase plugin) {
plugin.setParentNetwork(this);
this.plugins.put(plugin.getClass(), plugin);
}
/**
* Returns the requested plugin
* @param pluginName name of the plugin to get
* @return plugin with specified name
*/
public PluginBase getPlugin(Class pluginClass) {
return this.plugins.get(pluginClass);
}
/**
* Removes the plugin with specified name
* @param pluginName name of the plugin to remove
*/
public void removePlugin(String pluginName) {
this.plugins.remove(pluginName);
}
/**
* Get network label
* @return network label
*/
public String getLabel() {
return label;
}
/**
* Set network label
* @param label network label to set
*/
public void setLabel(String label) {
this.label = label;
}
}