"New neuron count is either a decrease or no change: "
+ neuronCount);
}
// access the flat network
final FlatNetwork flat = this.network.getStructure().getFlat();
final double[] oldWeights = flat.getWeights();
// first find out how many connections there will be after this prune.
int connections = oldWeights.length;
int inBoundConnections = 0;
int outBoundConnections = 0;
// are connections added from the previous layer?
if (targetLayer > 0) {
inBoundConnections = this.network
.getLayerTotalNeuronCount(targetLayer - 1);
connections += inBoundConnections * increaseBy;
}
// are there connections added from the next layer?
if (targetLayer < (this.network.getLayerCount() - 1)) {
outBoundConnections = this.network
.getLayerNeuronCount(targetLayer + 1);
connections += outBoundConnections * increaseBy;
}
// increase layer count
final int flatLayer = this.network.getLayerCount() - targetLayer - 1;
flat.getLayerCounts()[flatLayer] += increaseBy;
flat.getLayerFeedCounts()[flatLayer] += increaseBy;
// allocate new weights now that we know how big the new weights will be
final double[] newWeights = new double[connections];
// construct the new weights
int weightsIndex = 0;
int oldWeightsIndex = 0;
for (int fromLayer = flat.getLayerCounts().length - 2; fromLayer >= 0; fromLayer--) {
final int fromNeuronCount = this.network
.getLayerTotalNeuronCount(fromLayer);
final int toNeuronCount = this.network
.getLayerNeuronCount(fromLayer + 1);
final int toLayer = fromLayer + 1;
for (int toNeuron = 0; toNeuron < toNeuronCount; toNeuron++) {
for (int fromNeuron = 0; fromNeuron < fromNeuronCount; fromNeuron++) {
if ((toLayer == targetLayer)
&& (toNeuron >= oldNeuronCount)) {
newWeights[weightsIndex++] = 0;
} else if ((fromLayer == targetLayer)
&& (fromNeuron > oldNeuronCount)) {
newWeights[weightsIndex++] = 0;
} else {
newWeights[weightsIndex++] = this.network.getFlat().getWeights()[oldWeightsIndex++];
}
}
}
}
// swap in the new weights
flat.setWeights(newWeights);
// reindex
reindexNetwork();
}