Package org.neuroph.nnet.flat

Source Code of org.neuroph.nnet.flat.FlatNetworkPlugin

/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.neuroph.nnet.flat;

import java.util.HashSet;
import java.util.Set;

import org.encog.engine.EncogEngine;
import org.encog.engine.EncogEngineError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.engine.network.flat.FlatLayer;
import org.encog.engine.network.flat.FlatNetwork;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.core.transfer.Linear;
import org.neuroph.core.transfer.Sigmoid;
import org.neuroph.core.transfer.Tanh;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.BiasNeuron;
import org.neuroph.nnet.comp.InputNeuron;
import org.neuroph.core.learning.TrainingData;
import org.neuroph.util.plugins.PluginBase;

/**
* The FlatNetworkPlugin allows a Neuroph multi-layer-perceptron to make use of the Encog Engine.
* The Encog Engine makes use of flat, array-based, networks to train and process in a very efficient
* manner. 
*
* To make use of this plug-in, call the flattenNeuralNetworkNetwork method with the desired neural
* network.  The network will now train with multithreaded RPROP.  The network can still be used as a
* normal Neuroph network, however, if you change the structure of the network, you should reflatten it.
*
* To remove the plugin, use the unFlattenNeuralNetworkNetwork method.
*
* If you wish to use OpenCL/GPU processing, call the initCL method.
*
* Finally, once you are ready to exit your application, you should call the shutdown
* method to make sure that the threadpool and OpenCL resources have been properly shutdown.
*
* @author Jeff Heaton (http://www.jeffheaton.com)
*/
public class FlatNetworkPlugin extends PluginBase {

  /**
   * The serial ID.
   */
  private static final long serialVersionUID = 1L;

  /**
   * The name of the plugin.
   */
  public static final String PLUGIN_NAME = "FlatNetworkPlugin";

  /**
   * The flat network being used.
   */
  public FlatNetwork flatNetwork;

  /**
   * Construct a flat network plugin.
   * @param network The flat network.
   */
  public FlatNetworkPlugin(FlatNetwork network) {
    super(PLUGIN_NAME);
    this.flatNetwork = network;
  }

  /**
   * @return The actual flat network being used.
   */
  public FlatNetwork getFlatNetwork() {
    return this.flatNetwork;
  }

  /**
   * This method is used to install the driver into a neural network.
   * @param network The neural network to flatten.
   * @return True if the flatening was successful.
   */
  public static boolean flattenNeuralNetworkNetwork(NeuralNetwork network) {
    if (network instanceof MultiLayerPerceptron) {
      return flattenMultiLayerPerceptron((MultiLayerPerceptron) network);
    } else
      // failed to flaten the network
      return false;
  }

  /**
   * Init the OpenCL.
   */
  public static void initCL() {
    EncogEngine.getInstance().initCL();
  }
 
  /**
   * Make sure that the threadpool and OpenCL are properly shutdown.
   */
  public static void shutdown()
  {
    EncogEngine.getInstance().shutdown();
  }

  /**
   * Flatten a multi-layer perceptron.
   * @param network The network to flatten.
   * @return True, if the network was successfully flattened.
   */
  private static boolean flattenMultiLayerPerceptron(
      MultiLayerPerceptron network) {

    FlatLayer[] flatLayers = new FlatLayer[network.getLayers().size()];

    int index = 0;

    for (Layer layer : network.getLayers()) {
      FlatLayer flatLayer = flattenLayer(layer);
      if (flatLayer == null)
        return false;
      flatLayers[index++] = flatLayer;
    }

    FlatNetwork flat = new FlatNetwork(flatLayers);
    FlatNetworkPlugin plugin = new FlatNetworkPlugin(flat);
    network.addPlugin(plugin);

    FlatNetworkLearning training = new FlatNetworkLearning(flat);
    network.setLearningRule(training);

    flattenWeights(flat, network);

    return true;
  }

  /**
   * Flatten the specified layer.
   * @param layer The layer to flatten.
   * @return The flat layer that represents the provided layer.
   */
  private static FlatLayer flattenLayer(Layer layer) {
    boolean inputLayer = false;
    Set<Class<?>> transferFunctions = new HashSet<Class<?>>();
    int neuronCount = 0;
    int biasCount = 0;
    TransferFunction transfer = null;

    for (Neuron neuron : layer.getNeurons()) {
      if (neuron.getClass() == InputNeuron.class)
        inputLayer = true;

      if (neuron.getClass() == Neuron.class
          || neuron.getClass() == InputNeuron.class) {
        neuronCount++;

        transfer = neuron.getTransferFunction();
        transferFunctions.add(transfer.getClass());
      } else if (neuron.getClass() == BiasNeuron.class)
        biasCount++;
    }

    if (transferFunctions.size() > 1)
      return null;

    Class<?> t = transferFunctions.iterator().next();

    double slope = 1;

    ActivationFunction activation = null;
   
    if (inputLayer)
      activation = new ActivationLinear();
    else if (t == Linear.class) {
      slope = ((Linear) transfer).getSlope();
      activation = new ActivationLinear();
    } else if (t == Sigmoid.class) {
      slope = ((Sigmoid) transfer).getSlope();
      activation = new ActivationSigmoid();
    } else if (t == Tanh.class) {
      slope = ((Tanh) transfer).getSlope();
      activation = new ActivationTANH();
    } else
      return null;

    if (biasCount > 1)
      return null;

    double[] params = { slope };
    return new FlatLayer(activation, neuronCount, biasCount == 1 ? 1.0:0.0, params);
  }

  /**
   * Replace all of the weights with FlatWeights.
   * @param flatNetwork The flat network.
   * @param network The neural network.
   */
  private static void flattenWeights(FlatNetwork flatNetwork,
      NeuralNetwork network) {
    double[] weights = flatNetwork.getWeights();

    int index = 0;

    for (int layerIndex = network.getLayers().size() - 1; layerIndex > 0; layerIndex--) {
      Layer layer = network.getLayers().get(layerIndex);

      for (Neuron neuron : layer.getNeurons()) {
        for (Connection connection : neuron.getInputConnections()) {
          if (index >= weights.length)
            throw new EncogEngineError("Weight size mismatch.");

          Weight weight = connection.getWeight();
          FlatWeight flatWeight = new FlatWeight(weights, index++);
          flatWeight.setValue(weight.getValue());
          connection.setWeight(flatWeight);
        }
      }
    }
  }

  /**
   * Remove the flat network plugin, and replace flat weights with regular Neuroph weights.
   *
   * @param network The network to unflatten.
   * @return True if unflattening was successful.
   */
  public static boolean unFlattenNeuralNetworkNetwork(NeuralNetwork network) {

    for (int layerIndex = network.getLayers().size() - 1; layerIndex > 0; layerIndex--) {
      Layer layer = network.getLayers().get(layerIndex);

      for (Neuron neuron : layer.getNeurons()) {
        for (Connection connection : neuron.getInputConnections()) {
          Weight weight = connection.getWeight();

          if (weight instanceof FlatWeight) {
            Weight weight2 = new Weight(weight.getValue());
            //weight2.setPreviousValue(weight.getPreviousValue());
                                                weight2.getTrainingData().set(TrainingData.PREVIOUS_WEIGHT, weight.getTrainingData().get(TrainingData.PREVIOUS_WEIGHT));
            connection.setWeight(weight2);
          }
        }
      }
    }

    network.removePlugin(FlatNetworkPlugin.PLUGIN_NAME);
    return true;
  }
}
TOP

Related Classes of org.neuroph.nnet.flat.FlatNetworkPlugin

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.