/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/
package zdenekdrahos.AI.NeuralNetwork.Builder;
import zdenekdrahos.AI.ActivationFunctions.ActivationFactory;
import zdenekdrahos.AI.ActivationFunctions.IActivationFactory;
import zdenekdrahos.AI.ActivationFunctions.IActivationFunction;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Layers.ILayer;
import zdenekdrahos.AI.NeuralNetwork.Layers.Layer;
import zdenekdrahos.AI.NeuralNetwork.NeuralNetwork;
public class NetworkBuilder implements INetworkBuilder {
private IActivationFactory factory;
private Activations defaultHiddenLayer, defaultOutputLayer;
public NetworkBuilder() {
this(new ActivationFactory());
}
public NetworkBuilder(IActivationFactory factory) {
this.factory = factory;
defaultHiddenLayer = Activations.TANH;
defaultOutputLayer = Activations.LIN;
}
@Override
public void setDefaultHiddenLayer(Activations hiddenLayer) {
defaultHiddenLayer = hiddenLayer;
}
@Override
public void setDefaultOutputLayer(Activations outputLayer) {
defaultOutputLayer = outputLayer;
}
@Override
public INeuralNetwork build(int[] topology) {
return build(topology, buildActivationsTopology(topology.length));
}
@Override
public INeuralNetwork build(int[] topology, Activations[] functions) {
if (topology.length != functions.length) {
throw new IllegalArgumentException("Layers & functions - different length");
}
INeuralNetwork network = new NeuralNetwork();
ILayer layer;
for (int i = 0; i < topology.length; i++) {
layer = new Layer(topology[i], getActivation(functions[i]));
network.addLayer(layer);
}
network.generateWeights();
return network;
}
private Activations[] buildActivationsTopology(int topologySize) {
Activations[] functions = new Activations[topologySize];
for (int i = 0; i < topologySize; i++) {
if (i == 0) {
functions[i] = Activations.LIN;
} else if (i == (topologySize - 1)) {
functions[i] = defaultOutputLayer;
} else {
functions[i] = defaultOutputLayer;
}
}
return functions;
}
private IActivationFunction getActivation(Activations activation) {
switch (activation) {
case SIG:
return factory.getSigmoid();
case TANH:
return factory.getHyperbolicTangent();
default:
return factory.getLinearFunction();
}
}
}