Package cnslab.cnsnetwork

Source Code of cnslab.cnsnetwork.NetHost

    package cnslab.cnsnetwork;
   
    import java.io.*;
    import java.util.*;
    import java.util.concurrent.*;
    import org.w3c.dom.*;
    import org.w3c.dom.ls.*;
    import org.w3c.dom.bootstrap.*;

    import jpvm.*;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;

    import cnslab.cnsmath.*;
    import edu.jhu.mb.ernst.net.PeerInfo;

    /***********************************************************************
    * NetHost class.
    *
    * It first initialize itself, prepare the jpvm environment.
    * populate the neurons, build up the network structure,
    * start the executing thread, then listen to message from trial host
    * and mainhost.
    *
    * @version
    *   $Date: 2012-08-04 20:43:22 +0200 (Sat, 04 Aug 2012) $
    *   $Rev: 104 $
    *   $Author: croft $
    * @author
    *   Yi Dong
    * @author
    *   David Wallace Croft
    ***********************************************************************/
    public class  NetHost
    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    {
     
    private static final Class<NetHost>
      CLASS = NetHost.class;

//    private static final String
//      CLASS_NAME = CLASS.getName ( );
   
    private static final Logger
      LOGGER = LoggerFactory.getLogger ( CLASS );

    //public static JpvmInfo  info;
   
    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
   
    public static void  main ( final String [ ]  args )
    ////////////////////////////////////////////////////////////////////////
    {
      jpvmEnvironment  jpvmEnvironmentInstance = null;
       
      try
      {
        // double  minDelay;
       
        // double backFire;
       
        // Enroll in the parallel virtual machine...
       
        jpvmEnvironmentInstance = new jpvmEnvironment ( );
       
        // receive information about its peers
       
        // It will hang here until it receives a network message.
       
        final jpvmMessage  message
          = jpvmEnvironmentInstance.pvm_recv ( NetMessageTag.sendTids );
       
        final PeerInfo  peerInfo = new PeerInfo ( message );
       
        launch ( jpvmEnvironmentInstance, peerInfo );

        jpvmEnvironmentInstance.pvm_exit ( );
      }
      catch (jpvmException jpe)
      {
        System.out.println("Error - jpvm exception");
       
        try
        {
          if ( jpvmEnvironmentInstance != null )
          {
            FileOutputStream out = new FileOutputStream(
              "log/"+jpvmEnvironmentInstance.pvm_mytid().getHost()
              +"error.txt");

            PrintStream p = new  PrintStream(out);

            jpe.printStackTrace(p);

            p.close();
          }
        }
        catch(Exception ex)
        {
          //
        }
      }
      catch (Exception a)
      {
        try
        {
          if ( jpvmEnvironmentInstance != null )
          {
            FileOutputStream out = new FileOutputStream(
              "log/"+ jpvmEnvironmentInstance.pvm_mytid().getHost()
              +"error.txt");

            PrintStream p = new  PrintStream(out);

            a.printStackTrace(p);

            p.close();
          }
        }
        catch(Exception ex)
        {
          //
        }
      }
    }
   
    public static void  launch (
      final jpvmEnvironment  jpvmEnvironmentInstance,
      final PeerInfo         peerInfo )
      throws Exception, jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      final JpvmInfo  jpvmInfo = new JpvmInfo ( );
       
      jpvmInfo.jpvm = jpvmEnvironmentInstance;

      jpvmInfo.myJpvmTaskId = jpvmEnvironmentInstance.pvm_mytid ( );

      // Get my parent's task id...
      //    info.parent = info.jpvm.pvm_parent();
      // actually it is grand parent;
     
      peerInfo.populateJpvmInfo ( jpvmInfo );

      int  seedInt = peerInfo.seedInt;

      // NetHosts tid

      //FileOutputStream outt = new FileOutputStream(
      //  "log/"+info.parent.getHost()+"_preinfo"+iter_int+".txt");
      //PrintStream p = new  PrintStream(outt);

      //    String out="Host id "+info.idIndex+"\n";

      final DOMImplementationRegistry
        registry = DOMImplementationRegistry.newInstance ( );

      final DOMImplementationLS  impl
        = ( DOMImplementationLS ) registry.getDOMImplementation ( "LS" );

      final SimulatorParser  simulatorParser = createSimulatorParser (
        impl,
        peerInfo.byteArray,
        seedInt,
        jpvmInfo );

      final CyclicBarrier  barrier = new CyclicBarrier ( 2 );

      /*
        //receiving neurons
        m = info.jpvm.pvm_recv(NetMessageTag.sendNeurons);
        int numNeurons;
        numNeurons = m.buffer.upkint();
        int base = m.buffer.upkint();
        // the last neuron is a sync neuron
        Neuron neurons[] = new Neuron[numNeurons+1];
        m.buffer.unpack(neurons,numNeurons,1);
       */

      /*
        out=out+"\n";
        for(int i=0; i< numNeurons; i++)
        {
          out=out+"Neurons are:"+neurons[i].toString();
        }
       */

      final Seed  idum = new Seed ( seedInt - jpvmInfo.idIndex );

      // neurons, pvminfo, base index, mini delay, background freq,
      // seed number       

      final Network  testNet = new Network (
        simulatorParser.getModelFactory ( ),
        simulatorParser.getDiscreteEventQueue ( ),
        simulatorParser.getModulatedSynapseSeq ( ),
        simulatorParser.layerStructure.neurons,
        simulatorParser.layerStructure.axons,
        jpvmInfo,
        simulatorParser.layerStructure.base,
        simulatorParser.minDelay,
        simulatorParser,             
        idum,
        simulatorParser.experiment );

      testNet.initNet ( );

      //    pas.p = testNet.p;


      // if(info.idIndex==0) pas.ls.connectFrom("O,0,0,E", testNet.p);
      // if(info.idIndex==4) pas.ls.connectFrom("O,0,16,0,E", testNet.p);
      // pas.ls.connectFrom("T,27,9,L", testNet.p);

      testNet.p.println("Host id "+jpvmInfo.idIndex+"\n");

      testNet.p.println("base "+testNet.base+"\n");

      testNet.p.println("idum "+idum.seed+"\n");

      testNet.p.println("NumOfSyn "+simulatorParser.layerStructure.numSYN+"\n");


      //    testNet.p.flush();

      Object lock = new Object();

      // Object synLock = new Object();
      // ListenInput listen = new ListenInput(testNet,lock);
      // PCommunicationThread
      //   p1 = new PCommunicationThread(testNet, lock);
      // PComputationThread p2 = new PComputationThread(testNet, lock);

      final PRun  p1 = new PRun (
        testNet.getDiscreteEventQueue ( ),
        testNet,
        lock,
        barrier );

      //    pas.ls.cellmap=null;

      System.gc ( );

      final Thread  run = new Thread ( p1 );

      run.setDaemon ( true );

      run.start ( );

      // Barrier Sync

      jpvmBuffer buf = new jpvmBuffer();

      // send out ready info;
     
      final String
        messageString
          = "NetHost " + jpvmInfo.jpvm.pvm_mytid ( ).toString ( )
          + " is ready to go";

      buf.pack ( messageString );
     
      LOGGER.info ( messageString );
     
      jpvmTaskId
        jpvmTaskIdInstance = jpvmInfo.tids [ jpvmInfo.idIndex ];
     
      if ( jpvmTaskIdInstance == null )
      {
        LOGGER.warn ( "jpvmTaskIdInstance == null" );
     
        // TODO:  This is a temporary hack for JUnit testing.
     
        jpvmTaskIdInstance = jpvmInfo.jpvm.pvm_mytid ( );
      }

      jpvmInfo.jpvm.pvm_send (
        buf,
        jpvmTaskIdInstance,
        NetMessageTag.readySig );

      // m = info.jpvm.pvm_recv(NetMessageTag.readySig);

      // Barrier Sync

      while ( !testNet.stop )
      {
        // receive info from others

        final jpvmMessage  m = testNet.info.jpvm.pvm_recv ( );

        if ( m.messageTag == NetMessageTag.trialDone )
        {
          synchronized(lock)
          {
            testNet.trialDone = false;
          }
        }

        // synchronized(lock)
        {
          // testNet.p.println("message "+m.messageTag);

          // lock.notify();

          switch ( m.messageTag )
          {
            case NetMessageTag.sendSpike:

              processSendSpike (
                lock,
                m,
                testNet,
                jpvmInfo );

              break;

            case NetMessageTag.syncRoot:

              // if its a message about time

              processSyncRoot (
                lock,
                testNet,
                m );

              break;

            case NetMessageTag.stopSig:

              // if its a message about time

              processStopSig (
                testNet,
                jpvmInfo,
                lock );

              break;

            case NetMessageTag.tempStopSig:

              //if its a message about time

              processTempStopSig (
                testNet,
                jpvmInfo );

              break;

            case NetMessageTag.trialDone:

              // new trial begins

              processTrialDone (
                lock,
                barrier,
                m,
                testNet );

              break;

            case NetMessageTag.changeConnection:

              processChangeConnection (
                m,
                impl,
                simulatorParser,
                testNet,
                jpvmInfo );

              break;

            case NetMessageTag.netHostNotify:

              processNetHostNotify ( jpvmInfo );

              break;

            case NetMessageTag.resetNetHost:

              processResetNetHost (
                lock,
                testNet,
                jpvmInfo );

              break;

            case NetMessageTag.checkTime:

              processCheckTime (
                m,
                testNet,
                p1,
                jpvmInfo );

              break;
          }
        }
      }
    }

    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
   
    public static SimulatorParser  createSimulatorParser (
      final DOMImplementationLS  impl,
      final byte [ ]             byteArray,
      final int                  seedInt,
      final JpvmInfo             jpvmInfo )
      throws Exception, jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      final LSInput  input = impl.createLSInput ( );

      input.setByteStream ( new ByteArrayInputStream ( byteArray ) );

      final LSParser  parser = impl.createLSParser (
        DOMImplementationLS.MODE_SYNCHRONOUS,
        null );

      final SimulatorParser  simulatorParser = new SimulatorParser (
        new Seed ( seedInt - jpvmInfo.idIndex ),
        parser.parse ( input ) );

      simulatorParser.parseMapCells ( jpvmInfo.idIndex );

      // p.println("cell maped");
      // System.out.println(pas.ls.base+ " "+pas.ls.neuron_end);

      simulatorParser.parseNeuronDef ( );

      simulatorParser.parsePopNeurons ( jpvmInfo.idIndex );

      // p.println("cell poped");

      simulatorParser.parseScaffold ( jpvmInfo.idIndex );

      simulatorParser.layerStructure.buildStructure ( jpvmInfo.idIndex );

      simulatorParser.parseConnection ( jpvmInfo.idIndex );

      simulatorParser.parseTarget ( jpvmInfo );
       
      // sort the  synapses, not necessary unless tune the parameters       

      //    pas.ls.sortSynapses(info.idIndex);
      //    p.println("connected");
       
      simulatorParser.parseExp ( jpvmInfo.idIndex );
       
      //    p.println("exp");
       
      simulatorParser.findMinDelay ( );
       
      //    p.close();

      if ( simulatorParser.layerStructure.axons == null )
      {
        throw new RuntimeException ( "no axon info" );
      }
     
      return simulatorParser;
    }
   
    public static void  processCheckTime (
      final jpvmMessage  m,
      final Network      testNet,
      final PRun         pRun,
      final JpvmInfo     jpvmInfo )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      // testNet.p.println("check time now");
     
      // testNet.p.flush();
     
      final jpvmBuffer  buf = new jpvmBuffer ( );
     
      // send the available Host id

      buf.pack ( m.buffer.upkint ( ) );

      buf.pack ( testNet.subExpId );

      buf.pack ( pRun.minTime );

      buf.pack ( testNet.trialId );

      jpvmInfo.jpvm.pvm_send (
        buf,
        jpvmInfo.tids [ jpvmInfo.idIndex ],
        NetMessageTag.checkTime );
    }
   
    public static void  processChangeConnection (
      final jpvmMessage          m,
      final DOMImplementationLS  impl,
      final SimulatorParser      simulatorParser,
      final Network              testNet,
      final JpvmInfo             jpvmInfo )
      throws jpvmException, Exception
    ////////////////////////////////////////////////////////////////////////
    {
      final int  seedInt = m.buffer.upkint ( );
     
      final int  baLength = m.buffer.upkint ( );
     
      final byte [ ]  ba = new byte [ baLength ];
     
      m.buffer.unpack ( ba, baLength, 1 );
               
      // DOMImplementationRegistry
      //   registry = DOMImplementationRegistry.newInstance();

      // DOMImplementationLS impl = (DOMImplementationLS)
      //   registry.getDOMImplementation("LS");

      final LSInput  input = impl.createLSInput();

      input.setByteStream ( new ByteArrayInputStream ( ba ) );

      final LSParser  parser = impl.createLSParser (
        DOMImplementationLS.MODE_SYNCHRONOUS,
        null );

      final NodeList
        conns = simulatorParser.rootElement.getElementsByTagName ( "Connections" );

      testNet.p.println("connection num"+conns.getLength());

      simulatorParser.rootElement.removeChild ( conns.item ( 0 ) );

      final Node  dup = simulatorParser.document.importNode (
        parser.parse(input).getDocumentElement()
        .getElementsByTagName("Connections").item(0),
        true );

      simulatorParser.rootElement.appendChild ( dup );

      double weight;

      // change the connections;

      weight = simulatorParser.parseChangeConnection (
        jpvmInfo.idIndex );

      // synchronize
      // testNet.p.println(
      //   "connection change done with weight "+weight);

      testNet.seed = new Seed ( seedInt - jpvmInfo.idIndex);

      // testNet.p.println("new seed"+testNet.idum.seed);

      final jpvmBuffer  buf1 = new jpvmBuffer ( );

      if ( weight < 0 )
      {
        buf1.pack (
          "NetHost "+ jpvmInfo.jpvm.pvm_mytid().getHost()
          +" has been changed"); //send out ready info;
      }
      else
      {
        // send out ready info

        buf1.pack ( "badweight$" + weight );
      }

      jpvmInfo.jpvm.pvm_send (
        buf1,
        jpvmInfo.tids[jpvmInfo.idIndex],
        NetMessageTag.readySig );

      //end of sync

      if(weight < 0 )
      {
        final jpvmBuffer  buf2 = new jpvmBuffer();

        jpvmInfo.jpvm.pvm_send (
          buf2,
          jpvmInfo.parentJpvmTaskId,
          NetMessageTag.trialDone );
      }
    }
   
    public static void  processNetHostNotify (
      final JpvmInfo  jpvmInfo )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      // testNet.p.println("nethost notify");

      final jpvmBuffer  buf1 = new jpvmBuffer ( );
     
      // send out ready info

      buf1.pack (
        "NetHost " + jpvmInfo.jpvm.pvm_mytid ( ).toString ( )
        + " is ready to go" );

      jpvmInfo.jpvm.pvm_send (
        buf1,
        jpvmInfo.tids [ jpvmInfo.idIndex ],
        NetMessageTag.readySig );

      // restore saved seed for comparison;
      // testNet.idum.seed = testNet.saveSeed;

      final jpvmBuffer  buf2 = new jpvmBuffer ( );

      jpvmInfo.jpvm.pvm_send (
        buf2,
        jpvmInfo.parentJpvmTaskId,
        NetMessageTag.trialDone );
    }
               
    public static void  processResetNetHost (
      final Object    lock,
      final Network   testNet,
      final JpvmInfo  jpvmInfo )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      synchronized(lock)
      {
        // break;

        testNet.trialDone = true;

        // added;

        testNet.spikeState = true;

        for ( int  iter = 0; iter < testNet.info.numTasks; iter++ )
        {      
          /*
          while(!testNet.received[iter].empty())
          {
            testNet.received[iter].pop();
          }
          */

          // leave mark here
         
          testNet.received [ iter ] = 1;
        }

        //added over;

        // initilization
        // testNet.rootTime=0.0;

        testNet.recorderData.clear();
       
        testNet.clearQueues ( );

        testNet.countTrial++;

        //  testNet.p.println("reset the trial now");

        testNet.recordBuff.buff.clear();

        testNet.intraRecBuffers.init();

        for ( int  i = 0; i < jpvmInfo.numTasks; i++ )
        {
          testNet.spikeBuffers [ i ].buff.clear ( );
        }

        // testNet.fireQueue.insertItem(
        //   new FireEvent(
        //     testNet.neurons.length-1+testNet.base,
        //     testNet.endOfTrial-testNet.minDelay/2.0));

        // for(int i=0; i<testNet.neurons.length-1; i++)
        // {
        //   if(testNet.neurons[i].isSensory())
        //   {
        //     // sensory neuron send spikes to the other neurons
        //
        //     fireQueue.insertItem(
        //       new FireEvent(
        //         i+base,
        //         neurons[i].updateFire() ));
        //   }
        //   else
        //   {
        //     // nonsensory neurons will be initiazed;
        //
        //     testNet.neurons[i].init(idum);
        //   }
        // }

        lock.notify();

        while ( jpvmInfo.jpvm.pvm_probe ( ) )
        {
          // clear buffer
         
          final jpvmMessage  m = testNet.info.jpvm.pvm_recv ( );
         
// TODO:  Are this messages supposed to be ignored?         
        }     

        final jpvmBuffer  buf1 = new jpvmBuffer ( );

        buf1.pack ( "NetHost "+jpvmInfo.jpvm.pvm_mytid().toString()
          +" is ready to go"); //send out ready info;

        jpvmInfo.jpvm.pvm_send (
          buf1,
          jpvmInfo.tids[jpvmInfo.idIndex],
          NetMessageTag.readySig );

        // restore saved seed for comparison;                 
        // testNet.idum.seed = testNet.saveSeed;

        final jpvmBuffer  buf2 = new jpvmBuffer ( );

        jpvmInfo.jpvm.pvm_send (
          buf2,
          jpvmInfo.parentJpvmTaskId,
          NetMessageTag.trialDone);

        // testNet.p.println("reset done");
      }
    }

    public static void  processSendSpike (
      final Object       lock,
      final jpvmMessage  m,
      final Network      testNet,
      final JpvmInfo     jpvmInfo )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      synchronized ( lock )
      {
        int  sourceID = m.buffer.upkint ( );

        int  trialID = m.buffer.upkint ( );

        if ( trialID == testNet.countTrial )
        {
          // testNet.p.println ( "received and processed" );

          testNet.spikeState = true;

          ( testNet.received [ sourceID ] )++;

          for ( int iter = 0; iter < jpvmInfo.numTasks; iter++ )
          {
            if ( iter != jpvmInfo.idIndex
              && testNet.received [ iter ] == 0 )
            {
              testNet.spikeState = false;
            }
          }

          final SpikeBuffer
            sbuff = ( SpikeBuffer ) m.buffer.upkcnsobj ( );

          final Iterator<NetMessage>  iter = sbuff.buff.iterator ( );

          while ( iter.hasNext ( ) )
          {
            final NetMessage  message = iter.next ( );
           
            final Integer  fromInteger = Integer.valueOf ( message.from );

            // try {
           
            for ( final Branch  branch
              : testNet.axons.get ( fromInteger ).branches )
            {
              testNet.getInputEventSlot ( ).offer (
                new InputEvent (
                  message.time + branch.delay,
                  branch,
                  message.from ) ); //new input events
            }
           
            // }
            // catch(Exception ex)
            // {
            //   throw new RuntimeException (
            //     ex.getMessage()+"\n from:"+message.from
            //       +" host id:"+info.idIndex+" axons"
            //       +testNet.axons.size());
            // }
          }

          lock.notify ( );
        }
        /*
        else
        {
          // testNet.p.println("received and ignored");
        }
        */
      }
    }
   
    public static void  processStopSig (
      final Network   testNet,
      final JpvmInfo  jpvmInfo,
      final Object    lock )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      testNet.p.println (
        "get stop message root " + testNet.rootTime );

      testNet.p.flush ( );
     
      final jpvmBuffer  buf = new jpvmBuffer ( );

      buf.pack ( testNet.recorderData );

      buf.pack ( jpvmInfo.idIndex );

      jpvmInfo.jpvm.pvm_send (
        buf,
        jpvmInfo.tids [ jpvmInfo.idIndex ],
        NetMessageTag.getBackData );
     
      testNet.clearQueues ( );

      testNet.stop = true;

      testNet.p.close ( );

      synchronized ( lock )
      {
        lock.notify ( );
      }
    }
   
    public static void  processSyncRoot (
      final Object       lock,
      final Network      testNet,
      final jpvmMessage  m )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      synchronized ( lock )
      {
        if ( !testNet.trialDone )
        {
          testNet.rootTime = m.buffer.upkdouble ( );

          lock.notify ( );
        }
      }
    }
               
    public static void  processTempStopSig (
      final Network  testNet,
      final JpvmInfo  jpvmInfo )
      throws jpvmException
    ////////////////////////////////////////////////////////////////////////
    {
      final jpvmBuffer  buf = new jpvmBuffer ( );
               
      buf.pack ( testNet.recorderData );

      buf.pack ( jpvmInfo.idIndex );

      jpvmInfo.jpvm.pvm_send (
        buf,
        jpvmInfo.tids [ jpvmInfo.idIndex ],
        NetMessageTag.getBackData );

      testNet.recorderData.clear ( );
    }
   
    public static void  processTrialDone (
      final Object         lock,
      final CyclicBarrier  barrier,
      final jpvmMessage    m,
      final Network        testNet )
      throws jpvmException, InterruptedException, BrokenBarrierException
    ////////////////////////////////////////////////////////////////////////
    {

      // testNet.p.println("begin notifying");
      //
      // System.out.println(
      //   "start"+testNet.subExpId+" "+testNet.trialId);

      synchronized ( lock )
      {
        barrier.await ( );

        testNet.trialId = m.buffer.upkint ( );

        testNet.subExpId = m.buffer.upkint ( );

        testNet.endOfTrial
          = testNet.experiment.subExp [ testNet.subExpId ].trialLength;

        Thread.yield ( );

        Thread.yield ( );

        // synchronized (synLock)
        // {
        //   synLock.notify();
        //
        //   testNet.startSig=true;
        // }

        lock.notify ( );
      }
    }
   
    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    }
TOP

Related Classes of cnslab.cnsnetwork.NetHost

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.