Package cnslab.cnsnetwork

Source Code of cnslab.cnsnetwork.FitFun$ToDo$ToDoTask

    package cnslab.cnsnetwork;
   
    import org.w3c.dom.Document;
    import org.w3c.dom.Element;
    import org.w3c.dom.NodeList;
    import cnslab.cnsmath.Seed;
    import org.w3c.dom.Node;
    import org.w3c.dom.bootstrap.DOMImplementationRegistry;
    import org.w3c.dom.ls.DOMImplementationLS;
    import org.w3c.dom.ls.LSSerializer;
    import org.w3c.dom.ls.LSOutput;
    import java.io.FileOutputStream;
    import java.io.File;
    import java.io.ByteArrayOutputStream;
    import java.io.ByteArrayInputStream;
    import java.io.InputStream;
    import org.w3c.dom.DocumentType;
    import jpvm.jpvmTaskId;
    import cnslab.cnsnetwork.JpvmInfo;
    import java.util.LinkedList;
    import java.util.HashMap;
    import java.util.Set;
    import java.util.Map;
    import jpvm.jpvmEnvironment;
    import jpvm.jpvmBuffer;
    import jpvm.jpvmMessage;
    import jpvm.jpvmException;
    import java.util.Iterator;
    import java.util.Timer;
    import java.util.TimerTask;

    /***********************************************************************
    * Deprecated.
    * 
    * @version
    *   $Date: 2012-08-04 20:43:22 +0200 (Sat, 04 Aug 2012) $
    *   $Rev: 104 $
    *   $Author: croft $
    * @author
    *   Yi Dong
    * @author
    *   David Wallace Croft
    ***********************************************************************/
