/*
* JAVA Neural Networks (https://bitbucket.org/zdenekdrahos/java-neural-networks)
* @license New BSD License
* @author Zdenek Drahos
*/
package zdenekdrahos.AI.NeuralNetwork.Weights;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Layers.ILayer;
public class Weights implements IWeights {
private Map<Integer, List<List<IWeight>>> weights;
@Override
public IWeight getConnectionWeight(int layerIndex, int neuronIndex, int previousNeuronIndex) {
return weights.get(layerIndex).get(neuronIndex).get(previousNeuronIndex);
}
@Override
public List<IWeight> getNeuronWeights(int layerIndex, int neuronIndex) {
return weights.get(layerIndex).get(neuronIndex);
}
@Override
public void setConnectionWeight(int layerIndex, int neuronIndex, int previousNeuronIndex, double newWeight) {
IWeight weight = getConnectionWeight(layerIndex, neuronIndex, previousNeuronIndex);
weight.setWeight(newWeight);
}
@Override
public int getLayersCount() {
return weights == null ? 0 : (weights.size() + 1);
}
@Override
public void create(INeuralNetwork network) {
create(network, new double[0][0][0]);
}
@Override
public void create(INeuralNetwork network, double[][][] initialWeights) {
weights = new HashMap<Integer, List<List<IWeight>>>();
ILayer currentLayer, previousLayer;
IWeight weight;
for (int layerIndex = 1; layerIndex < network.getLayersCount(); layerIndex++) {
previousLayer = network.getLayer(layerIndex - 1);
currentLayer = network.getLayer(layerIndex);
createListsForLayer(layerIndex, currentLayer);
for (int neuronIndex = 0; neuronIndex < currentLayer.getNeuronsCount(); neuronIndex++) {
createListForNeuron(layerIndex, previousLayer);
for (int previousNeuronIndex = 0; previousNeuronIndex <= previousLayer.getNeuronsCount(); previousNeuronIndex++) {
weight = new Weight();
try {
weight.setWeight(initialWeights[layerIndex - 1][neuronIndex][previousNeuronIndex]);
} catch (IndexOutOfBoundsException e) {
weight.generateWeight();
}
weights.get(layerIndex).get(neuronIndex).add(weight);
}
}
}
}
private void createListsForLayer(int layerIndex, ILayer currentLayer) {
ArrayList<List<IWeight>> layerWeights = new ArrayList<List<IWeight>>(currentLayer.getNeuronsCount());
weights.put(layerIndex, layerWeights);
}
private void createListForNeuron(int layerIndex, ILayer previousLayer) {
int neuronsCount = previousLayer.getNeuronsCount() + 1; // + bias
List<IWeight> neuronWeights = new ArrayList<IWeight>(neuronsCount);
weights.get(layerIndex).add(neuronWeights);
}
}