Package trust.jfcm.learning.training

Source Code of trust.jfcm.learning.training.LinearRegression

package trust.jfcm.learning.training;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import trust.jfcm.Concept;
import trust.jfcm.FcmConnection;
import trust.jfcm.WeightedConnection;
import trust.jfcm.learning.EntryStructure;
import trust.jfcm.learning.FcmLearning;
import trust.jfcm.learning.FcmTrainingSet;
import trust.jfcm.learning.FcmTrainingSetEntry;
import trust.jfcm.learning.InputLearningConcept;
import trust.jfcm.learning.LearningConcept;
import trust.jfcm.learning.LearningWeightedConnection;
import trust.jfcm.learning.OutputLearningConcept;
import trust.jfcm.utils.FcmRunner;
import trust.jfcm.utils.SimpleFcmRunner;

/**
* This class implement the linear regression
* for the FCM training used in the backward propagation
* @author pc
*
*/
public class LinearRegression implements FcmTrainer{

  /**
   * Squared Errors Vector
   */
  double[] error_vector;
 
 
  FcmLearning map;
 
  /******* TRAINING PARAMETERS ********/
 
  /**
   * The number of epochs to train through.
   */
  int trainingTime = 500;
 
  /**
   * Target level of accuracy on the error.
   */
  double accuracy = 0.01;
 
  /**
   * Level of tolerance in the error change.
   */
  double epsilon = 0.0000001;
 
  /**
   * Fcm runner instance
   */
  FcmRunner runner;
 
  /**
   * Current training epoch
   */
  int epoch;
 
  /**
   * Elapsed time
   */
  long start, end;
 
  /**
   * Constructor
   */
  public LinearRegression(FcmLearning map){
    this.map = map;
    error_vector =new double[trainingTime];
    runner = new SimpleFcmRunner(map, 0.000001, 1000);
    epoch = 0;
  }
 
  /**
   * Train the map and update weights on the current training data set
   */
  public void train(){
   
    start = System.currentTimeMillis();

   
    FcmTrainingSet trainingSet = map.getTrainingSet();
   
    for(epoch=0; epoch<trainingTime; epoch++){
     
      /*
      try {
        Thread.sleep(5000);
      } catch (InterruptedException e1) {
        // TODO Auto-generated catch block
        e1.printStackTrace();
      }*/
     
      //System.out.println(" - Epoch: "+(epoch+1)+"/"+trainingTime);
      for(int i=0; i<trainingSet.size(); i++){
        FcmTrainingSetEntry e = trainingSet.get(i);
       
        //System.out.println("\n- Processing "+e);
       
        /*
         * Reset concept values and errors
         */
        map.resetConcepts();
       
        /*
         * Calculate errors on the nodes
         */
        calculateErrors(e);
       
       
        /*
         * Update network weights
         */
        updateNetworkWeights();
       
        /*
         * Normalize network weights
         */
        map.normalizeNetworkWeights(true);
       
       
        //System.out.println("\n"+map);
      }
     
      /*
       * Check objective functions
       */
      if(checkTerminalConditions()){
        break;
      }
     
    }
   
    end = System.currentTimeMillis();

    printInfo();
    cleanUp();
   
  }
 
  private void printInfo(){
    System.out.println("\n****** TRAINING MAP ********");
    FcmTrainingSet trainingSet = map.getTrainingSet();
    for(int k=0; k<trainingSet.size(); k++){
     
      FcmTrainingSetEntry e = trainingSet.get(k);
      System.out.println("\n- "+e);
     
      /*
       * Reset map
       */
      map.resetConcepts();
     
      /*
       * Insert inputs
       */
      setTrainingInputs(e.getInputs());
     
     
      /*
       * Evaluate error on the output nodes
       */
      setDesirableOutputs(e.getOutputs());
     
      /*
       * Execute map
       */
      runner.converge();
     
      /*
       * compute errors on the output nodes
       */
      List<EntryStructure> outputs = e.getOutputs();
      for(int i=0; i<outputs.size(); i++){
        EntryStructure e1 = outputs.get(i);
        Concept c = e1.getConcept();
       
        if(c==null)
          System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
       
       
          LearningConcept lc = null;
          if(c instanceof LearningConcept)
            lc = (LearningConcept) c;
         
          if(lc != null && lc.isOutput()){
            OutputLearningConcept oc = (OutputLearningConcept) lc;
            System.out.println("  Output concept: "+c.getName());
            System.out.println("  Desirable output: "+oc.getDesirableOutput()+", output: "+oc.getOutput().doubleValue());
            System.out.println("  Error on output concept: "+c.getName()+": "+(oc.getDesirableOutput() - oc.getOutput().doubleValue()));
           
          }
         
      }
    }
    if(epoch==error_vector.length)
      System.out.println("\n  Squeared mean error: "+Math.sqrt(error_vector[epoch-1]));
    else
      System.out.println("\n  Squeared mean error: "+Math.sqrt(error_vector[epoch]));
    System.out.println("  Training epoch: "+(epoch)+"/"+trainingTime);
   
    System.out.println("  Elapsed time: "+(end-start)/1000F+" sec \n");
   
    map.resetConcepts();
  }
 
