package cnslab.cnsnetwork;
import java.io.*;
import java.util.*;
import java.util.concurrent.*;
import org.w3c.dom.*;
import org.w3c.dom.ls.*;
import org.w3c.dom.bootstrap.*;
import jpvm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import cnslab.cnsmath.*;
import edu.jhu.mb.ernst.net.PeerInfo;
/***********************************************************************
* NetHost class.
*
* It first initialize itself, prepare the jpvm environment.
* populate the neurons, build up the network structure,
* start the executing thread, then listen to message from trial host
* and mainhost.
*
* @version
* $Date: 2012-08-04 20:43:22 +0200 (Sat, 04 Aug 2012) $
* $Rev: 104 $
* $Author: croft $
* @author
* Yi Dong
* @author
* David Wallace Croft
***********************************************************************/
public class NetHost
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
{
private static final Class<NetHost>
CLASS = NetHost.class;
// private static final String
// CLASS_NAME = CLASS.getName ( );
private static final Logger
LOGGER = LoggerFactory.getLogger ( CLASS );
//public static JpvmInfo info;
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
public static void main ( final String [ ] args )
////////////////////////////////////////////////////////////////////////
{
jpvmEnvironment jpvmEnvironmentInstance = null;
try
{
// double minDelay;
// double backFire;
// Enroll in the parallel virtual machine...
jpvmEnvironmentInstance = new jpvmEnvironment ( );
// receive information about its peers
// It will hang here until it receives a network message.
final jpvmMessage message
= jpvmEnvironmentInstance.pvm_recv ( NetMessageTag.sendTids );
final PeerInfo peerInfo = new PeerInfo ( message );
launch ( jpvmEnvironmentInstance, peerInfo );
jpvmEnvironmentInstance.pvm_exit ( );
}
catch (jpvmException jpe)
{
System.out.println("Error - jpvm exception");
try
{
if ( jpvmEnvironmentInstance != null )
{
FileOutputStream out = new FileOutputStream(
"log/"+jpvmEnvironmentInstance.pvm_mytid().getHost()
+"error.txt");
PrintStream p = new PrintStream(out);
jpe.printStackTrace(p);
p.close();
}
}
catch(Exception ex)
{
//
}
}
catch (Exception a)
{
try
{
if ( jpvmEnvironmentInstance != null )
{
FileOutputStream out = new FileOutputStream(
"log/"+ jpvmEnvironmentInstance.pvm_mytid().getHost()
+"error.txt");
PrintStream p = new PrintStream(out);
a.printStackTrace(p);
p.close();
}
}
catch(Exception ex)
{
//
}
}
}
public static void launch (
final jpvmEnvironment jpvmEnvironmentInstance,
final PeerInfo peerInfo )
throws Exception, jpvmException
////////////////////////////////////////////////////////////////////////
{
final JpvmInfo jpvmInfo = new JpvmInfo ( );
jpvmInfo.jpvm = jpvmEnvironmentInstance;
jpvmInfo.myJpvmTaskId = jpvmEnvironmentInstance.pvm_mytid ( );
// Get my parent's task id...
// info.parent = info.jpvm.pvm_parent();
// actually it is grand parent;
peerInfo.populateJpvmInfo ( jpvmInfo );
int seedInt = peerInfo.seedInt;
// NetHosts tid
//FileOutputStream outt = new FileOutputStream(
// "log/"+info.parent.getHost()+"_preinfo"+iter_int+".txt");
//PrintStream p = new PrintStream(outt);
// String out="Host id "+info.idIndex+"\n";
final DOMImplementationRegistry
registry = DOMImplementationRegistry.newInstance ( );
final DOMImplementationLS impl
= ( DOMImplementationLS ) registry.getDOMImplementation ( "LS" );
final SimulatorParser simulatorParser = createSimulatorParser (
impl,
peerInfo.byteArray,
seedInt,
jpvmInfo );
final CyclicBarrier barrier = new CyclicBarrier ( 2 );
/*
//receiving neurons
m = info.jpvm.pvm_recv(NetMessageTag.sendNeurons);
int numNeurons;
numNeurons = m.buffer.upkint();
int base = m.buffer.upkint();
// the last neuron is a sync neuron
Neuron neurons[] = new Neuron[numNeurons+1];
m.buffer.unpack(neurons,numNeurons,1);
*/
/*
out=out+"\n";
for(int i=0; i< numNeurons; i++)
{
out=out+"Neurons are:"+neurons[i].toString();
}
*/
final Seed idum = new Seed ( seedInt - jpvmInfo.idIndex );
// neurons, pvminfo, base index, mini delay, background freq,
// seed number
final Network testNet = new Network (
simulatorParser.getModelFactory ( ),
simulatorParser.getDiscreteEventQueue ( ),
simulatorParser.getModulatedSynapseSeq ( ),
simulatorParser.layerStructure.neurons,
simulatorParser.layerStructure.axons,
jpvmInfo,
simulatorParser.layerStructure.base,
simulatorParser.minDelay,
simulatorParser,
idum,
simulatorParser.experiment );
testNet.initNet ( );
// pas.p = testNet.p;
// if(info.idIndex==0) pas.ls.connectFrom("O,0,0,E", testNet.p);
// if(info.idIndex==4) pas.ls.connectFrom("O,0,16,0,E", testNet.p);
// pas.ls.connectFrom("T,27,9,L", testNet.p);
testNet.p.println("Host id "+jpvmInfo.idIndex+"\n");
testNet.p.println("base "+testNet.base+"\n");
testNet.p.println("idum "+idum.seed+"\n");
testNet.p.println("NumOfSyn "+simulatorParser.layerStructure.numSYN+"\n");
// testNet.p.flush();
Object lock = new Object();
// Object synLock = new Object();
// ListenInput listen = new ListenInput(testNet,lock);
// PCommunicationThread
// p1 = new PCommunicationThread(testNet, lock);
// PComputationThread p2 = new PComputationThread(testNet, lock);
final PRun p1 = new PRun (
testNet.getDiscreteEventQueue ( ),
testNet,
lock,
barrier );
// pas.ls.cellmap=null;
System.gc ( );
final Thread run = new Thread ( p1 );
run.setDaemon ( true );
run.start ( );
// Barrier Sync
jpvmBuffer buf = new jpvmBuffer();
// send out ready info;
final String
messageString
= "NetHost " + jpvmInfo.jpvm.pvm_mytid ( ).toString ( )
+ " is ready to go";
buf.pack ( messageString );
LOGGER.info ( messageString );
jpvmTaskId
jpvmTaskIdInstance = jpvmInfo.tids [ jpvmInfo.idIndex ];
if ( jpvmTaskIdInstance == null )
{
LOGGER.warn ( "jpvmTaskIdInstance == null" );
// TODO: This is a temporary hack for JUnit testing.
jpvmTaskIdInstance = jpvmInfo.jpvm.pvm_mytid ( );
}
jpvmInfo.jpvm.pvm_send (
buf,
jpvmTaskIdInstance,
NetMessageTag.readySig );
// m = info.jpvm.pvm_recv(NetMessageTag.readySig);
// Barrier Sync
while ( !testNet.stop )
{
// receive info from others
final jpvmMessage m = testNet.info.jpvm.pvm_recv ( );
if ( m.messageTag == NetMessageTag.trialDone )
{
synchronized(lock)
{
testNet.trialDone = false;
}
}
// synchronized(lock)
{
// testNet.p.println("message "+m.messageTag);
// lock.notify();
switch ( m.messageTag )
{
case NetMessageTag.sendSpike:
processSendSpike (
lock,
m,
testNet,
jpvmInfo );
break;
case NetMessageTag.syncRoot:
// if its a message about time
processSyncRoot (
lock,
testNet,
m );
break;
case NetMessageTag.stopSig:
// if its a message about time
processStopSig (
testNet,
jpvmInfo,
lock );
break;
case NetMessageTag.tempStopSig:
//if its a message about time
processTempStopSig (
testNet,
jpvmInfo );
break;
case NetMessageTag.trialDone:
// new trial begins
processTrialDone (
lock,
barrier,
m,
testNet );
break;
case NetMessageTag.changeConnection:
processChangeConnection (
m,
impl,
simulatorParser,
testNet,
jpvmInfo );
break;
case NetMessageTag.netHostNotify:
processNetHostNotify ( jpvmInfo );
break;
case NetMessageTag.resetNetHost:
processResetNetHost (
lock,
testNet,
jpvmInfo );
break;
case NetMessageTag.checkTime:
processCheckTime (
m,
testNet,
p1,
jpvmInfo );
break;
}
}
}
}
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
public static SimulatorParser createSimulatorParser (
final DOMImplementationLS impl,
final byte [ ] byteArray,
final int seedInt,
final JpvmInfo jpvmInfo )
throws Exception, jpvmException
////////////////////////////////////////////////////////////////////////
{
final LSInput input = impl.createLSInput ( );
input.setByteStream ( new ByteArrayInputStream ( byteArray ) );
final LSParser parser = impl.createLSParser (
DOMImplementationLS.MODE_SYNCHRONOUS,
null );
final SimulatorParser simulatorParser = new SimulatorParser (
new Seed ( seedInt - jpvmInfo.idIndex ),
parser.parse ( input ) );
simulatorParser.parseMapCells ( jpvmInfo.idIndex );
// p.println("cell maped");
// System.out.println(pas.ls.base+ " "+pas.ls.neuron_end);
simulatorParser.parseNeuronDef ( );
simulatorParser.parsePopNeurons ( jpvmInfo.idIndex );
// p.println("cell poped");
simulatorParser.parseScaffold ( jpvmInfo.idIndex );
simulatorParser.layerStructure.buildStructure ( jpvmInfo.idIndex );
simulatorParser.parseConnection ( jpvmInfo.idIndex );
simulatorParser.parseTarget ( jpvmInfo );
// sort the synapses, not necessary unless tune the parameters
// pas.ls.sortSynapses(info.idIndex);
// p.println("connected");
simulatorParser.parseExp ( jpvmInfo.idIndex );
// p.println("exp");
simulatorParser.findMinDelay ( );
// p.close();
if ( simulatorParser.layerStructure.axons == null )
{
throw new RuntimeException ( "no axon info" );
}
return simulatorParser;
}
public static void processCheckTime (
final jpvmMessage m,
final Network testNet,
final PRun pRun,
final JpvmInfo jpvmInfo )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
// testNet.p.println("check time now");
// testNet.p.flush();
final jpvmBuffer buf = new jpvmBuffer ( );
// send the available Host id
buf.pack ( m.buffer.upkint ( ) );
buf.pack ( testNet.subExpId );
buf.pack ( pRun.minTime );
buf.pack ( testNet.trialId );
jpvmInfo.jpvm.pvm_send (
buf,
jpvmInfo.tids [ jpvmInfo.idIndex ],
NetMessageTag.checkTime );
}
public static void processChangeConnection (
final jpvmMessage m,
final DOMImplementationLS impl,
final SimulatorParser simulatorParser,
final Network testNet,
final JpvmInfo jpvmInfo )
throws jpvmException, Exception
////////////////////////////////////////////////////////////////////////
{
final int seedInt = m.buffer.upkint ( );
final int baLength = m.buffer.upkint ( );
final byte [ ] ba = new byte [ baLength ];
m.buffer.unpack ( ba, baLength, 1 );
// DOMImplementationRegistry
// registry = DOMImplementationRegistry.newInstance();
// DOMImplementationLS impl = (DOMImplementationLS)
// registry.getDOMImplementation("LS");
final LSInput input = impl.createLSInput();
input.setByteStream ( new ByteArrayInputStream ( ba ) );
final LSParser parser = impl.createLSParser (
DOMImplementationLS.MODE_SYNCHRONOUS,
null );
final NodeList
conns = simulatorParser.rootElement.getElementsByTagName ( "Connections" );
testNet.p.println("connection num"+conns.getLength());
simulatorParser.rootElement.removeChild ( conns.item ( 0 ) );
final Node dup = simulatorParser.document.importNode (
parser.parse(input).getDocumentElement()
.getElementsByTagName("Connections").item(0),
true );
simulatorParser.rootElement.appendChild ( dup );
double weight;
// change the connections;
weight = simulatorParser.parseChangeConnection (
jpvmInfo.idIndex );
// synchronize
// testNet.p.println(
// "connection change done with weight "+weight);
testNet.seed = new Seed ( seedInt - jpvmInfo.idIndex);
// testNet.p.println("new seed"+testNet.idum.seed);
final jpvmBuffer buf1 = new jpvmBuffer ( );
if ( weight < 0 )
{
buf1.pack (
"NetHost "+ jpvmInfo.jpvm.pvm_mytid().getHost()
+" has been changed"); //send out ready info;
}
else
{
// send out ready info
buf1.pack ( "badweight$" + weight );
}
jpvmInfo.jpvm.pvm_send (
buf1,
jpvmInfo.tids[jpvmInfo.idIndex],
NetMessageTag.readySig );
//end of sync
if(weight < 0 )
{
final jpvmBuffer buf2 = new jpvmBuffer();
jpvmInfo.jpvm.pvm_send (
buf2,
jpvmInfo.parentJpvmTaskId,
NetMessageTag.trialDone );
}
}
public static void processNetHostNotify (
final JpvmInfo jpvmInfo )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
// testNet.p.println("nethost notify");
final jpvmBuffer buf1 = new jpvmBuffer ( );
// send out ready info
buf1.pack (
"NetHost " + jpvmInfo.jpvm.pvm_mytid ( ).toString ( )
+ " is ready to go" );
jpvmInfo.jpvm.pvm_send (
buf1,
jpvmInfo.tids [ jpvmInfo.idIndex ],
NetMessageTag.readySig );
// restore saved seed for comparison;
// testNet.idum.seed = testNet.saveSeed;
final jpvmBuffer buf2 = new jpvmBuffer ( );
jpvmInfo.jpvm.pvm_send (
buf2,
jpvmInfo.parentJpvmTaskId,
NetMessageTag.trialDone );
}
public static void processResetNetHost (
final Object lock,
final Network testNet,
final JpvmInfo jpvmInfo )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
synchronized(lock)
{
// break;
testNet.trialDone = true;
// added;
testNet.spikeState = true;
for ( int iter = 0; iter < testNet.info.numTasks; iter++ )
{
/*
while(!testNet.received[iter].empty())
{
testNet.received[iter].pop();
}
*/
// leave mark here
testNet.received [ iter ] = 1;
}
//added over;
// initilization
// testNet.rootTime=0.0;
testNet.recorderData.clear();
testNet.clearQueues ( );
testNet.countTrial++;
// testNet.p.println("reset the trial now");
testNet.recordBuff.buff.clear();
testNet.intraRecBuffers.init();
for ( int i = 0; i < jpvmInfo.numTasks; i++ )
{
testNet.spikeBuffers [ i ].buff.clear ( );
}
// testNet.fireQueue.insertItem(
// new FireEvent(
// testNet.neurons.length-1+testNet.base,
// testNet.endOfTrial-testNet.minDelay/2.0));
// for(int i=0; i<testNet.neurons.length-1; i++)
// {
// if(testNet.neurons[i].isSensory())
// {
// // sensory neuron send spikes to the other neurons
//
// fireQueue.insertItem(
// new FireEvent(
// i+base,
// neurons[i].updateFire() ));
// }
// else
// {
// // nonsensory neurons will be initiazed;
//
// testNet.neurons[i].init(idum);
// }
// }
lock.notify();
while ( jpvmInfo.jpvm.pvm_probe ( ) )
{
// clear buffer
final jpvmMessage m = testNet.info.jpvm.pvm_recv ( );
// TODO: Are this messages supposed to be ignored?
}
final jpvmBuffer buf1 = new jpvmBuffer ( );
buf1.pack ( "NetHost "+jpvmInfo.jpvm.pvm_mytid().toString()
+" is ready to go"); //send out ready info;
jpvmInfo.jpvm.pvm_send (
buf1,
jpvmInfo.tids[jpvmInfo.idIndex],
NetMessageTag.readySig );
// restore saved seed for comparison;
// testNet.idum.seed = testNet.saveSeed;
final jpvmBuffer buf2 = new jpvmBuffer ( );
jpvmInfo.jpvm.pvm_send (
buf2,
jpvmInfo.parentJpvmTaskId,
NetMessageTag.trialDone);
// testNet.p.println("reset done");
}
}
public static void processSendSpike (
final Object lock,
final jpvmMessage m,
final Network testNet,
final JpvmInfo jpvmInfo )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
synchronized ( lock )
{
int sourceID = m.buffer.upkint ( );
int trialID = m.buffer.upkint ( );
if ( trialID == testNet.countTrial )
{
// testNet.p.println ( "received and processed" );
testNet.spikeState = true;
( testNet.received [ sourceID ] )++;
for ( int iter = 0; iter < jpvmInfo.numTasks; iter++ )
{
if ( iter != jpvmInfo.idIndex
&& testNet.received [ iter ] == 0 )
{
testNet.spikeState = false;
}
}
final SpikeBuffer
sbuff = ( SpikeBuffer ) m.buffer.upkcnsobj ( );
final Iterator<NetMessage> iter = sbuff.buff.iterator ( );
while ( iter.hasNext ( ) )
{
final NetMessage message = iter.next ( );
final Integer fromInteger = Integer.valueOf ( message.from );
// try {
for ( final Branch branch
: testNet.axons.get ( fromInteger ).branches )
{
testNet.getInputEventSlot ( ).offer (
new InputEvent (
message.time + branch.delay,
branch,
message.from ) ); //new input events
}
// }
// catch(Exception ex)
// {
// throw new RuntimeException (
// ex.getMessage()+"\n from:"+message.from
// +" host id:"+info.idIndex+" axons"
// +testNet.axons.size());
// }
}
lock.notify ( );
}
/*
else
{
// testNet.p.println("received and ignored");
}
*/
}
}
public static void processStopSig (
final Network testNet,
final JpvmInfo jpvmInfo,
final Object lock )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
testNet.p.println (
"get stop message root " + testNet.rootTime );
testNet.p.flush ( );
final jpvmBuffer buf = new jpvmBuffer ( );
buf.pack ( testNet.recorderData );
buf.pack ( jpvmInfo.idIndex );
jpvmInfo.jpvm.pvm_send (
buf,
jpvmInfo.tids [ jpvmInfo.idIndex ],
NetMessageTag.getBackData );
testNet.clearQueues ( );
testNet.stop = true;
testNet.p.close ( );
synchronized ( lock )
{
lock.notify ( );
}
}
public static void processSyncRoot (
final Object lock,
final Network testNet,
final jpvmMessage m )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
synchronized ( lock )
{
if ( !testNet.trialDone )
{
testNet.rootTime = m.buffer.upkdouble ( );
lock.notify ( );
}
}
}
public static void processTempStopSig (
final Network testNet,
final JpvmInfo jpvmInfo )
throws jpvmException
////////////////////////////////////////////////////////////////////////
{
final jpvmBuffer buf = new jpvmBuffer ( );
buf.pack ( testNet.recorderData );
buf.pack ( jpvmInfo.idIndex );
jpvmInfo.jpvm.pvm_send (
buf,
jpvmInfo.tids [ jpvmInfo.idIndex ],
NetMessageTag.getBackData );
testNet.recorderData.clear ( );
}
public static void processTrialDone (
final Object lock,
final CyclicBarrier barrier,
final jpvmMessage m,
final Network testNet )
throws jpvmException, InterruptedException, BrokenBarrierException
////////////////////////////////////////////////////////////////////////
{
// testNet.p.println("begin notifying");
//
// System.out.println(
// "start"+testNet.subExpId+" "+testNet.trialId);
synchronized ( lock )
{
barrier.await ( );
testNet.trialId = m.buffer.upkint ( );
testNet.subExpId = m.buffer.upkint ( );
testNet.endOfTrial
= testNet.experiment.subExp [ testNet.subExpId ].trialLength;
Thread.yield ( );
Thread.yield ( );
// synchronized (synLock)
// {
// synLock.notify();
//
// testNet.startSig=true;
// }
lock.notify ( );
}
}
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
}