Package tv.floe.metronome.classification.neuralnetworks.learning

Source Code of tv.floe.metronome.classification.neuralnetworks.learning.BackPropogationLearningAlgorithm

package tv.floe.metronome.classification.neuralnetworks.learning;

import java.util.ArrayList;

import tv.floe.metronome.classification.neuralnetworks.core.Connection;
import tv.floe.metronome.classification.neuralnetworks.core.Layer;
import tv.floe.metronome.classification.neuralnetworks.core.Weight;
import tv.floe.metronome.classification.neuralnetworks.core.neurons.Neuron;
import tv.floe.metronome.classification.neuralnetworks.learning.adagrad.AdagradLearningRate;
import tv.floe.metronome.classification.neuralnetworks.activation.ActivationFunction;

/**
* Basic backprop implementation
* - adagrad and momentum learning are added in a very basic way, no elaborate class system
*
* @author josh
*
*/
public class BackPropogationLearningAlgorithm extends SigmoidDeltaLearningAlgorithm {

  boolean adagradLearningOn = false;
  boolean momentumLearningOn = false;
  double adagradInitLearningRate = 10;
 
  public BackPropogationLearningAlgorithm() {
    super();
  }
 
  public void turnOnAdagradLearning(double adagradLearningRate) {
    this.adagradLearningOn = true;
    this.adagradInitLearningRate = adagradLearningRate;
  }
 
  public void turnOnMomentumLearning() {
    this.momentumLearningOn = true;
  }
 
    @Override
    public void setup() {
     
      if (this.adagradLearningOn) {
       
        // add stuff to the weights
       
        ArrayList<Layer> layers = nn.getLayers();
       
       
        /*
        for (int l = layers.size() - 2; l > 0; l--) {
                     
          for ( Neuron neuron : layers.get( l ).getNeurons() ) { 
                                   
            //double neuronError = this.calculateHiddenNeuronError( neuron );
           
            //neuron.setError( neuronError );
           
            //this.updateNeuronWeights( neuron );
                   
          } // for
         
        } // for 
        */
       
        for ( int x = 1; x < this.nn.getLayersCount(); x++ ) {
           
                for (Neuron neuron : this.nn.getLayerByIndex(x).getNeurons()) {
                 
                    for (Connection connection : neuron.getInConnections()) {
                     
                        connection.getWeight().trainingMetaData.put("adagrad", new AdagradLearningRate(this.adagradInitLearningRate));
                       
                    }
                   
                }
                 
         
        }       
       
       
       
      }
     
      if (this.momentumLearningOn) {
       
        // add stuff to the weights
       
      }
     
     
    }
 


  /**
   * This is the main driver method for the learning algorithms
   * train( ... )
   * - calculate output of current instance
   * - calculate error of output vs desired output
   *
   */
  @Override
  protected void updateNetworkWeights(double[] outputError) {
       
    this.calculateErrorAndUpdateOutputNeurons(outputError); // via SigmoidDelta
    this.calculateErrorAndUpdateHiddenNeurons();            // implemented in this class
   
  }

  protected void calculateErrorAndUpdateHiddenNeurons() {
               
    ArrayList<Layer> layers = nn.getLayers();
   
 
   
    for (int l = layers.size() - 2; l > 0; l--) {
                 
      for ( Neuron neuron : layers.get( l ).getNeurons() ) { 
                               
        double neuronError = this.calculateHiddenNeuronError( neuron );
       
        neuron.setError( neuronError );
       
        this.updateNeuronWeights( neuron );
               
      } // for
     
    } // for
   
  }

  protected double calculateHiddenNeuronError(Neuron neuron) { 
   
    double deltaSum = 0d;
   
   
   
    for (Connection connection : neuron.getOutConnections()) { 
     
      double delta = connection.getToNeuron().getError() * connection.getWeight().value;
      deltaSum += delta; // weighted delta sum from the next layer
   
      if (this.isMetricCollectionOn()) {
        this.metrics.incErrCalcOpCount();
      }
     
    } // for

   
   
    ActivationFunction transferFunction = neuron.getActivationFunction();
   
   
    double netInput = neuron.getNetInput(); // should we use input of this or other neuron?
    double fnDerv = transferFunction.getDerivative(netInput);
    double neuronError = fnDerv * deltaSum;
    return neuronError;
 
 
 
  /**
   * Updated for adagrad
   *
   * so with adagrad engaged, we are tracking/accumulating the gradient change
   * into the AdagradLearningRate object on each weight
   * - over time this drives the learning rate downwards
   *
   */
  @Override
    protected void updateNeuronWeights(Neuron neuron) {

      double neuronError = neuron.getError();
        double lrTemp = 0;
        AdagradLearningRate alr = null;
       
        for (Connection connection : neuron.getInConnections()) {

          if (this.adagradLearningOn) {
            alr = (AdagradLearningRate)connection.getWeight().trainingMetaData.get("adagrad");
            lrTemp = alr.compute();
          } else {
            lrTemp = this.learningRate;
          }
         
          double input = connection.getInput();
            //double weightChange = this.learningRate * neuronError * input;
          double weightChange = lrTemp * neuronError * input;

            Weight weight = connection.getWeight();

            if (this.isInBatchMode() == false) {            
                weight.weightChange = weightChange;               
                weight.value += weightChange;
            } else {
                weight.weightChange += weightChange;
            }
           
          if (this.adagradLearningOn) {
            alr = (AdagradLearningRate)connection.getWeight().trainingMetaData.get("adagrad");
            alr.addLastIterationGradient(weightChange);
          }           
           
            if (this.isMetricCollectionOn()) {
              this.metrics.incWeightOpCount();
            }
        }
    } 

  @Override
  public String Debug() {
   
    String out = "";
   
    if (this.adagradLearningOn) {
    for (int x = 1; x < this.nn.getLayersCount(); x++ ) {
     
      out += "L" + x + ":";
     
      int n  = 0;
     
      for ( Neuron neuron : this.nn.getLayerByIndex(x).getNeurons() ) { 
       
        out += "n" + n + "=";
       
            for (Connection connection : neuron.getInConnections()) {

              if (this.adagradLearningOn) {
                AdagradLearningRate alr = (AdagradLearningRate)connection.getWeight().trainingMetaData.get("adagrad");
                //lrTemp = alr.compute();
                out += "" + alr.compute() +",";
              }

            }
           
            n++;
                 
      } // for
     
    } // for 
    }
   
    return out;
   
   
  }

  public String DebugAdagrad() {
   
    String out = "";
   
    if (this.adagradLearningOn) {
     
      Neuron neuron = this.nn.getLayerByIndex(1).getNeurons().get(1);
       

       
      Connection c = neuron.getInConnections().get(0);
               
      AdagradLearningRate alr = (AdagradLearningRate)c.getWeight().trainingMetaData.get("adagrad");
                //lrTemp = alr.compute();
       
      out += "[Ada: " + alr.compute() +" ]";

      c = neuron.getInConnections().get(1);
       
      alr = (AdagradLearningRate)c.getWeight().trainingMetaData.get("adagrad");

      out += "[Ada: " + alr.compute() +" ]";
     
    }
   
    return out;
   
   
  }
 
 
}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.learning.BackPropogationLearningAlgorithm

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.