  /**
   * Store the squared error on the output nodes in the error vector
   */
  private void measureOutputError(List<EntryStructure> outputs) {
   
     /*
     * compute errors on the output nodes
     */
    double error = 0;
    for(int i=0; i<outputs.size(); i++){
      EntryStructure e1 = outputs.get(i);
      Concept c = e1.getConcept();
     
      if(c==null)
        System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
     
     
        LearningConcept lc = null;
        if(c instanceof LearningConcept)
          lc = (LearningConcept) c;
       
        if(lc != null && lc.isOutput()){
          OutputLearningConcept oc = (OutputLearningConcept) lc;
          error += Math.pow(oc.getOutput() - e1.getValue(), 2);
          //System.out.println("  OutputNode: "+c.getName()+": "+c.getOutput()+", desirable value: "+e1.getValue());
        }
       
    }
   
    /* Update squared error vector*/
    //System.out.println("epoch:"+epoch);
    if(error_vector[epoch]==0){
      //System.out.println("  error is: "+error);
      error_vector[epoch] = (error);
    }
    else{
      double prev_error = error_vector[epoch];
      //System.out.println("  update error:  prev_err:"+prev_error+" error: "+error);
      error_vector[epoch] = prev_error + error;
      //System.out.println(error_vector[epoch]);
    }
   
  }

  /**
   * Calculate the overall squared error of the current map
   * on the list of given outputs
   * @param outputs: training set entry of the output data
   */
  private void setDesirableOutputs(List<EntryStructure> outputs){
   
    //System.out.println("- Calculating errors:");
    /*
     * Set desirable outputs
     */
    for(int i=0; i<outputs.size(); i++){
      EntryStructure e1 = outputs.get(i);
      Concept c = e1.getConcept();
     
      if(c==null)
        System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
     
     
        LearningConcept lc = null;
        if(c instanceof LearningConcept)
          lc = (LearningConcept) c;
       
        if(lc != null && lc.isOutput()){
          OutputLearningConcept oc = (OutputLearningConcept) lc;
          oc.setDesirableOutput(e1.getValue());
          //System.out.println("  OutputNode: "+c.getName()+": "+c.getOutput()+", desirable value: "+e1.getValue());
        }
       
    }
  }

