Package zdenekdrahos.AI.NeuralNetwork.Builder

Source Code of zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder

/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/

package zdenekdrahos.AI.NeuralNetwork.Builder;

import zdenekdrahos.AI.ActivationFunctions.ActivationFactory;
import zdenekdrahos.AI.ActivationFunctions.IActivationFactory;
import zdenekdrahos.AI.ActivationFunctions.IActivationFunction;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Layers.ILayer;
import zdenekdrahos.AI.NeuralNetwork.Layers.Layer;
import zdenekdrahos.AI.NeuralNetwork.NeuralNetwork;

public class NetworkBuilder implements INetworkBuilder {

    private IActivationFactory factory;
    private Activations defaultHiddenLayer, defaultOutputLayer;

    public NetworkBuilder() {
        this(new ActivationFactory());
    }

    public NetworkBuilder(IActivationFactory factory) {
        this.factory = factory;
        defaultHiddenLayer = Activations.TANH;
        defaultOutputLayer = Activations.LIN;
    }

    @Override
    public void setDefaultHiddenLayer(Activations hiddenLayer) {
        defaultHiddenLayer = hiddenLayer;
    }

    @Override
    public void setDefaultOutputLayer(Activations outputLayer) {
        defaultOutputLayer = outputLayer;
    }

    @Override
    public INeuralNetwork build(int[] topology) {       
        return build(topology, buildActivationsTopology(topology.length));
    }

    @Override
    public INeuralNetwork build(int[] topology, Activations[] functions) {
        if (topology.length != functions.length) {
            throw new IllegalArgumentException("Layers & functions - different length");
        }
        INeuralNetwork network = new NeuralNetwork();
        ILayer layer;
        for (int i = 0; i < topology.length; i++) {
            layer = new Layer(topology[i], getActivation(functions[i]));
            network.addLayer(layer);
        }
        network.generateWeights();
        return network;
    }

    private Activations[] buildActivationsTopology(int topologySize) {
        Activations[] functions = new Activations[topologySize];
        for (int i = 0; i < topologySize; i++) {
            if (i == 0) {
                functions[i] = Activations.LIN;
            } else if (i == (topologySize - 1)) {
                functions[i] = defaultOutputLayer;
            } else {
                functions[i] = defaultOutputLayer;
            }
        }
        return functions;
    }

    private IActivationFunction getActivation(Activations activation) {
        switch (activation) {
            case SIG:
                return factory.getSigmoid();
            case TANH:
                return factory.getHyperbolicTangent();
            default:
                return factory.getLinearFunction();
        }
    }
}
TOP

Related Classes of zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.