Package org.neuroph.nnet.flat

Source Code of org.neuroph.nnet.flat.FlatNetworkLearning

/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neuroph.nnet.flat;

import java.io.Serializable;

import org.encog.engine.EncogEngineError;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkBackPropagation;
import org.encog.engine.network.train.prop.TrainFlatNetworkManhattan;
import org.encog.engine.network.train.prop.TrainFlatNetworkResilient;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.core.learning.TrainingSet;

/**
* Learning rule for flat networks.  The learning type is specified by the the learningType
* property.  Currently the flat networks can be trained in the following three ways.
*
* Classic momentum based propagation.  The learning rate and momentum are specified by using the
* setLearningRate and setMomentum on the FlatNetworkLearning class.
*
* Resilient Propagation.  No parameters are needed for this learning method.  RPROP is the best
* general purpose training method.
*
* Manhattan update rule.  Not at all a good general purpose learning method, only useful in
* some situations.  A learning rate must be specified by using the setLearningRate method on
* the FlatNetworkLearning class.
*
*
* @author Jeff Heaton (http://www.jeffheaton.com)
*
*/
public class FlatNetworkLearning extends SupervisedLearning implements Serializable {

    @Override
    protected void updatePatternError(double[] patternError) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    protected void updateTotalNetworkError() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

  /**
   * The object serial id.
   */
 
  private static final long serialVersionUID = 1L;
  /**
   * The flat network to train.
   */
  private FlatNetwork flat;
 
  /**
   * The last training set used.  This property is used to tell if the training set has
   * changed since the last iteration.
   */
  private transient EngineIndexableSet lastTrainingSet;
 
  /**
   * The last type of learning used.  This property is used to tell if the training set has
   * changed since the last iteration.
   */
  private FlatLearningType lastLearningType;
 
  /**
   * The flat trainer that is in use.
   */
  private transient TrainFlatNetwork training;
 
  /**
   * The learning rate.  This value is used for backprop and Manhattan.
   */
  private double learningRate;
 
  /**
   * The momentum.  This value is used for backpropagation training.
   */
  private double momentum;
 
  /**
   * The number of threads to use.  The default, zero, specifies to use the processor count
   * to determine an optimal number of threads.
   */
  private int numThreads;
 
  /**
   * The type of flat learning to be used.
   */
  private FlatLearningType learningType;

  /**
   * Constructor, create a flat network learning rule from the specified flat network.
   * @param flat The flat network.
   */
  public FlatNetworkLearning(FlatNetwork flat) {
    init(flat);
  }

  /**
   * Constructor, create a flat network learning rule from a Neuroph network.  This neuroph
   * network must have already had the flat network plugin installed.
   * @param network The neuroph network to use.
   */
  public FlatNetworkLearning(NeuralNetwork network) {
    FlatNetworkPlugin plugin = (FlatNetworkPlugin) network
        .getPlugin(FlatNetworkPlugin.class);
    FlatNetwork flat = plugin.getFlatNetwork();
    if (this.flat == null)
      throw new EncogEngineError(
          "This learning rule only works with a network that has a FlatNetworkPlugin attached.");
    init(flat);
  }
 
  /**
   * Internal method used to setup the learning rule.
   * @param flat The network that will be trained.
   */
  private void init(FlatNetwork flat)
  {
    this.flat = flat;
    this.learningType = FlatLearningType.ResilientPropagation;
    this.learningRate = 0.7;
    this.momentum = 0.3;
  }

  /**
   * This method is not used.
   * @param patternError Not used.
   */
  @Override
  protected void updateNetworkWeights(double[] patternError) {
    throw new EncogEngineError(
        "Method (updateNetworkWeights) is unimplemented and should not have been called.");

  }

  /**
   * This method is not used.
   * @param patternError Not used.
   */
  @Override
  protected void updateTotalNetworkError(double[] patternError) {
    throw new EncogEngineError(
        "Method (updateTotalNetworkError) is unimplemented and should not have been called.");

  }

  /**
   * Perform one learning epoch.
   * @param trainingSet The training set to use.
   */
  @Override
  public void doLearningEpoch(TrainingSet trainingSet) {

    this.previousEpochError = this.totalNetworkError;

    // have we changed learning types, or training sets?  If so, create a whole new trainer.
    if (this.lastLearningType != this.learningType || this.lastTrainingSet != trainingSet) {
     
      this.lastTrainingSet = trainingSet;
     
      switch( this.learningType )
      {
        case ResilientPropagation:
          this.training = new TrainFlatNetworkResilient(this.flat,
              this.lastTrainingSet);
          break;
         
        case ManhattanUpdateRule:
          this.training = new TrainFlatNetworkManhattan(this.flat,
              this.lastTrainingSet, this.learningRate);
          break;
         
        case BackPropagation:
          this.training = new TrainFlatNetworkBackPropagation(this.flat,
              this.lastTrainingSet, this.getLearningRate(), this.getMomentum());         
          break;
      }
     
      this.training.setNumThreads(this.numThreads);
     
      // remember the last state of the learning type and trianing set so that we can
      // tell if this changes in the next iteration.
      this.lastLearningType = this.learningType;
      this.lastTrainingSet = trainingSet;
    }
   

    // perform a training iteration
    this.training.iteration();
    this.totalNetworkError = this.training.getError();

    // should Neuroph training stop
    if (hasReachedStopCondition()) {
      stopLearning();
    }
  }

  /**
   * @return The learning rate.  This value is used for Backprop and Manhattan training.
   */
  public double getLearningRate() {
    return learningRate;
  }

  /**
   * Set the learning rate.  This value is used for Backprop and Manhattan training.
   */
  public void setLearningRate(double learningRate) {
    this.learningRate = learningRate;
  }

  /**
   * @return The momentum, this is is used for Backprop.
   */
  public double getMomentum() {
    return momentum;
  }

  /**
   * Set the momentum, this is used for Backprop.
   * @param momentum
   */
  public void setMomentum(double momentum) {
    this.momentum = momentum;
  }

  /**
   * @return The number of threads to use.  Zero requests that the learning
   * rule choose a number of processors.
   */
  public int getNumThreads() {
    return numThreads;
  }

  /**
   * Set the number of threads to use.
   * @param numThreads
   */
  public void setNumThreads(int numThreads) {
    this.numThreads = numThreads;
  }

  /**
   * @return The type of learning to use.
   */
  public FlatLearningType getLearningType() {
    return learningType;
  }

  /**
   * Set the type of learning to use.
   * @param learningType The type of learning to use.
   */
  public void setLearningType(FlatLearningType learningType) {
    this.learningType = learningType;
  }
}
TOP

Related Classes of org.neuroph.nnet.flat.FlatNetworkLearning

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.