package cnslab.cnsnetwork;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import cnslab.cnsmath.Seed;
import org.w3c.dom.Node;
import org.w3c.dom.bootstrap.DOMImplementationRegistry;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSSerializer;
import org.w3c.dom.ls.LSOutput;
import java.io.FileOutputStream;
import java.io.File;
import java.io.ByteArrayOutputStream;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import org.w3c.dom.DocumentType;
import jpvm.jpvmTaskId;
import cnslab.cnsnetwork.JpvmInfo;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.Set;
import java.util.Map;
import jpvm.jpvmEnvironment;
import jpvm.jpvmBuffer;
import jpvm.jpvmMessage;
import jpvm.jpvmException;
import java.util.Iterator;
import java.util.Timer;
import java.util.TimerTask;
/***********************************************************************
* Deprecated.
*
* @version
* $Date: 2012-08-04 20:43:22 +0200 (Sat, 04 Aug 2012) $
* $Rev: 104 $
* $Author: croft $
* @author
* Yi Dong
* @author
* David Wallace Croft
***********************************************************************/
@Deprecated
public class FitFun
{
public FitFun(ParToDoc parDoc, Seed idum, String modelFilename, int heapSize )
{
try {
this.parDoc = parDoc;
this.seedInt = idum.seed;
this.saveInt = idum.seed;
this.idum = idum;
//spawn all the slaves and ready for the simulation
pas = new SimulatorParser (
idum, new File ( "model/" + modelFilename ) );
pas.parseMapCells();
pas.parseExperiment();
exp = pas.experiment; // experiment infomation
pas.findMinDelay();
pas.document.removeChild(pas.documentType);
DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
LSSerializer writer = impl.createLSSerializer();
LSOutput output = impl.createLSOutput();
ByteArrayOutputStream bArray = new ByteArrayOutputStream();
output.setByteStream(bArray);
writer.write(pas.document, output);
pas.document.appendChild(pas.documentType);
byte [] ba = bArray.toByteArray();
info= new JpvmInfo();
// Enroll in the parallel virtual machine...
info.jpvm = new jpvmEnvironment();
// Get my task id...
info.myJpvmTaskId = info.jpvm.pvm_mytid();
System.out.println("Task Id: "+info.myJpvmTaskId.toString());
info.numTasks= pas.parallelHost; // total number of trial hosts;
info.idIndex=info.numTasks; // root id; //not used at all
info.tids = new jpvmTaskId[info.numTasks];
// Spawn some trialHosts
info.jpvm.pvm_spawn("cnslab.cnsnetwork.TrialHost",info.numTasks,info.tids,48);
System.out.println("spawn successfully");
jpvmBuffer buf2 = new jpvmBuffer();
buf2.pack(info.numTasks);
buf2.pack(info.tids, info.numTasks, 1);
buf2.pack(pas.minDelay);
info.jpvm.pvm_mcast(buf2,info.tids,info.numTasks,NetMessageTag.sendTids);
System.out.println("All sent");
//gernerate all the nethosts :
info.endIndex=pas.layerStructure.nodeEndIndices;
netTids = new jpvmTaskId [info.numTasks][info.endIndex.length] ;
for(int i = 0 ; i < info.numTasks; i++)
{
System.out.println("generate child for trialHost "+i);
info.jpvm.pvm_spawn("cnslab.cnsnetwork.NetHostTune",info.endIndex.length,netTids[i],heapSize); //Net Host is to seperate large network into small pieces;
jpvmBuffer buf = new jpvmBuffer();
buf.pack(info.endIndex.length);
buf.pack(netTids[i],info.endIndex.length,1);
buf.pack(info.endIndex,info.endIndex.length,1);
seedInt = seedInt - info.endIndex.length;
buf.pack(seedInt);
buf.pack(info.tids[i]); //parent's tid;
buf.pack(ba.length);
buf.pack(ba, ba.length, 1);
info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.sendTids);
}
for(int i = 0 ; i < info.numTasks; i++)
{
jpvmBuffer buf = new jpvmBuffer();
buf.pack(info.endIndex.length);
buf.pack(netTids[i],info.endIndex.length,1);
info.jpvm.pvm_send(buf,info.tids[i],NetMessageTag.sendTids2); //send trial Host the child tids
}
//Barrier Sync
for (int i=0;i<info.numTasks*info.endIndex.length+info.numTasks; i++) {
// Receive a message...
jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
// Unpack the message...
String str = message.buffer.upkstr();
System.out.println(str);
}
jpvmBuffer buf = new jpvmBuffer();
}
catch(jpvmException ex) {
ex.printStackTrace();
}catch(Exception ap)
{
ap.printStackTrace();
}
}
public Seed idum;
public double valueSd; // value's standard deviation
public int trialId = -1;
public int expId = 0;
public int saveInt;
public ParToDoc parDoc;
public Experiment exp;
public JpvmInfo info;
public int [] endIndex;
public double minDelay;
public double backFire;
public int seedInt;
public jpvmTaskId[][] netTids;
public LinkedList[] intraReceiver ;
public RecorderData rdata;
public SimulatorParser pas;
public void closeHosts() {
try {
for (int i=0;i<info.numTasks; i++) {
jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
int freeId = message.buffer.upkint();
System.out.println("Trial Host "+freeId+" is killed");
jpvmBuffer buf = new jpvmBuffer();
buf.pack(0);
info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
}
}
catch(jpvmException ex) {
ex.printStackTrace();
}
}
public class ToDo {
Timer timer;
public ToDo ( int seconds ) {
timer = new Timer ( ) ;
timer.schedule ( new ToDoTask ( ) , seconds*1000) ;
}
class ToDoTask extends TimerTask {
public void run ( ) {
if(trialId == netTids.length-1) //no new trial is finished
{
jpvmBuffer buf = new jpvmBuffer();
try {
info.jpvm.pvm_send(buf,info.tids[0],NetMessageTag.checkTime); //check first trial host time
}
catch(jpvmException ex) {
ex.printStackTrace();
System.exit(-1);
}
}
timer.cancel() ;
}
}
public void stop()
{
timer.cancel() ;
}
}
/**
* @see cnslab.cnsnetwork.FitFun#fitFunction(double[]) fitFunction
*/
public double fitFunction(double[] para) {
double tmpValue=0.0;
try {
//initialized parameters
intraReceiver = new LinkedList[exp.recorder.intraEle.size()*exp.subExp.length];
rdata = new RecorderData();
//first change the connections for the netHost
Document newDoc = parDoc.getDocument(para);
NodeList conns = pas.rootElement.getElementsByTagName("Connections");
System.out.println("connections num"+conns.getLength());
pas.rootElement.removeChild(conns.item(0));
Node dup = pas.document.importNode(newDoc.getDocumentElement() , true);
pas.rootElement.appendChild(dup);
pas.document.normalizeDocument(); //expand everything like save and load
pas.document.removeChild(pas.documentType);
DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
LSSerializer writer = impl.createLSSerializer();
LSOutput output = impl.createLSOutput();
ByteArrayOutputStream bArray = new ByteArrayOutputStream();
output.setByteStream(bArray);
writer.write(pas.document, output);
pas.document.appendChild(pas.documentType);
byte [] ba = bArray.toByteArray();
seedInt = saveInt;
for (int i=0;i<info.numTasks; i++) {
// Receive a message...
jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
int freeId = message.buffer.upkint();
System.out.println("Trial Host "+freeId+" change connection");
seedInt = seedInt - info.endIndex.length;
jpvmBuffer buf = new jpvmBuffer();
buf.pack(seedInt);
buf.pack(ba.length);
buf.pack(ba, ba.length, 1);
info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.changeConnection);
// Unpack the message...
}
trialId = -1;
expId = 0;
int aliveHost = info.numTasks;
int totalTrials=0;
for(int j=0;j< exp.subExp.length; j++)
{
totalTrials += exp.subExp[j].repetition;
}
totalTrials = totalTrials * pas.numOfHosts;
boolean stop = false;
boolean badweight=false;
//Barrier Sync
for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
// Receive a message...
jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
// Unpack the message...
String str = message.buffer.upkstr();
int posi;
if((posi=str.indexOf("$")) >=0 )
{
badweight = true;
double tmpval = Double.parseDouble(str.substring(posi+1));
if(tmpval > tmpValue) {
tmpValue = tmpval;
valueSd =0.0;
}
}
System.out.println(str);
}
jpvmBuffer buf ;
double per=0.0;
boolean overflow=false;
if(!badweight)
{
System.out.println("************ simulation is starting *********************");
//listening and processing
ToDo toDo = new ToDo(90); //1min's threshold
int countHosts=0;
while(!stop)
{
jpvmMessage m = info.jpvm.pvm_recv(); //receive info from others
switch(m.messageTag)
{
case NetMessageTag.checkTime: //received percentage time and reset netHosts
int hostId = m.buffer.upkint();
int eId = m.buffer.upkint();
double root_time = m.buffer.upkdouble();
per = root_time/exp.subExp[eId].trialLength*100;
stop = true; //can stop now;
overflow=true;
for (int i=0;i<info.numTasks; i++) //reset all the netHosts
{
buf = new jpvmBuffer();
info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.resetNetHost);
}
System.out.println("Overflow detected");
break;
case NetMessageTag.trialReady:
int freeId = m.buffer.upkint(); //get free host id;
trialId++;
// System.out.println("Now trialId is "+ trialId);
if(expId < exp.subExp.length)
{
if(trialId == exp.subExp[expId].repetition && expId+1 == exp.subExp.length)
{
expId++;
}
else if(trialId == exp.subExp[expId].repetition && expId+1 != exp.subExp.length)
{
trialId =0;
expId++;
}
}
if(expId < exp.subExp.length) // game is still on;
{
System.out.println("Subexp "+expId+" trial "+ trialId+" freeId "+freeId);
buf = new jpvmBuffer();
buf.pack(trialId);
buf.pack(expId);
info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.oneTrial);
}
else if( aliveHost != 0 ) // if all the work are done and some hosts are not killed
{
System.out.println("host "+freeId+" finished his job");
// buf = new jpvmBuffer();
// buf.pack(0);
// info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
//aliveHost--;
buf = new jpvmBuffer();
buf.pack(0);
info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.tempStopSig);
}
break;
case NetMessageTag.getBackData:
countHosts++;
if(countHosts== info.endIndex.length) {aliveHost--;countHosts=0;}
//System.out.println("Receiving data from ");
RecorderData spikes = (RecorderData)m.buffer.upkcnsobj();
// System.out.print("Receiving data from "+m.buffer.upkint()+" and alive hosts"+aliveHost+"\n");
//comibne for single unit
Set entries = spikes.receiver.entrySet();
Iterator entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, LinkedList<Double> > entry = (Map.Entry<String, LinkedList<Double> >)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
LinkedList<Double> value = entry.getValue(); // Get the value.
LinkedList<Double> tmp = rdata.receiver.get(key);
if(tmp == null)
{
rdata.receiver.put(key, tmp=(new LinkedList<Double>()));
}
//if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
tmp.addAll(value); //put received info into memory
}
///comibne for multi unit
entries = spikes.multiCounter.entrySet();
entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
Integer value = entry.getValue(); // Get the value.
Integer tmp = rdata.multiCounter.get(key);
if(tmp == null)
{
rdata.multiCounter.put(key, tmp=(new Integer(0)));
}
tmp+=value;
rdata.multiCounter.put(key, tmp);
}
entries = spikes.multiCounterAll.entrySet();
entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
Integer value = entry.getValue(); // Get the value.
Integer tmp = rdata.multiCounterAll.get(key);
if(tmp == null)
{
rdata.multiCounterAll.put(key, tmp=(new Integer(0)));
}
tmp+=value;
rdata.multiCounterAll.put(key, tmp);
}
//combine for field ele
entries = spikes.fieldCounter.entrySet();
entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
// System.out.println(key);
Integer value = entry.getValue(); // Get the value.
Integer tmp = rdata.fieldCounter.get(key);
if(tmp == null)
{
rdata.fieldCounter.put(key, tmp=(new Integer(0)));
}
tmp+=value;
rdata.fieldCounter.put(key, tmp);
}
//combine for vector ele
entries = spikes.vectorCounterX.entrySet();
entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
Double value = entry.getValue(); // Get the value.
Double tmp = rdata.vectorCounterX.get(key);
if(tmp == null)
{
rdata.vectorCounterX.put(key, tmp=(new Double(0.0)));
}
tmp+=value;
rdata.vectorCounterX.put(key, tmp);
}
entries = spikes.vectorCounterY.entrySet();
entryIter = entries.iterator();
while (entryIter.hasNext()) {
Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
String key = entry.getKey(); // Get the key from the entry.
Double value = entry.getValue(); // Get the value.
Double tmp = rdata.vectorCounterY.get(key);
if(tmp == null)
{
rdata.vectorCounterY.put(key, tmp=(new Double(0.0)));
}
tmp+=value;
rdata.vectorCounterY.put(key, tmp);
}
break;
case NetMessageTag.trialDone:
totalTrials--;
int res_trial = m.buffer.upkint();
int res_exp = m.buffer.upkint();
// System.out.println("R: "+"E"+res_exp+"T"+res_trial);
//RecordBuffer spikes = (RecordBuffer)m.buffer.upkcnsobj();
IntraRecBuffer intra = (IntraRecBuffer)m.buffer.upkcnsobj();
/*
// System.out.println(intra);
Iterator<NetRecordSpike> iter_spike = spikes.buff.iterator();
// System.out.println("spikes"+spikes.buff.size());
while(iter_spike.hasNext())
{
NetRecordSpike spike = iter_spike.next();
LinkedList<Double> tmp = receiver.get("E"+res_exp+"T"+res_trial+"N"+spike.from);
if(tmp == null)
{
receiver.put("E"+res_exp+"T"+res_trial+"N"+spike.from, tmp=(new LinkedList<Double>()));
}
//if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
tmp.add(spike.time); //put received info into memory
//System.out.println("fire: time:"+spike.time+ " index:"+spike.from);
}
*/
for(int i = 0; i < intra.neurons.length; i++)
{
int neu = intra.neurons[i];
int eleId = exp.recorder.intraIndex(neu);
LinkedList<IntraInfo> info = intra.buff.get(i);
LinkedList<IntraInfo> currList;
if((currList=(LinkedList<IntraInfo>)intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()])!=null)
{
Iterator<IntraInfo> intraData = info.iterator();
Iterator<IntraInfo> thisData = currList.iterator();
while(intraData.hasNext())
{
(thisData.next()).plus(intraData.next());
}
}
else
{
intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()]=info;
}
}
break;
}
if( aliveHost == 0 && totalTrials==0)
{
stop = true;
break;
}
}
toDo.stop();
if(overflow)
{
tmpValue = (100.0-per)/100.0;
valueSd =0.0;
//Barrier Sync
}
else
{
tmpValue = parDoc.getFitValue(pas,intraReceiver,rdata);
double [] bootVal = new double [200]; //200 times to get Sd.
for(int i=0; i < 200; i++)
{
bootVal[i] = parDoc.getRanFitValue(pas,intraReceiver,rdata,idum);
}
valueSd = FunUtil.sd(bootVal);
for (int i=0;i<info.numTasks; i++) {
buf = new jpvmBuffer();
info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
}
}
}
else
{
for (int i=0;i<info.numTasks; i++) {
buf = new jpvmBuffer();
info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
}
}
for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
// Receive a message...
jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
// Unpack the message...
String str = message.buffer.upkstr();
System.out.println(str);
}
//start working for the next fitting evaluation
System.out.println("cost:"+tmpValue+" Sd:"+valueSd);
}
catch(Exception ex) {
ex.printStackTrace();
System.exit(-1);
}catch(jpvmException ap)
{
ap.printStackTrace();
System.exit(-1);
}
return tmpValue;
}
}