  /**
   *
   * @param e
   */
  private void calculateErrors(FcmTrainingSetEntry e){
   
   
    /*
     * Insert inputs
     */
    setTrainingInputs(e.getInputs());
   
    /*
     * Evaluate error on the output nodes
     */
    setDesirableOutputs(e.getOutputs());

    /*
     * Execute map
     */
    runner.converge();
   
   
    //System.out.println("...map executed");
   
    //System.out.println(this);
   
    /*
     * **Calculate errors on the  concepts.**
     *
     * Repeatedly call the calculateError() function until errors
     * are computed on all the nodes
     *
     */
    //System.out.println("  Propagate errors...");
    //printErrorFlag();
    int num_completed = 0;
    while(num_completed < map.getNumLearningConcepts()){
      Iterator<Concept> it = map.getConceptsIterator();
      while(it.hasNext()){
        Concept c = it.next();
       
        if(c instanceof LearningConcept){
          LearningConcept lc = (LearningConcept) c;
          if(!lc.isErrorCalculated()){
           
           
           
           
            //error to be calculated
            //System.out.println(" error to be calculated for " + lc.getName());
            if(calculateErrorOnConcept(lc)){
              //calculation successful
              num_completed++;
             
              //System.out.println("  error on concept "+lc.getName()+": "+lc.getError());
            }
          } //else   System.out.println(" error already calculated for " + lc.getName() + " (completed "+num_completed+")");
       
      }
    }
   
    /*
     * Measure errors
     */
    measureOutputError(e.getOutputs());
  }
 
  /**
     * This function calculates the error on the current node in
     * the back propagation fashion. Error is the weighted sum of
     * the errors of the outgoing nodes.
     * @param node The node to calculate the error for.
     * @return true if the node is successfully updated.
     */
  private boolean calculateErrorOnConcept(LearningConcept c){
   
   
    if(c.isErrorCalculated())
      return false;
   
    /*
     * Specialization of the calculateError() method.
     * For output nodes the error is the difference with the desirable output.
     */
    if(c.isOutput()){
      OutputLearningConcept oc = (OutputLearningConcept) c;
      oc.setError(oc.getDesirableOutput() - oc.getOutput().doubleValue());

      return true;
    }
     
   
    /*
     * Weighted sum of the errors on the outgoing nodes.
     */
    Set<FcmConnection> out = c.getOutConnections();
    Iterator<FcmConnection> it = out.iterator();
    double val = 0;
    while(it.hasNext()){
      FcmConnection conn = it.next();
     
      WeightedConnection wconn = null;
      if(conn instanceof WeightedConnection)
        wconn = (WeightedConnection) conn;
     
      Concept cout = wconn.getTo();
      LearningConcept lout = null;
      if(cout instanceof LearningConcept)
        lout = (LearningConcept) cout;
     
      if(!lout.isErrorCalculated())
        /*
         * Error cannot be computed yet
         */
        return false;
     
      val += wconn.getWeight() * lout.getError();
     
    }
   
    /*
     * Set Error
     */
    c.setError(val);

    //System.out.println("Error on concept "+c.getName()+": "+c.getError());

    return true;
   
   
  }
 
  /**
   * Stop the training cycle if the terminal conditions are true.
   * @return
   */
  private boolean checkTerminalConditions() {
   
    if(Math.sqrt(error_vector[epoch])<=accuracy){
      //System.out.println("STOP accuracy");
      return true;
    }
   
    //int size = error_vector.length;
    if(epoch>1){
      double error_1 = Math.sqrt(error_vector[epoch-1]);
      double error_2 = Math.sqrt(error_vector[epoch-2]);
      double delta = Math.abs(error_1 - error_2);
      //System.out.println("Error(-1): "+error_1+" Error(-2): "+error_2+" Delta: "+delta+" epsilon: "+epsilon);
      if(delta <= epsilon){
        System.out.println("STOP epsilon");
        return true;
      }
    }
   
    return false;
  }
 
  /**
   * Reset the data structures after the training
   */
  private void cleanUp() {
    for(int i=0; i<epoch; i++) error_vector[i] = 0;
   
    epoch = 0;
    start = 0;
    end = 0;
   
    Collection<FcmConnection> cons = map.getConnections().values();
    for(FcmConnection f : cons){
      if(f instanceof LearningWeightedConnection){
        LearningWeightedConnection w = (LearningWeightedConnection) f;
        w.setChangeInWeight(0);
      }
    }
  }
 
  /**
   * Set the inputs to the map from the given training set entry
   * @param inputs: training set entry of input data
   */
  private void setTrainingInputs(List<EntryStructure> inputs){
    for(int i=0; i<inputs.size(); i++){
      EntryStructure e1 = inputs.get(i);
      //System.out.println("EntryStructure: "+e1);
     
      LearningConcept lc = (LearningConcept) e1.getConcept();
      if(lc != null && lc.isInput()){
        InputLearningConcept ilc = (InputLearningConcept) lc;
        ilc.setInput(e1.getValue());
      }

    }
  }
 
  /**
   * Update weights of the connections.
   *
   * Iteratively call the updateWeights function in all the nodes
   * until the whole network is updated
   */
  private void updateNetworkWeights(){
    //System.out.println("- Update weights...");
    Iterator<Concept> it = map.getConcepts().values().iterator();
    while(it.hasNext()){
      Concept c = it.next();
       
      if(c instanceof LearningConcept){
        LearningConcept lc = (LearningConcept) c;
        Set<FcmConnection> out = lc.getOutConnections();
        Iterator<FcmConnection> it1 = out.iterator();
       
        while(it1.hasNext()){
          FcmConnection conn = it1.next();
         
          LearningWeightedConnection wconn = null;
          if(conn instanceof LearningWeightedConnection){
            wconn = (LearningWeightedConnection) conn;
            /*
             * Update learning connection
             */
            lc.getTrainingFunction().update(wconn);
          }
         
        }
         
      }     
    }
  }

  @Override
  public void setMap(FcmLearning map) {
    this.map = map;
   
  }
 
 
}
TOP

Related Classes of trust.jfcm.learning.training.LinearRegression

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.