package cnslab.cnsnetwork;
import java.io.ByteArrayInputStream;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.Stack;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSInput;
import org.w3c.dom.ls.LSParser;
import org.w3c.dom.bootstrap.DOMImplementationRegistry;
import jpvm.*;
import cnslab.cnsmath.*;
// TODO: Refactor in the same way NetHost was refactored.
/***********************************************************************
* Same as NetHost but deals with avalanche
*
* @version
* $Date: 2012-08-04 20:43:22 +0200 (Sat, 04 Aug 2012) $
* $Rev: 104 $
* $Author: croft $
* @author
* Yi Dong
* @author
* David Wallace Croft
***********************************************************************/
public final class ANetHost
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
{
public static JpvmInfo info;
public static void main(String args[])
////////////////////////////////////////////////////////////////////////
{
try {
double minDelay;
double backFire;
int seedInt;
info= new JpvmInfo();
// Enroll in the parallel virtual machine...
info.jpvm = new jpvmEnvironment();
info.myJpvmTaskId = info.jpvm.pvm_mytid();
// Get my parent's task id...
// info.parent = info.jpvm.pvm_parent(); //actually it is grand parent;
//recevive infomation about it's peers
jpvmMessage m = info.jpvm.pvm_recv(NetMessageTag.sendTids);
info.numTasks =m.buffer.upkint();
info.tids = new jpvmTaskId[info.numTasks];
info.endIndex = new int[info.numTasks];
m.buffer.unpack(info.tids,info.numTasks,1);
m.buffer.unpack(info.endIndex,info.numTasks,1);
seedInt = m.buffer.upkint();
int iter_int;
for(iter_int=0; iter_int< info.numTasks; iter_int++)
{
if(info.myJpvmTaskId.equals(info.tids[iter_int])){break;}
}
info.idIndex=iter_int;
info.parentJpvmTaskId = m.buffer.upktid(); //trialHost.
info.tids[info.idIndex]= info.jpvm.pvm_parent(); //self id should not be stored, change to grandpa's id
//NetHosts tid
//FileOutputStream outt = new FileOutputStream("log/"+info.parent.getHost()+"_preinfo"+iter_int+".txt");
//PrintStream p = new PrintStream(outt);
// String out="Host id "+info.idIndex+"\n";
byte [] ba;
int baLength;
baLength = m.buffer.upkint();
ba = new byte[baLength];
m.buffer.unpack(ba,baLength,1);
DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
LSInput input = impl.createLSInput();
input.setByteStream(new ByteArrayInputStream(ba));
LSParser parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null);
SimulatorParser pas = new SimulatorParser(new Seed(seedInt-info.idIndex), parser.parse(input));
pas.parseMapCells(info.idIndex);
// p.println("cell maped");
// System.out.println(pas.ls.base+ " "+pas.ls.neuron_end);
pas.parseNeuronDef();
pas.parsePopNeurons(info.idIndex);
// p.println("cell poped");
pas.parseScaffold(info.idIndex);
pas.layerStructure.buildStructure(info.idIndex);
pas.parseConnection(info.idIndex);
pas.parseTarget(info);
// pas.ls.sortSynapses(info.idIndex); //sort the synapses, not necessary unless tune the parameters
// p.println("connected");
pas.parseExp(info.idIndex);
// p.println("exp");
pas.findMinDelay();
// p.close();
CyclicBarrier barrier = new CyclicBarrier(2);
/*
//receiving neurons
m = info.jpvm.pvm_recv(NetMessageTag.sendNeurons);
int numNeurons;
numNeurons = m.buffer.upkint();
int base = m.buffer.upkint();
Neuron neurons[] = new Neuron[numNeurons+1]; // the last neuron is a sync neuron;
m.buffer.unpack(neurons,numNeurons,1);
*/
/*
out=out+"\n";
for(int i=0; i< numNeurons; i++)
{
out=out+"Neurons are:"+neurons[i].toString();
}
*/
Seed idum= new Seed(seedInt-info.idIndex);
if(pas.layerStructure.axons == null) throw new RuntimeException("no axon info");
// neurons,pvminfo,base index, mini delay, background freq, seed number
final Network testNet = new ANetwork (
pas.getModelFactory ( ),
pas.getDiscreteEventQueue ( ),
pas.getModulatedSynapseSeq ( ),
pas.layerStructure.neurons,
pas.layerStructure.axons,
info,
pas.layerStructure.base,
pas.minDelay,
pas,
idum,
pas.experiment );
testNet.initNet();
// pas.p = testNet.p;
// if(info.idIndex==0) pas.ls.connectFrom("O,0,0,E", testNet.p);
// if(info.idIndex==4) pas.ls.connectFrom("O,0,16,0,E", testNet.p);
// pas.ls.connectFrom("T,27,9,L", testNet.p);
testNet.p.println("Host id "+info.idIndex+"\n");
testNet.p.println("base "+testNet.base+"\n");
testNet.p.println("idum "+idum.seed+"\n");
testNet.p.println("NumOfSyn "+pas.layerStructure.numSYN+"\n");
// testNet.p.flush();
Object lock = new Object();
Object synLock = new Object();
// ListenInput listen = new ListenInput(testNet,lock);
// PCommunicationThread p1 = new PCommunicationThread(testNet, lock);
// PComputationThread p2 = new PComputationThread(testNet, lock);
PRun p1 = new PRun (
testNet.getDiscreteEventQueue ( ),
testNet,
lock,
barrier );
// pas.ls.cellmap=null;
System.gc();
Thread run = new Thread(p1);
run.setDaemon(true);
run.start();
//Barrier Sync
jpvmBuffer buf = new jpvmBuffer();
buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
// m = info.jpvm.pvm_recv(NetMessageTag.readySig);
//Barrier Sync
while (!testNet.stop)
{
m = testNet.info.jpvm.pvm_recv(); //receive info from others
if(m.messageTag==NetMessageTag.trialDone)testNet.trialDone=false;
// synchronized(lock)
{
// testNet.p.println("message "+m.messageTag);
// lock.notify();
switch(m.messageTag)
{
case NetMessageTag.sendSpike:
synchronized (lock)
{
int sourceID = m.buffer.upkint();
int trialID = m.buffer.upkint();
if(trialID==testNet.countTrial)
{
// testNet.p.println("received and processed");
testNet.spikeState=true;
(testNet.received[sourceID])++;
for(int iter=0;iter<info.numTasks;iter++)
{
if(iter!=info.idIndex && testNet.received[iter]==0) testNet.spikeState=false;
}
SpikeBuffer sbuff = (SpikeBuffer)m.buffer.upkcnsobj();
Iterator<NetMessage> iter = sbuff.buff.iterator();
while(iter.hasNext())
{
NetMessage message = iter.next();
// try {
for(int ii =0; ii <testNet.axons.get(message.from).branches.length; ii++)
{
// new input events
testNet.getInputEventSlot ( ).offer (
new AInputEvent (
message.time
+ testNet.axons.get ( message.from )
.branches [ ii ].delay,
testNet.axons.get ( message.from )
.branches [ ii ],
message.from,
( ( ANetMessage ) message ).sourceId,
( ( ANetMessage ) message ).avalancheId ) );
}
// }
// catch(Exception ex) {
// throw new RuntimeException(ex.getMessage()+"\n from:"+message.from+" host id:"+info.idIndex+" axons"+testNet.axons.size());
// }
}
lock.notify();
}
/*
else
{
// testNet.p.println("received and ignored");
}
*/
}
break;
case NetMessageTag.syncRoot: //if its a message about time
synchronized (lock)
{
if(!testNet.trialDone)
{
testNet.rootTime=m.buffer.upkdouble();
lock.notify();
}
}
break;
case NetMessageTag.stopSig: //if its a message about time
testNet.p.println("get stop message root "+testNet.rootTime);
testNet.p.flush();
buf = new jpvmBuffer();
buf.pack(testNet.recorderData);
//buf.pack(info.idIndex);
buf.pack(((ANetwork) testNet).aData);
testNet.p.println("Size of couting "+ ((ANetwork) testNet).aData.avalancheCounter.size());
testNet.p.flush();
info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData);
testNet.clearQueues ( );
testNet.stop=true;
testNet.p.close();
synchronized (lock)
{
lock.notify();
}
break;
case NetMessageTag.tempStopSig: //if its a message about time
buf = new jpvmBuffer();
buf.pack(testNet.recorderData);
buf.pack(info.idIndex);
info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData);
testNet.recorderData.clear();
break;
case NetMessageTag.trialDone: // new trial begins
// testNet.p.println("begin nofiying");
//
synchronized (lock)
{
testNet.trialId = m.buffer.upkint();
testNet.subExpId = m.buffer.upkint();
testNet.endOfTrial = testNet.experiment.subExp[testNet.subExpId].trialLength;
Thread.yield();
Thread.yield();
barrier.await();
// synchronized (synLock)
// {
// synLock.notify();
// testNet.startSig=true;
// }
lock.notify();
}
break;
case NetMessageTag.changeConnection:
seedInt = m.buffer.upkint();
baLength = m.buffer.upkint();
ba = new byte[baLength];
m.buffer.unpack(ba,baLength,1);
// DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
// DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
input = impl.createLSInput();
input.setByteStream(new ByteArrayInputStream(ba));
parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null);
NodeList conns = pas.rootElement.getElementsByTagName("Connections");
testNet.p.println("connection num"+conns.getLength());
pas.rootElement.removeChild(conns.item(0));
Node dup = pas.document.importNode(parser.parse(input).getDocumentElement().getElementsByTagName("Connections").item(0) , true);
pas.rootElement.appendChild(dup);
double weight;
weight = pas.parseChangeConnection(info.idIndex); //change the connections;
//synchroniz
// testNet.p.println("connection change done with weight "+weight);
testNet.seed = new Seed(seedInt-info.idIndex);
// testNet.p.println("new seed"+testNet.idum.seed);
buf = new jpvmBuffer();
if(weight<0)
{
buf.pack("NetHost "+info.jpvm.pvm_mytid().getHost()+" has been changed"); //send out ready info;
}
else
{
buf.pack("badweight$"+weight); //send out ready info;
}
info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
//end of sync
if(weight<0)
{
buf = new jpvmBuffer();
info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
}
break;
case NetMessageTag.netHostNotify:
// testNet.p.println("nethost notify");
buf = new jpvmBuffer();
buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
// testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison;
buf = new jpvmBuffer();
info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
break;
case NetMessageTag.resetNetHost:
synchronized(lock)
{
//break;
testNet.trialDone=true;
//added;
testNet.spikeState=true;
for(int iter=0;iter<testNet.info.numTasks;iter++)
{
/*
while(!testNet.received[iter].empty())
{
testNet.received[iter].pop();
}
*/
testNet.received[iter]=1; //leave mark here
}
//added over;
//initilization
// testNet.rootTime=0.0;
testNet.recorderData.clear();
testNet.clearQueues ( );
testNet.countTrial++;
// testNet.p.println("reset the trial now");
testNet.recordBuff.buff.clear();
testNet.intraRecBuffers.init();
for(int i=0; i<info.numTasks;i++)
{
testNet.spikeBuffers[i].buff.clear();
}
// testNet.fireQueue.insertItem( new FireEvent(testNet.neurons.length-1+testNet.base,testNet.endOfTrial-testNet.minDelay/2.0));
// for(int i=0; i<testNet.neurons.length-1; i++)
// {
// if(testNet.neurons[i].isSensory())
// {
// fireQueue.insertItem( new FireEvent(i+base, neurons[i].updateFire() )); //sensory neuron send spikes to the other neurons
// }
// else
// {
// testNet.neurons[i].init(idum); //nonsensory neurons will be initiazed;
// }
// }
lock.notify();
while( info.jpvm.pvm_probe())
{
m = testNet.info.jpvm.pvm_recv(); //clear buffer
}
buf = new jpvmBuffer();
buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
// testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison;
buf = new jpvmBuffer();
info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
// testNet.p.println("reset done");
}
break;
case NetMessageTag.checkTime:
// testNet.p.println("check time now");
// testNet.p.flush();
buf = new jpvmBuffer();
buf.pack(m.buffer.upkint()); //send the availale Host id
buf.pack(testNet.subExpId);
buf.pack(p1.minTime);
buf.pack(testNet.trialId);
info.jpvm.pvm_send(buf,info.tids[info.idIndex],NetMessageTag.checkTime);
break;
}
}
}
info.jpvm.pvm_exit();
}
catch (jpvmException jpe) {
System.out.println("Error - jpvm exception");
try {
FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt");
PrintStream p = new PrintStream(out);
jpe.printStackTrace(p);
p.close();
}
catch(Exception ex) {
}
}
catch (Exception a)
{
try {
FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt");
PrintStream p = new PrintStream(out);
a.printStackTrace(p);
p.close();
}
catch(Exception ex) {
}
}
}
};