@Deprecated
public class FitFun
{
  public FitFun(ParToDoc parDoc, Seed idum, String modelFilename, int heapSize )
  {
    try {
      this.parDoc = parDoc;
      this.seedInt = idum.seed;
      this.saveInt = idum.seed;
      this.idum = idum;

      //spawn all the slaves and ready for the simulation   
      pas = new SimulatorParser (
        idum, new File ( "model/" + modelFilename ) );
     
      pas.parseMapCells();
      pas.parseExperiment();
      exp = pas.experiment; // experiment infomation
      pas.findMinDelay();

      pas.document.removeChild(pas.documentType);
      DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
      DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
      LSSerializer writer = impl.createLSSerializer();
      LSOutput output = impl.createLSOutput();
      ByteArrayOutputStream bArray = new ByteArrayOutputStream();
      output.setByteStream(bArray);
      writer.write(pas.document, output);
      pas.document.appendChild(pas.documentType);
      byte [] ba = bArray.toByteArray();


      info= new JpvmInfo();
      // Enroll in the parallel virtual machine...
      info.jpvm = new jpvmEnvironment();

      // Get my task id...
      info.myJpvmTaskId = info.jpvm.pvm_mytid();
      System.out.println("Task Id: "+info.myJpvmTaskId.toString());

      info.numTasks= pas.parallelHost; // total number of trial hosts;
      info.idIndex=info.numTasks; // root id; //not used at all
      info.tids = new jpvmTaskId[info.numTasks];

      // Spawn some  trialHosts
      info.jpvm.pvm_spawn("cnslab.cnsnetwork.TrialHost",info.numTasks,info.tids,48);
      System.out.println("spawn successfully");

      jpvmBuffer buf2 = new jpvmBuffer();

      buf2.pack(info.numTasks);
      buf2.pack(info.tids, info.numTasks, 1);
      buf2.pack(pas.minDelay);

      info.jpvm.pvm_mcast(buf2,info.tids,info.numTasks,NetMessageTag.sendTids);

      System.out.println("All sent");

      //gernerate all the nethosts :
      info.endIndex=pas.layerStructure.nodeEndIndices;

      netTids =  new jpvmTaskId [info.numTasks][info.endIndex.length] ;

      for(int i = 0 ; i < info.numTasks; i++)
      {
        System.out.println("generate child for trialHost "+i);
        info.jpvm.pvm_spawn("cnslab.cnsnetwork.NetHostTune",info.endIndex.length,netTids[i],heapSize); //Net Host is to seperate large network into small pieces;
        jpvmBuffer buf = new jpvmBuffer();
        buf.pack(info.endIndex.length);
        buf.pack(netTids[i],info.endIndex.length,1);
        buf.pack(info.endIndex,info.endIndex.length,1);
        seedInt = seedInt - info.endIndex.length;
        buf.pack(seedInt);
        buf.pack(info.tids[i]); //parent's tid;
        buf.pack(ba.length);
        buf.pack(ba, ba.length, 1);
        info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.sendTids);
      }

      for(int i = 0 ; i < info.numTasks; i++)
      {
        jpvmBuffer buf = new jpvmBuffer();
        buf.pack(info.endIndex.length);
        buf.pack(netTids[i],info.endIndex.length,1);
        info.jpvm.pvm_send(buf,info.tids[i],NetMessageTag.sendTids2); //send trial Host the child tids
      }


      //Barrier Sync
      for (int i=0;i<info.numTasks*info.endIndex.length+info.numTasks; i++) {
        // Receive a message...
        jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
        // Unpack the message...
        String str = message.buffer.upkstr();
        System.out.println(str);
      }
      jpvmBuffer buf = new jpvmBuffer();

    }
    catch(jpvmException ex) {
      ex.printStackTrace();
    }catch(Exception ap)
    {
      ap.printStackTrace();
    }
  }

  public Seed idum;

  public double valueSd; // value's standard deviation

  public  int trialId = -1;

  public  int expId = 0;

  public int saveInt;


  public ParToDoc parDoc;

  public Experiment exp;

  public JpvmInfo info;
  public int [] endIndex;

  public double minDelay;
  public double backFire;
  public int seedInt;

  public jpvmTaskId[][] netTids;

  public LinkedList[] intraReceiver ;
  public RecorderData rdata;


  public SimulatorParser pas;

  public void closeHosts() {
    try {

      for (int i=0;i<info.numTasks; i++) {
        jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
        int freeId = message.buffer.upkint();
        System.out.println("Trial Host "+freeId+" is killed");
        jpvmBuffer buf = new jpvmBuffer();
        buf.pack(0);
        info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
      }
    }
    catch(jpvmException ex) {
      ex.printStackTrace();
    }
  }

  public class ToDo  {
    Timer timer;

    public ToDo ( int seconds )   {
      timer = new Timer (  ) ;
      timer.schedule ( new ToDoTask (  ) , seconds*1000) ;
    }


    class ToDoTask extends TimerTask  {
      public void run (  )   {
        if(trialId == netTids.length-1) //no new trial is finished
        {
          jpvmBuffer buf = new jpvmBuffer();
          try {
            info.jpvm.pvm_send(buf,info.tids[0],NetMessageTag.checkTime); //check first trial host time
          }
          catch(jpvmException ex) {
            ex.printStackTrace();
            System.exit(-1);
          }
        }
        timer.cancel() ;
      }
    }

    public void stop()
    {
      timer.cancel() ;
    }
  }

  /**
   * @see cnslab.cnsnetwork.FitFun#fitFunction(double[]) fitFunction
   */
  public double fitFunction(double[] para) {
    double tmpValue=0.0;
    try {
      //initialized parameters
      intraReceiver = new LinkedList[exp.recorder.intraEle.size()*exp.subExp.length];
      rdata = new RecorderData();


      //first change the connections for the netHost
      Document newDoc = parDoc.getDocument(para);
      NodeList conns = pas.rootElement.getElementsByTagName("Connections");

      System.out.println("connections num"+conns.getLength());

      pas.rootElement.removeChild(conns.item(0));
      Node dup = pas.document.importNode(newDoc.getDocumentElement() , true);
      pas.rootElement.appendChild(dup);

      pas.document.normalizeDocument(); //expand everything like save and load

      pas.document.removeChild(pas.documentType);
      DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
      DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
      LSSerializer writer = impl.createLSSerializer();
      LSOutput output = impl.createLSOutput();
      ByteArrayOutputStream bArray = new ByteArrayOutputStream();
      output.setByteStream(bArray);
      writer.write(pas.document, output);
      pas.document.appendChild(pas.documentType);
      byte [] ba = bArray.toByteArray();



      seedInt = saveInt;
      for (int i=0;i<info.numTasks; i++) {
        // Receive a message...
        jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
        int freeId = message.buffer.upkint();
        System.out.println("Trial Host "+freeId+" change connection");
        seedInt = seedInt - info.endIndex.length;

        jpvmBuffer buf = new jpvmBuffer();
        buf.pack(seedInt);
        buf.pack(ba.length);
        buf.pack(ba, ba.length, 1);
        info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.changeConnection);
        // Unpack the message...
      }

      trialId = -1;
      expId = 0;
      int aliveHost = info.numTasks;
      int totalTrials=0;
      for(int j=0;j< exp.subExp.length; j++)
      {
        totalTrials += exp.subExp[j].repetition;
      }

      totalTrials = totalTrials * pas.numOfHosts;
      boolean stop = false;

      boolean badweight=false;

      //Barrier Sync
      for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
        // Receive a message...
        jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
        // Unpack the message...
        String str = message.buffer.upkstr();
        int posi;
        if((posi=str.indexOf("$")) >=0 )
        {
          badweight = true;
          double tmpval =  Double.parseDouble(str.substring(posi+1));
          if(tmpval > tmpValue) {
            tmpValue = tmpval;
            valueSd =0.0;
          }
        }
        System.out.println(str);
      }

      jpvmBuffer buf ;
      double per=0.0;
      boolean overflow=false;


      if(!badweight)
      {
        System.out.println("************ simulation is starting *********************");
        //listening and processing
        ToDo toDo = new ToDo(90); //1min's threshold
        int countHosts=0;
        while(!stop)
        {
          jpvmMessage m =  info.jpvm.pvm_recv(); //receive info from others
          switch(m.messageTag)
          {
            case NetMessageTag.checkTime: //received percentage time and reset netHosts
              int hostId = m.buffer.upkint();
              int eId = m.buffer.upkint();
              double root_time = m.buffer.upkdouble();
              per = root_time/exp.subExp[eId].trialLength*100;
              stop = true; //can stop now;
              overflow=true;
              for (int i=0;i<info.numTasks; i++) //reset all the netHosts
              {
                buf = new jpvmBuffer();
                info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.resetNetHost);
              }
              System.out.println("Overflow detected");
              break;
            case NetMessageTag.trialReady:
              int freeId = m.buffer.upkint(); //get free host id;
              trialId++;
//              System.out.println("Now trialId is "+ trialId);
              if(expId < exp.subExp.length)
              {
                if(trialId == exp.subExp[expId].repetition && expId+1 == exp.subExp.length)
                {      
                  expId++;
                }
                else if(trialId == exp.subExp[expId].repetition && expId+1 != exp.subExp.length)                                                          
                {      
                  trialId =0;
                  expId++;
                }
              }
              if(expId < exp.subExp.length// game is still on;
              {
                System.out.println("Subexp "+expId+" trial "+ trialId+" freeId "+freeId);
                buf = new jpvmBuffer();
                buf.pack(trialId);
                buf.pack(expId);
                info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.oneTrial);
              }
              else if( aliveHost != 0 ) // if all the work are done and some hosts are not killed
              {
                System.out.println("host "+freeId+" finished his job");
                //              buf = new jpvmBuffer();
                //              buf.pack(0);
                //              info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
                //aliveHost--;
                buf = new jpvmBuffer();
                buf.pack(0);
                info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.tempStopSig);                                         
              }
              break;
            case NetMessageTag.getBackData:
              countHosts++;
              if(countHosts== info.endIndex.length) {aliveHost--;countHosts=0;}
              //System.out.println("Receiving data from ");
              RecorderData spikes = (RecorderData)m.buffer.upkcnsobj();
//              System.out.print("Receiving data from "+m.buffer.upkint()+" and alive hosts"+aliveHost+"\n");
              //comibne for single unit
              Set entries = spikes.receiver.entrySet();
              Iterator entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, LinkedList<Double> > entry = (Map.Entry<String, LinkedList<Double> >)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                LinkedList<Double> value = entry.getValue()// Get the value.
                LinkedList<Double> tmp = rdata.receiver.get(key);
                if(tmp == null)
                {
                  rdata.receiver.put(key, tmp=(new LinkedList<Double>()));
                }
                //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
                tmp.addAll(value); //put received info into memory
              }
              ///comibne for multi unit
              entries = spikes.multiCounter.entrySet();
              entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                Integer value = entry.getValue()// Get the value.
                Integer tmp = rdata.multiCounter.get(key);
                if(tmp == null)
                {
                  rdata.multiCounter.put(key, tmp=(new Integer(0)));
                }
                tmp+=value;
                rdata.multiCounter.put(key, tmp);
              }
              entries = spikes.multiCounterAll.entrySet();
              entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                Integer value = entry.getValue()// Get the value.
                Integer tmp = rdata.multiCounterAll.get(key);
                if(tmp == null)
                {
                  rdata.multiCounterAll.put(key, tmp=(new Integer(0)));
                }
                tmp+=value;
                rdata.multiCounterAll.put(key, tmp);
              }
              //combine for field ele
              entries = spikes.fieldCounter.entrySet();
              entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                //              System.out.println(key);
                Integer value = entry.getValue()// Get the value.
                Integer tmp = rdata.fieldCounter.get(key);
                if(tmp == null)
                {
                  rdata.fieldCounter.put(key, tmp=(new Integer(0)));
                }
                tmp+=value;
                rdata.fieldCounter.put(key, tmp);
              }
              //combine for vector ele
              entries = spikes.vectorCounterX.entrySet();
              entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                Double value = entry.getValue()// Get the value.
                Double tmp = rdata.vectorCounterX.get(key);
                if(tmp == null)
                {
                  rdata.vectorCounterX.put(key, tmp=(new Double(0.0)));
                }
                tmp+=value;
                rdata.vectorCounterX.put(key, tmp);
              }
              entries = spikes.vectorCounterY.entrySet();
              entryIter = entries.iterator();
              while (entryIter.hasNext()) {
                Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
                String key = entry.getKey()// Get the key from the entry.
                Double value = entry.getValue()// Get the value.
                Double tmp = rdata.vectorCounterY.get(key);
                if(tmp == null)
                {
                  rdata.vectorCounterY.put(key, tmp=(new Double(0.0)));
                }
                tmp+=value;
                rdata.vectorCounterY.put(key, tmp);
              }
              break;
            case NetMessageTag.trialDone:
              totalTrials--;
              int res_trial = m.buffer.upkint();
              int res_exp = m.buffer.upkint();
              //  System.out.println("R: "+"E"+res_exp+"T"+res_trial);
              //RecordBuffer spikes = (RecordBuffer)m.buffer.upkcnsobj();
              IntraRecBuffer intra = (IntraRecBuffer)m.buffer.upkcnsobj();
              /*
              //      System.out.println(intra);

              Iterator<NetRecordSpike> iter_spike = spikes.buff.iterator();
              //          System.out.println("spikes"+spikes.buff.size());
              while(iter_spike.hasNext())
              {
              NetRecordSpike spike = iter_spike.next();

              LinkedList<Double> tmp = receiver.get("E"+res_exp+"T"+res_trial+"N"+spike.from);

              if(tmp == null)
              {
              receiver.put("E"+res_exp+"T"+res_trial+"N"+spike.from, tmp=(new LinkedList<Double>()));
              }
              //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
              tmp.add(spike.time); //put received info into memory
              //System.out.println("fire: time:"+spike.time+ " index:"+spike.from);
              }
              */

              for(int i = 0; i < intra.neurons.length; i++)
              {
                int neu = intra.neurons[i];
                int eleId = exp.recorder.intraIndex(neu);
                LinkedList<IntraInfo> info = intra.buff.get(i);
                LinkedList<IntraInfo> currList;
                if((currList=(LinkedList<IntraInfo>)intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()])!=null)
                {
                  Iterator<IntraInfo> intraData = info.iterator();
                  Iterator<IntraInfo> thisData = currList.iterator();
                  while(intraData.hasNext())
                  {
                    (thisData.next()).plus(intraData.next());
                  }
                }
                else
                {
                  intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()]=info; 
                }
              }
              break;
          }

          if( aliveHost == 0 && totalTrials==0)
          {
            stop = true;
            break;
          }
        }
        toDo.stop();

        if(overflow)
        {
          tmpValue = (100.0-per)/100.0;
          valueSd =0.0;
          //Barrier Sync
        }
        else
        {
          tmpValue = parDoc.getFitValue(pas,intraReceiver,rdata);
          double [] bootVal = new double [200]; //200 times to get Sd.
          for(int i=0; i < 200; i++)
          {
            bootVal[i] = parDoc.getRanFitValue(pas,intraReceiver,rdata,idum);
          }
          valueSd = FunUtil.sd(bootVal);
          for (int i=0;i<info.numTasks; i++) {
            buf = new jpvmBuffer();
            info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
          }
        }

      }
      else
      {
        for (int i=0;i<info.numTasks; i++) {
          buf = new jpvmBuffer();
          info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
        }
      }

      for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
        // Receive a message...
        jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
        // Unpack the message...
        String str = message.buffer.upkstr();
        System.out.println(str);
      }

      //start working for the next fitting evaluation

      System.out.println("cost:"+tmpValue+" Sd:"+valueSd);
    }
    catch(Exception ex) {
      ex.printStackTrace();
      System.exit(-1);
    }catch(jpvmException ap)
    {
      ap.printStackTrace();
      System.exit(-1);
    }
    return tmpValue;
  }
}
TOP

Related Classes of cnslab.cnsnetwork.FitFun$ToDo$ToDoTask

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.