Package tv.floe.metronome.classification.neuralnetworks.iterativereduce

Source Code of tv.floe.metronome.classification.neuralnetworks.iterativereduce.MasterNode

package tv.floe.metronome.classification.neuralnetworks.iterativereduce;

import java.io.ByteArrayOutputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Collection;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.ToolRunner;

import tv.floe.metronome.classification.logisticregression.iterativereduce.POLRMasterNode;
import tv.floe.metronome.classification.logisticregression.iterativereduce.ParameterVectorUpdatable;
import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;
import tv.floe.metronome.classification.neuralnetworks.learning.BackPropogationLearningAlgorithm;
import tv.floe.metronome.classification.neuralnetworks.networks.MultiLayerPerceptronNetwork;
import tv.floe.metronome.io.records.RecordFactory;
import tv.floe.metronome.linearregression.iterativereduce.NodeBase;

import com.cloudera.iterativereduce.ComputableMaster;
import com.cloudera.iterativereduce.yarn.appmaster.ApplicationMaster;

public class MasterNode  extends NodeBase implements ComputableMaster<NetworkWeightsUpdateable> {

  public NeuralNetwork master_nn = null;
  double trainingErrorThreshold = 0;
  boolean hasHitThreshold = false;
 
  @Override
  public void complete(DataOutputStream ds) throws IOException {

      //System.out.println("master::complete (Iterations: " + this.NumberIterations + ")");
     
     ds.write(this.master_nn.Serialize());
   
    //System.out.println("Master: Model Writing is disabled for testing");
     
   
  }

  @Override
  public NetworkWeightsUpdateable compute(Collection<NetworkWeightsUpdateable> workerUpdates,
      Collection<NetworkWeightsUpdateable> masterUpdates) {

    NetworkWeightsUpdateable return_msg = new NetworkWeightsUpdateable();
       
    double avg_rmse = 0;
   
   
    NetworkWeightsUpdateable first = workerUpdates.iterator().next();
    NetworkAccumulator accumNet = null;
   
    if (null == first) {
      System.out.println("Can't seem to get the first network weights updateable");
    } else {
     
      if (null == this.master_nn) {
       
        System.out.println("Building base master MLP network");
        this.master_nn = new MultiLayerPerceptronNetwork();
            try {
          this.master_nn.buildFromConf(first.networkUpdate.network.getConfig());
        } catch (Exception e) {
          // TODO Auto-generated catch block
          e.printStackTrace();
        }
           
      }
     
    }
/*   
    if (null == this.first_worker_copy) {
      this.first_worker_copy = first.networkUpdate.network;
    }
  */ 
    try {
      accumNet = NetworkAccumulator.buildAveragingNetworkFromConf(first.networkUpdate.network.getConfig());
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
   
    if (null == accumNet) {
      System.out.println("Master: Network Accumulator is null! [Error]");
      return null;
    }
   
      for (NetworkWeightsUpdateable nn_worker : workerUpdates) {

        accumNet.AccumulateWorkerNetwork(nn_worker.networkUpdate.network);
        avg_rmse += nn_worker.networkUpdate.RMSE;
       
      }
     
      avg_rmse = avg_rmse / workerUpdates.size();
      BackPropogationLearningAlgorithm bp = ((BackPropogationLearningAlgorithm)this.master_nn.getLearningRule());
      bp.getMetrics().setLastRMSE(avg_rmse);
     
      if (avg_rmse <= this.trainingErrorThreshold && !hasHitThreshold && first.networkUpdate.CurrentIteration > 10) {
       
        System.out.println("\nMaster hit avg rmse threshold at epoch: " + first.networkUpdate.CurrentIteration + "\n");
       
        for (NetworkWeightsUpdateable nn_worker : workerUpdates) {

          System.out.println("worker.rmse: " + nn_worker.networkUpdate.RMSE );
        }
       
       
        this.hasHitThreshold = true;
      } else {
       
        //System.out.println("rmse debug > " + avg_rmse + " <= " + this.trainingErrorThreshold + ", iterations: " + first.networkUpdate.CurrentIteration);
       
      }
       
     
     
      try {
      accumNet.AverageNetworkWeights();
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
     
      //accumNet.
     
      this.master_nn.copyWeightsAndConf(accumNet);
   
      NeuralNetworkWeightsDelta nnwd = new NeuralNetworkWeightsDelta();
      nnwd.network = this.master_nn;
     
      return_msg.set(nnwd);
     
      //this.master_nn = nnwd.network;
   
      // THIS NEEDS TO BE DONE, probably automated!
      workerUpdates.clear();
      masterUpdates.clear();

      //System.out.println("DEBUG > Master Completes Iteration");
   
    return return_msg;
  }

  @Override
  public NetworkWeightsUpdateable getResults() {
    System.out.println(">>> getResults() - null!!!");
    return null;
  }

  @Override
  public void setup(Configuration c) {

   

      this.conf = c;
     
      try {

        // this is the target to get the avg rmse under for testing purposes
        this.trainingErrorThreshold = Double.parseDouble(this.conf.get(
                "tv.floe.metronome.neuralnetwork.conf.TrainingErrorThreshold", "0.2"));
       

      } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
        System.out.println(">> Error loading conf!");
      }
     
      System.out.println( "-----------------------------------------" );
      System.out.println( "# Master Conf #" );
      //System.out.println( "Number Iterations: " + this.NumberIterations );
      System.out.println( "-----------------------------------------\n\n" );
         
   
  }

    public static void main(String[] args) throws Exception {
        MasterNode pmn = new MasterNode();
        ApplicationMaster<NetworkWeightsUpdateable> am = new ApplicationMaster<NetworkWeightsUpdateable>(
            pmn, NetworkWeightsUpdateable.class);
       
        ToolRunner.run(am, args);
    }



}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.iterativereduce.MasterNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.