Package zdenekdrahos.AI.NeuralNetwork.Weights

Source Code of zdenekdrahos.AI.NeuralNetwork.Weights.Weights

/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/

package zdenekdrahos.AI.NeuralNetwork.Weights;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Layers.ILayer;

public class Weights implements IWeights {
   
    private Map<Integer, List<List<IWeight>>> weights;

    @Override
    public IWeight getConnectionWeight(int layerIndex, int neuronIndex, int previousNeuronIndex) {
        return weights.get(layerIndex).get(neuronIndex).get(previousNeuronIndex);
    }

    @Override
    public List<IWeight> getNeuronWeights(int layerIndex, int neuronIndex) {       
        return weights.get(layerIndex).get(neuronIndex);
    }

    @Override
    public void setConnectionWeight(int layerIndex, int neuronIndex, int previousNeuronIndex, double newWeight) {
        IWeight weight = getConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex);
        weight.setWeight(newWeight);
    }

    @Override
    public int getLayersCount() {
        return weights == null ? 0 : (weights.size() + 1);
    }
   
    @Override
    public void create(INeuralNetwork network) {
        create(network, new double[0][0][0]);
    }

    @Override
    public void create(INeuralNetwork network, double[][][] initialWeights) {
        weights = new HashMap<Integer, List<List<IWeight>>>();
        ILayer currentLayer, previousLayer;
        IWeight weight;
        for (int layerIndex = 1; layerIndex < network.getLayersCount(); layerIndex++) {           
            previousLayer = network.getLayer(layerIndex - 1);
            currentLayer = network.getLayer(layerIndex);
            createListsForLayer(layerIndex, currentLayer);
            for (int neuronIndex = 0; neuronIndex < currentLayer.getNeuronsCount(); neuronIndex++) {
                createListForNeuron(layerIndex, previousLayer);
                for (int previousNeuronIndex = 0; previousNeuronIndex <= previousLayer.getNeuronsCount(); previousNeuronIndex++) {                   
                    weight = new Weight();
                    try {
                        weight.setWeight(initialWeights[layerIndex - 1][neuronIndex][previousNeuronIndex]);
                    } catch (IndexOutOfBoundsException e) {
                        weight.generateWeight();
                    }
                    weights.get(layerIndex).get(neuronIndex).add(weight);
                }
            }
        }
    }

    private void createListsForLayer(int layerIndex, ILayer currentLayer) {
        ArrayList<List<IWeight>> layerWeights = new ArrayList<List<IWeight>>(currentLayer.getNeuronsCount());
        weights.put(layerIndex, layerWeights);
    }

    private void createListForNeuron(int layerIndex, ILayer previousLayer) {
        int neuronsCount = previousLayer.getNeuronsCount() + 1; // + bias
        List<IWeight> neuronWeights = new ArrayList<IWeight>(neuronsCount);
        weights.get(layerIndex).add(neuronWeights);
    }

}
TOP

Related Classes of zdenekdrahos.AI.NeuralNetwork.Weights.Weights

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.