Package tv.floe.metronome.classification.neuralnetworks.serde

Source Code of tv.floe.metronome.classification.neuralnetworks.serde.TestNetworkSerdeMechanics

package tv.floe.metronome.classification.neuralnetworks.serde;

import static org.junit.Assert.*;

import java.io.IOException;

import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.junit.Test;

import tv.floe.metronome.classification.neuralnetworks.activation.Tanh;
import tv.floe.metronome.classification.neuralnetworks.conf.Config;
import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;
import tv.floe.metronome.classification.neuralnetworks.core.neurons.Neuron;
import tv.floe.metronome.classification.neuralnetworks.input.WeightedSum;
import tv.floe.metronome.classification.neuralnetworks.iterativereduce.MasterNode;
import tv.floe.metronome.classification.neuralnetworks.iterativereduce.NetworkWeightsUpdateable;
import tv.floe.metronome.classification.neuralnetworks.iterativereduce.NeuralNetworkWeightsDelta;
import tv.floe.metronome.classification.neuralnetworks.networks.MultiLayerPerceptronNetwork;

public class TestNetworkSerdeMechanics {

  private static JobConf defaultConf = new JobConf();
  private static FileSystem localFs = null;
  static {
    try {
      defaultConf.set("fs.defaultFS", "file:///");
      localFs = FileSystem.getLocal(defaultConf);
    } catch (IOException e) {
      throw new RuntimeException("init failure", e);
    }
  }
 
 
  @Test
  public void testBaseJavaObjectSerialization() {

    NetworkWeightsUpdateable nwu = new NetworkWeightsUpdateable();
   
    Config c = new Config();
    c.parse(null); // default layer: 2-3-2
        c.setConfValue("inputFunction", WeightedSum.class);
    c.setConfValue("transferFunction", Tanh.class);
    c.setConfValue("neuronType", Neuron.class);
    c.setConfValue("networkType", NeuralNetwork.NetworkType.MULTI_LAYER_PERCEPTRON);
    c.setConfValue("layerNeuronCounts", "2,3,1" );
    c.parse(null);
   
    NeuralNetwork nn = new MultiLayerPerceptronNetwork();
   
    try {
      nn.buildFromConf(c);
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
   
    NeuralNetworkWeightsDelta nnwd = new NeuralNetworkWeightsDelta();
    nnwd.network = nn;
   
    nwu.set(nnwd);
   
    nwu.toBytes();
   
   
   
   
   
   
  }
 
  @Test
  public void testMasterNodeObjectSerialization() {

    NetworkWeightsUpdateable nwu = new NetworkWeightsUpdateable();
   
    Config c = new Config();
    c.parse(null); // default layer: 2-3-2
        c.setConfValue("inputFunction", WeightedSum.class);
    c.setConfValue("transferFunction", Tanh.class);
    c.setConfValue("neuronType", Neuron.class);
    c.setConfValue("networkType", NeuralNetwork.NetworkType.MULTI_LAYER_PERCEPTRON);
    c.setConfValue("layerNeuronCounts", "2,3,1" );
    c.parse(null);
   
    NeuralNetwork nn = new MultiLayerPerceptronNetwork();
   
    try {
      nn.buildFromConf(c);
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    /*
    NeuralNetworkWeightsDelta nnwd = new NeuralNetworkWeightsDelta();
    nnwd.network = nn;
   
    nwu.set(nnwd);
   
    nwu.toBytes();
    */
   
    MasterNode mnode = new MasterNode();
    mnode.master_nn = nn;
   
   
   
    try {

      Path out = new Path("/tmp/fooTest.model");
      FileSystem fs =
            out.getFileSystem(defaultConf);
     
      FSDataOutputStream fos;

      fos = fs.create(out);
        //LOG.info("Writing master results to " + out.toString());
        mnode.complete(fos);
       
        fos.flush();
        fos.close();

      //BufferedWriter bw = new BufferedWriter(new FileWriter(output_path));
      //master.complete( bw );
     
      //bw.close();
      /*
      FileOutputStream fs = new FileOutputStream(output_path);
        ObjectOutputStream oos = new ObjectOutputStream(fs);
       
        //oos.writeObject( master. );
        //master.complete(oos);
       
        oos.flush();
        oos.close();
*/
     
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
         
   
   
   
  } 

}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.serde.TestNetworkSerdeMechanics

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.