/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.prune;
import org.encog.mathutil.randomize.Distort;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.EngineArray;
/**
* Prune a neural network selectively. This class allows you to either add or
* remove neurons from layers of a neural network. You can also randomize or
* stimulate neurons.
*
* No provision is given for removing an entire layer. Removing a layer requires
* a totally new set of weights between the layers before and after the removed
* one. This essentially makes any remaining weights useless. At this point you
* are better off just creating a new network of the desired dimensions.
*
*/
public class PruneSelective {
/**
* The network to prune.
*/
private final BasicNetwork network;
/**
* Construct an object prune the neural network.
*
* @param network
* The network to prune.
*/
public PruneSelective(final BasicNetwork network) {
this.network = network;
}
/**
* Change the neuron count for the network. If the count is increased then a
* zero-weighted neuron is added, which will not affect the output of the
* neural network. If the neuron count is decreased, then the weakest neuron
* will be removed.
*
* This method cannot be used to remove a bias neuron.
*
* @param layer
* The layer to adjust.
* @param neuronCount
* The new neuron count for this layer.
*/
public void changeNeuronCount(final int layer, final int neuronCount) {
if (neuronCount == 0) {
throw new NeuralNetworkError("Can't decrease to zero neurons.");
}
final int currentCount = this.network.getLayerNeuronCount(layer);
// is there anything to do?
if (neuronCount == currentCount) {
return;
}
if (neuronCount > currentCount) {
increaseNeuronCount(layer, neuronCount);
} else {
decreaseNeuronCount(layer, neuronCount);
}
}
/**
* Internal function to decrease the neuron count of a layer.
*
* @param layer
* The layer to affect.
* @param neuronCount
* The new neuron count.
*/
private void decreaseNeuronCount(final int layer, final int neuronCount) {
// create an array to hold the least significant neurons, which will be
// removed
final int lostNeuronCount = this.network.getLayerNeuronCount(layer)
- neuronCount;
final int[] lostNeuron = findWeakestNeurons(layer, lostNeuronCount);
// finally, actually prune the neurons that the previous steps
// determined to remove
for (int i = 0; i < lostNeuronCount; i++) {
prune(layer, lostNeuron[i] - i);
}
}
/**
* Determine the significance of the neuron. The higher the return value,
* the more significant the neuron is.
*
* @param layer
* The layer to query.
* @param neuron
* The neuron to query.
* @return How significant is this neuron.
*/
public double determineNeuronSignificance(final int layer, final int neuron) {
this.network.validateNeuron(layer, neuron);
// calculate the bias significance
double result = 0;
// calculate the inbound significance
if (layer > 0) {
final int prevLayer = layer - 1;
final int prevCount = this.network
.getLayerTotalNeuronCount(prevLayer);
for (int i = 0; i < prevCount; i++) {
result += this.network.getWeight(prevLayer, i, neuron);
}
}
// calculate the outbound significance
if (layer < this.network.getLayerCount() - 1) {
final int nextLayer = layer + 1;
final int nextCount = this.network.getLayerNeuronCount(nextLayer);
for (int i = 0; i < nextCount; i++) {
result += this.network.getWeight(layer, neuron, i);
}
}
return Math.abs(result);
}
/**
* Find the weakest neurons on a layer. Considers both weight and bias.
*
* @param layer
* The layer to search.
* @param count
* The number of neurons to find.
* @return An array of the indexes of the weakest neurons.
*/
private int[] findWeakestNeurons(final int layer, final int count) {
// create an array to hold the least significant neurons, which will be
// returned
final double[] lostNeuronSignificance = new double[count];
final int[] lostNeuron = new int[count];
// init the potential lost neurons to the first ones, we will find
// better choices if we can
for (int i = 0; i < count; i++) {
lostNeuron[i] = i;
lostNeuronSignificance[i] = determineNeuronSignificance(layer, i);
}
// now loop over the remaining neurons and see if any are better ones to
// remove
for (int i = count; i < this.network.getLayerNeuronCount(layer); i++) {
final double significance = determineNeuronSignificance(layer, i);
// is this neuron less significant than one already chosen?
for (int j = 0; j < count; j++) {
if (lostNeuronSignificance[j] > significance) {
lostNeuron[j] = i;
lostNeuronSignificance[j] = significance;
break;
}
}
}
return lostNeuron;
}
/**
* @return The network that is being processed.
*/
public BasicNetwork getNetwork() {
return this.network;
}
/**
* Internal function to increase the neuron count. This will add a
* zero-weight neuron to this layer.
*
* @param targetLayer
* The layer to increase.
* @param neuronCount
* The new neuron count.
*/
private void increaseNeuronCount(final int targetLayer,
final int neuronCount) {
// check for errors
if (targetLayer > this.network.getLayerCount()) {
throw new NeuralNetworkError("Invalid layer " + targetLayer);
}
if (neuronCount <= 0) {
throw new NeuralNetworkError("Invalid neuron count " + neuronCount);
}
final int oldNeuronCount = this.network
.getLayerNeuronCount(targetLayer);
final int increaseBy = neuronCount - oldNeuronCount;
if (increaseBy <= 0) {
throw new NeuralNetworkError(
"New neuron count is either a decrease or no change: "
+ neuronCount);
}
// access the flat network
final FlatNetwork flat = this.network.getStructure().getFlat();
final double[] oldWeights = flat.getWeights();
// first find out how many connections there will be after this prune.
int connections = oldWeights.length;
int inBoundConnections = 0;
int outBoundConnections = 0;
// are connections added from the previous layer?
if (targetLayer > 0) {
inBoundConnections = this.network
.getLayerTotalNeuronCount(targetLayer - 1);
connections += inBoundConnections * increaseBy;
}
// are there connections added from the next layer?
if (targetLayer < (this.network.getLayerCount() - 1)) {
outBoundConnections = this.network
.getLayerNeuronCount(targetLayer + 1);
connections += outBoundConnections * increaseBy;
}
// increase layer count
final int flatLayer = this.network.getLayerCount() - targetLayer - 1;
flat.getLayerCounts()[flatLayer] += increaseBy;
flat.getLayerFeedCounts()[flatLayer] += increaseBy;
// allocate new weights now that we know how big the new weights will be
final double[] newWeights = new double[connections];
// construct the new weights
int weightsIndex = 0;
int oldWeightsIndex = 0;
for (int fromLayer = flat.getLayerCounts().length - 2; fromLayer >= 0; fromLayer--) {
final int fromNeuronCount = this.network
.getLayerTotalNeuronCount(fromLayer);
final int toNeuronCount = this.network
.getLayerNeuronCount(fromLayer + 1);
final int toLayer = fromLayer + 1;
for (int toNeuron = 0; toNeuron < toNeuronCount; toNeuron++) {
for (int fromNeuron = 0; fromNeuron < fromNeuronCount; fromNeuron++) {
if ((toLayer == targetLayer)
&& (toNeuron >= oldNeuronCount)) {
newWeights[weightsIndex++] = 0;
} else if ((fromLayer == targetLayer)
&& (fromNeuron > oldNeuronCount)) {
newWeights[weightsIndex++] = 0;
} else {
newWeights[weightsIndex++] = this.network.getFlat().getWeights()[oldWeightsIndex++];
}
}
}
}
// swap in the new weights
flat.setWeights(newWeights);
// reindex
reindexNetwork();
}
/**
* Prune one of the neurons from this layer. Remove all entries in this
* weight matrix and other layers. This method cannot be used to remove a
* bias neuron.
*
* @param targetLayer
* The neuron to prune. Zero specifies the first neuron.
* @param neuron
* The neuron to prune.
*/
public void prune(final int targetLayer, final int neuron) {
// check for errors
this.network.validateNeuron(targetLayer, neuron);
// don't empty a layer
if (this.network.getLayerNeuronCount(targetLayer) <= 1) {
throw new NeuralNetworkError(
"A layer must have at least a single neuron. If you want to remove the entire layer you must create a new network.");
}
// access the flat network
final FlatNetwork flat = this.network.getStructure().getFlat();
final double[] oldWeights = flat.getWeights();
// first find out how many connections there will be after this prune.
int connections = oldWeights.length;
int inBoundConnections = 0;
int outBoundConnections = 0;
// are connections removed from the previous layer?
if (targetLayer > 0) {
inBoundConnections = this.network
.getLayerTotalNeuronCount(targetLayer - 1);
connections -= inBoundConnections;
}
// are there connections removed from the next layer?
if (targetLayer < (this.network.getLayerCount() - 1)) {
outBoundConnections = this.network
.getLayerNeuronCount(targetLayer + 1);
connections -= outBoundConnections;
}
// allocate new weights now that we know how big the new weights will be
final double[] newWeights = new double[connections];
// construct the new weights
int weightsIndex = 0;
for (int fromLayer = flat.getLayerCounts().length - 2; fromLayer >= 0; fromLayer--) {
final int fromNeuronCount = this.network
.getLayerTotalNeuronCount(fromLayer);
final int toNeuronCount = this.network
.getLayerNeuronCount(fromLayer + 1);
final int toLayer = fromLayer + 1;
for (int toNeuron = 0; toNeuron < toNeuronCount; toNeuron++) {
for (int fromNeuron = 0; fromNeuron < fromNeuronCount; fromNeuron++) {
boolean skip = false;
if ((toLayer == targetLayer) && (toNeuron == neuron)) {
skip = true;
} else if ((fromLayer == targetLayer)
&& (fromNeuron == neuron)) {
skip = true;
}
if (!skip) {
newWeights[weightsIndex++] = this.network.getWeight(
fromLayer, fromNeuron, toNeuron);
}
}
}
}
// swap in the new weights
flat.setWeights(newWeights);
// decrease layer count
final int flatLayer = this.network.getLayerCount() - targetLayer - 1;
flat.getLayerCounts()[flatLayer]--;
flat.getLayerFeedCounts()[flatLayer]--;
// reindex
reindexNetwork();
}
/**
*
* @param low
* The low-end of the range.
* @param high
* The high-end of the range.
* @param targetLayer
* The target layer.
* @param neuron
* The target neuron.
*/
public void randomizeNeuron(final double low, final double high,
final int targetLayer, final int neuron) {
randomizeNeuron(targetLayer, neuron, true, low, high, false, 0.0);
}
/**
* Assign random values to the network. The range will be the min/max of
* existing neurons.
*
* @param targetLayer
* The target layer.
* @param neuron
* The target neuron.
*/
public void randomizeNeuron(final int targetLayer, final int neuron) {
final FlatNetwork flat = this.network.getStructure().getFlat();
final double low = EngineArray.min(flat.getWeights());
final double high = EngineArray.max(flat.getWeights());
randomizeNeuron(targetLayer, neuron, true, low, high, false, 0.0);
}
/**
* Used internally to randomize a neuron. Usually called from
* randomizeNeuron or stimulateNeuron.
*
* @param targetLayer
* The target layer.
* @param neuron
* The target neuron.
* @param useRange
* True if range randomization should be used.
* @param low
* The low-end of the range.
* @param high
* The high-end of the range.
* @param usePercent
* True if percent stimulation should be used.
* @param percent
* The percent to stimulate by.
*/
private void randomizeNeuron(final int targetLayer, final int neuron,
final boolean useRange, final double low, final double high,
final boolean usePercent, final double percent) {
final Randomizer d;
if (useRange) {
d = new RangeRandomizer(low, high);
} else {
d = new Distort(percent);
}
// check for errors
this.network.validateNeuron(targetLayer, neuron);
// access the flat network
final FlatNetwork flat = this.network.getStructure().getFlat();
// allocate new weights now that we know how big the new weights will be
final double[] newWeights = new double[flat.getWeights().length];
// construct the new weights
int weightsIndex = 0;
for (int fromLayer = flat.getLayerCounts().length - 2; fromLayer >= 0; fromLayer--) {
final int fromNeuronCount = this.network
.getLayerTotalNeuronCount(fromLayer);
final int toNeuronCount = this.network
.getLayerNeuronCount(fromLayer + 1);
final int toLayer = fromLayer + 1;
for (int toNeuron = 0; toNeuron < toNeuronCount; toNeuron++) {
for (int fromNeuron = 0; fromNeuron < fromNeuronCount; fromNeuron++) {
boolean randomize = false;
if ((toLayer == targetLayer) && (toNeuron == neuron)) {
randomize = true;
} else if ((fromLayer == targetLayer)
&& (fromNeuron == neuron)) {
randomize = true;
}
double weight = this.network.getWeight(fromLayer,
fromNeuron, toNeuron);
if (randomize) {
weight = d.randomize(weight);
}
newWeights[weightsIndex++] = weight;
}
}
}
// swap in the new weights
flat.setWeights(newWeights);
}
/**
* Creat new index values for the network.
*/
private void reindexNetwork() {
final FlatNetwork flat = this.network.getStructure().getFlat();
int neuronCount = 0;
int weightCount = 0;
for (int i = 0; i < flat.getLayerCounts().length; i++) {
if (i > 0) {
final int from = flat.getLayerFeedCounts()[i - 1];
final int to = flat.getLayerCounts()[i];
weightCount += from * to;
}
flat.getLayerIndex()[i] = neuronCount;
flat.getWeightIndex()[i] = weightCount;
neuronCount += flat.getLayerCounts()[i];
}
flat.setLayerOutput(new double[neuronCount]);
flat.setLayerSums(new double[neuronCount]);
flat.clearContext();
flat.setInputCount(flat.getLayerFeedCounts()[flat.getLayerCounts().length - 1]);
flat.setOutputCount(flat.getLayerFeedCounts()[0]);
}
/**
* Stimulate the specified neuron by the specified percent. This is used to
* randomize the weights and bias values for weak neurons.
*
* @param percent
* The percent to randomize by.
* @param targetLayer
* The layer that the neuron is on.
* @param neuron
* The neuron to randomize.
*/
public void stimulateNeuron(final double percent, final int targetLayer,
final int neuron) {
randomizeNeuron(targetLayer, neuron, false, 0, 0, true, percent);
}
/**
* Stimulate weaker neurons on a layer. Find the weakest neurons and then
* randomize them by the specified percent.
*
* @param layer
* The layer to stimulate.
* @param count
* The number of weak neurons to stimulate.
* @param percent
* The percent to stimulate by.
*/
public void stimulateWeakNeurons(final int layer, final int count,
final double percent) {
final int[] weak = findWeakestNeurons(layer, count);
for (final int element : weak) {
stimulateNeuron(percent, layer, element);
}
}
}