Package cnslab.cnsnetwork

Source Code of cnslab.cnsnetwork.MiNiNeuron$State

    package cnslab.cnsnetwork;

    import java.util.*;

    import cnslab.cnsmath.*;
    import edu.jhu.mb.ernst.model.Synapse;
    import edu.jhu.mb.ernst.util.slot.Slot;
   
    /***********************************************************************
    * MN model with STDP implemented.
    *
    * See Mihalas and Niebur 2009 paper.
    *
    * @version
    *   $Date: 2012-08-28 19:15:24 +0200 (Tue, 28 Aug 2012) $
    *   $Rev: 130 $
    *   $Author: jmcohen27 $
    * @author
    *   Yi Dong
    * @author
    *   David Wallace Croft, M.Sc.
    * @author
    *   Jeremy Cohen
    ***********************************************************************/
    public final class  MiNiNeuron
      implements Neuron
    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    {
   
    private double  timeOfNextFire;
   
    private double  timeOfLastUpdate;
   
    /** whether the neuron is recordable or not */
    private boolean  record;

    /** table for STDP that stores presynaptic event time and weight */
    private Map<Synapse, TimeWeight>  histTable;
   
    /** the epsilon of a final firing time calculation */
    public static double  tEPS = 1e-12;
   
    /** the criterion for removing synapses: if a synaptic channel does not
    get an input for a long time such that its current decays below rEPS,
    it is removed from computations of the voltage and future updates up
    until it gets an input */
    public static double  rEPS = 1e-22;

    /** the cost associated with one neuron update (see D'Haene, 2009) */
    public static double  cost_Update = 10;
   
    /** the cost associated with one queue schedule (see D'Haene, 2009) */
    public static double  cost_Schedule = 1;
   
    /** the expected cost of inserting a preliminary fire estimate into the
     * event queue (as opposed to performing another NR iteration) */
    public static double  cost;

    /** running average of the time interval between input spikes */
    public double  tAvg;
   
    /** last time the neuron received a spike */
    public double  lastInputTime;
   
    /** last time the neuron fired; negative means no last spike */
    public double  lastFireTime = -1.0;
   
    /** clamp the membrane voltage during absolute refractory period */
    public double  clampVoltage;
   
    /** clamp the threshold during absolute refractory period */
    public double  clampThreshold;
   
    /** whether the next fire event is a final calculation,
     *  as opposed to a preliminary estimate (see D'Haene, 2009) */
    public boolean  fire;

    /** long imposes a constraint on the maximum number of hosts to be 64 */
    public long  tHost;
   
    /** neuron parameters */
    public MiNiNeuronPara  para;

    /** linked list for state variables */
    public TreeSet<State>  state;
   
  /**
   * two-dimensional array pointing to the state variables. The first
   * dimension corresponds to type -- NG, NB, NJ_SPIKE, and NJ_SYNAPSE, in
   * that order. The second dimension corresponds to index. The NG and NB
   * arrays have only one element.
   */
    public State [ ] [ ] sta_p;
   

    ////////////////////////////////////////////////////////////////////////
    // inner classes
    ////////////////////////////////////////////////////////////////////////
   
    public final class  State
      implements Comparable<State>
    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    {
    /**
     * values of <code>type</code> 
     */
    public final static int
                        NG         = 0,
                        NB         = 1,
                        NJ_SPIKE   = 2,
                        NJ_SYNAPSE = 3;
 
    /**
     * The time when this State was last updated.
     */
    public double  time;
  
    /**
     * The value of this State at the time when it was last updated.
     */
    public double  value;

    /**
     * The term in the thresholdDiff function that this State represents
     * (see Mihalas and Niebur, 2009, Equation 3.5.)
     */
    public int     type;
   
    /**
     * The index of this State among other States of the same type.
     */
    public int     index;

    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////

    public State (
      final double  time,
      final double  value,
      final int     type,
      final int     index)
    ////////////////////////////////////////////////////////////////////////
    {
      this.time  = time;

      this.value = value;

      this.type = type;
     
      this.index = index;
    }

    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////

    /**
     * Sorts in descending order of decay rate.
     */
    @Override
    public int  compareTo ( final State  arg0 )
    ////////////////////////////////////////////////////////////////////////
    {
        if ( para.allDecays [ this.type ] [ this.index ] <
                para.allDecays [ arg0.type ] [ arg0.index ])
        {
            return 1;
        }
        else if ( para.allDecays [ this.type ] [ this.index ] >
        para.allDecays [ arg0.type ] [ arg0.index ])
        {
            return -1;
        }

        if ( this.type < arg0.type )
        {
            return -1;
        }
        else if ( this.type > arg0.type )
        {
            return 1;
        }

        if (this.index < arg0.index)
        {
            return -1;
        }
        else if (this.index > arg0.index)
        {
            return 1;
        }

        else
        {
            return 0;
        }
    }

    @Override
    public String  toString ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return
        "time:"   + time
        + " value:" + value
        + " type:"  + type
        + " index:" + index
        + " decay:" + para.allDecays [ type ] [ index ];
    }

    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    }

    ////////////////////////////////////////////////////////////////////////
    // constructor methods
    ////////////////////////////////////////////////////////////////////////
   
    public  MiNiNeuron ( final MiNiNeuronPara  para )
    ////////////////////////////////////////////////////////////////////////
    {
      this.para = para;
    }

    ////////////////////////////////////////////////////////////////////////
    // interface Neuron accessor methods
    ////////////////////////////////////////////////////////////////////////

    /**
     * Used mainly for intracellular recording.
     *
     * @param currTime the current time.
     * @return the synaptic and spike-induced currents in this neuron at
     * time currTime.
     */
    @Override
    public double [ ]  getCurr ( final double  currTime )
    ////////////////////////////////////////////////////////////////////////
    {
      final double [ ]  out =
          new double [ sta_p [ State.NJ_SYNAPSE ].length +
                       sta_p [ State.NJ_SPIKE ].length ];
     
      for ( int  a = 0; a < sta_p [ State.NJ_SYNAPSE ].length; a++ )
      {
        State state = sta_p [ State.NJ_SYNAPSE ] [ a ];
       
        if ( state == null )
        {
          out [ a ] = 0;
        }
        else
        {
          // Convert NJ_SYNAPSE state to synapse current, and update to
          // currTime.
         
          out [ a ] = state.value * Math.exp (
            -( currTime - state.time )
            * para.allDecays [ State.NJ_SYNAPSE ] [ a ] )
            / para.SYNAPSE_CURRENT_TO_SYNAPSE_NJ [ a ];
        }
      }
     
      for ( int  a = 0; a < sta_p [ State.NJ_SPIKE ].length; a++ )
      {
        int index = sta_p [ State.NJ_SPIKE ].length + a;
       
        State state = sta_p [ State.NJ_SPIKE ] [ a ];
       
        if ( state == null )
        {
          out [ index ] = 0;
        }
        else
        {
          // Convert NJ_SPIKE state to spike current, and update to
          // currTime.
         
          out [ index ] = state.value * Math.exp (
            -( currTime - state.time )
            * para.allDecays [ State.NJ_SPIKE ] [ a ] )
            / para.SPIKE_CURRENT_TO_SPIKE_NJ [ a ];
        }
      }
     
      return out;
    }   
   
    /**
     * Used mainly for intracellular recording.
     *
     * @param currTime the current time.
     * @return the membrane voltage in this neuron at time currTime.
     */
    @Override
    public double  getMemV ( final double  currTime )
    ////////////////////////////////////////////////////////////////////////
    {
      if ( currTime < timeOfLastUpdate )
      {
        // If input comes inside the refractory period,
      // the membrane voltage remains unchanged.
       
        return membraneVoltage ( timeOfLastUpdate );
      }
      else
      {      
        return membraneVoltage (currTime);
      }
    }

    @Override
    public boolean  getRecord ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return this.record;
    }

    @Override
    public long  getTHost ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return tHost;
    }

    @Override
    public double  getTimeOfNextFire ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return this.timeOfNextFire;
    }

    @Override
    public boolean  isSensory ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return false;
    }

    @Override
    public boolean  realFire ( )
    ////////////////////////////////////////////////////////////////////////
    {
      // Whether the next fire event is a final calculation,
      // as opposed to a preliminary estimate (see D'Haene, 2009).
     
      return fire;
    }

    ////////////////////////////////////////////////////////////////////////
    // interface Neuron mutator methods
    ////////////////////////////////////////////////////////////////////////
   
    @Override
    public void  setRecord ( final boolean  record )
    ////////////////////////////////////////////////////////////////////////
    {
      this.record = record;
    }

    @Override
    public void  setTHost ( final long  id )
    ////////////////////////////////////////////////////////////////////////
    {
      this.tHost = id;
    }

    @Override
    public void  setTimeOfNextFire ( final double  timeOfNextFire )
    ////////////////////////////////////////////////////////////////////////
    {
      this.timeOfNextFire = timeOfNextFire;
    }

    ////////////////////////////////////////////////////////////////////////
    // interface Neuron lifecycle methods
    ////////////////////////////////////////////////////////////////////////
   
    @Override
    public void  init (
      final int      expid,
      final int      trialid,
      final Seed     idum,
      final Network  net,
      final int      id )
    ////////////////////////////////////////////////////////////////////////
    {
      // initialization all the initial parameters.
     
      cost = Math.log ( cost_Schedule / cost_Update + 1.0 );
     
      this.tAvg = 0.005; // initial value is 5 ms;
     
      this.lastInputTime = 0.0;
     
      this.timeOfLastUpdate = 0.0;
     
      this.timeOfNextFire = -1;
     
      double  initialThreshold = para.ini_threshold;
     
      // Pick the initial membrane voltage from a random, uniform spread.
     
      double  initialVoltage =
          para.ini_mem + para.ini_memVar * Cnsran.ran2 ( idum );

      double [ ]  initialSpikeCurrents =
          new double [ para.ini_spike_curr.length ];
     
      for ( int  a = 0; a < initialSpikeCurrents.length; a++ )
        initialSpikeCurrents [ a ] = para.ini_spike_curr [ a ];
     
      // Set up the history table to moniter of synaptic activity for STDP
     
      histTable = new HashMap<Synapse,TimeWeight> ( );
     
      // Last time the neuron fired.
     
      this.lastFireTime = -1.0;
     
      // Set up state variables.

      // Create linked table for states.
     
      state = new TreeSet<State> ( );
     
      sta_p = new State [ 4 ] [ ];
     
      // Create the NG term.
     
      double ngTerm = para.NG_BASE;
     
      ngTerm += initialVoltage * para.VOLTAGE_TO_NG;
     
      for (int a = 0; a < initialSpikeCurrents.length; a ++)
      {
          ngTerm +=
              initialSpikeCurrents [ a ] * para.SPIKE_CURRENT_TO_NG [ a ];
      }     
     
      sta_p [ State.NG ] = new State [ ] {
                  new State ( 0, ngTerm, State.NG, 0 )
              };
     
      state.add ( sta_p [ State.NG ] [ 0 ] );   
     
      // Create the NB term.
     
      double nbTerm = para.NB_BASE;
     
      nbTerm += initialThreshold * para.THRESHOLD_TO_NB;
      nbTerm += initialVoltage * para.VOLTAGE_TO_NB;
     
      sta_p [ State.NB ] = new State [ ] {
              new State ( 0, nbTerm, State.NB , 0 )
      };
     
      state.add ( sta_p [ State.NB ] [ 0 ] );
     
      // Create the NJ_SPIKE terms.

      sta_p [ State.NJ_SPIKE ] =
          new State [ para.allDecays [ State.NJ_SPIKE ].length ];
     
      for ( int  i = 0; i < sta_p [ State.NJ_SPIKE ].length; i++ )
      {
        double njTerm =
          initialSpikeCurrents [ i ] * para.SPIKE_CURRENT_TO_SPIKE_NJ [ i ];
       
        sta_p [ State.NJ_SPIKE ] [ i ] =
                new State ( 0 , njTerm , State.NJ_SPIKE , i );
       
        state.add ( sta_p [ State.NJ_SPIKE ] [ i ] );
      }
     
      // Create the NJ_SYNAPSE terms.
     
      sta_p [ State.NJ_SYNAPSE ] =
          new State [ para.allDecays [ State.NJ_SYNAPSE ].length ];
     
      for ( int  i = 0; i < sta_p [ State.NJ_SYNAPSE ].length; i++ )
      {
        sta_p [ State.NJ_SYNAPSE ] [ i ] =
                new State ( 0 , 0 , State.NJ_SYNAPSE , i );
       
        state.add ( sta_p [ State.NJ_SYNAPSE ] [ i ] );
      }

      // Schedule the first fire event, if one should exist.
     
      // Use a modified Newton-Raphson (NR) root-finding iteration method to
      // find the first zero of the thresholdDiff function.

      double  deltaT, thresholdDiff;
     
      double  deriv = safeDerivative ( 0.0 );
     
      // Ever-increasing estimate of the absolute time of the next fire
      // event.
     
      double  nextTime = ( -( initialVoltage - initialThreshold ) / deriv );
     
      deltaT = nextTime;
           
      while ( deriv > 0 && nextTime < 1.0
        && !( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) )
      {
      // Safe derivative of the thresholdDiff function at time = nextTime.
       
        deriv = safeDerivative ( nextTime );
       
        // Value of the thresholdDiff function at time = nextTime.
       
        thresholdDiff = thresholdDiff ( nextTime );
       
        // As per the NR algorithm.
       
        deltaT = ( -thresholdDiff / deriv );
       
        nextTime += deltaT;
      }

      // If the derivative is decreasing or if the algorithm has gone far
      // enough, conclude that this neuron will not fire.
     
      if ( deriv < 0 || nextTime > 1.0 )
      {
        fire = false;
       
        nextTime = -1.0;
      }
     
      // If deltaT is within the proper precision, schedule an actual fire
      // event.

      else if ( ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) )
      {
        fire = true;
      }
      else
      {
        // This shouldn't happen.
      }

      if ( nextTime < 0 )
      {
      // A negative value for nextTime means that this neuron will not fire.
      }
      else
      {
        final Slot<FireEvent>  fireEventSlot = net.getFireEventSlot ( );
       
        if ( net.getClass ( ).getName().equals (
          "cnslab.cnsnetwork.ANetwork" ) )
        {
          fireEventSlot.offer (
            new AFireEvent (
              id,
              nextTime,
              net.info.idIndex,
              ( ( ANetwork ) net ).aData.getId ( ( ANetwork ) net ) ) );
        }
        else if ( net.getClass ( ).getName ( ).equals (
          "cnslab.cnsnetwork.Network" ) )
        {
          fireEventSlot.offer ( new FireEvent ( id, nextTime ) );
        }
        else
        {
          throw new RuntimeException (
            "Other Network Class doesn't exist" );
        }
       
        timeOfNextFire = nextTime;
      }
    }

    @Override
    public double  updateFire ( )
    ////////////////////////////////////////////////////////////////////////
    {     
    /*
     * I.  If this is a real fire event:
     *      A.  Update spike time dependent plasticity variables.
     *      B.  Update the state variables to time at the end of the
     *          refractory period.
     *      C.  Reset the membrane voltage, membrane voltage threshold, and
     *          spike-induced currents according to the reset rules of the
     *          Mihalas-Niebur model.
     * II. Schedule the next fire event.
     */
      // The time to start looking for the next fire event.
     
      double  baselineTime;
     
      // The voltage and threshold at baselineTime.
     
      double  nowVoltage, nowThreshold;
     
      if ( fire ) // Only proceed if this is an actual fire event.
      {
       
      // Update STDP variables.
       
        for ( final Map.Entry<Synapse,TimeWeight>  entry
          : histTable.entrySet ( ) )
        {
          final Synapse  syn = entry.getKey ( );
         
          // Get the relative weight.
         
          final TimeWeight  tw = entry.getValue();

          // Channel 0 has STDP.
         
          if ( syn.getType ( ) == 0 )
          {
            // LTP only for close spikes.
           
            if ( lastFireTime < tw.time )
            {
              // Update the weight.
             
              tw.weight = tw.weight
                * ( 1 + para.Alpha_LTP
                  * Math.exp ( -para.K_LTP
                    * ( timeOfNextFire-tw.time ) ) );
            }
          }
        }

        // Store the old neuron fire time.
       
        lastFireTime = timeOfNextFire;
       
        // Back up the current threshold.
       
        final double  thresholdBackup
          = membraneVoltage ( timeOfNextFire )
          - thresholdDiff ( timeOfNextFire );
       
        // Update the G term to the time at the end of the refractory period.
                       
        sta_p [ State.NG ] [ 0 ].value *= Math.exp (
          -( timeOfNextFire - sta_p [ State.NG ] [ 0 ].time + para.ABSREF )
          * para.allDecays [ State.NG ] [ 0 ]);
       
        sta_p [ State.NG ] [ 0 ].time = timeOfNextFire + para.ABSREF;
       
        // Update the B term to the time at the end of the refractory period.
       
        sta_p [ State.NB ] [ 0 ].value *= Math.exp (
          -( timeOfNextFire - sta_p [ State.NB ] [ 0 ].time + para.ABSREF )
          * para.allDecays [ State.NB ] [ 0 ]);
       
        sta_p [ State.NB ] [ 0 ].time = timeOfNextFire + para.ABSREF;

        // Reset the membrane voltage and voltage threshold.
       
        // Calculate what the voltage and threshold will be
        // at the end of the refractory period.
       
        final double
          futureVoltage = membraneVoltage ( timeOfNextFire + para.ABSREF );
               
        final double
          futureThresholdGap = thresholdDiff (timeOfNextFire + para.ABSREF);
       
        final double futureThreshold = futureVoltage - futureThresholdGap;
       
        // Calculate the reset voltage and threshold, according
        // to the reset rules of the model.
       
        final double resetVoltage = para.VRESET;
       
        final double resetThreshold =
                Math.max(para.RRESET, thresholdBackup + para.THRESHOLDADD);
       
        // Calculate the necessary offsets.
       
        final double  voltageOffset = resetVoltage - futureVoltage;
               
        final double  thresholdOffset = resetThreshold - futureThreshold;
       
        // Incorporate the offsets into the state variables.

        sta_p [ State.NG ] [ 0 ].value +=
            voltageOffset * para.VOLTAGE_TO_NG;
       
        sta_p [ State.NB ] [ 0 ].value +=
            voltageOffset * para.VOLTAGE_TO_NB;
       
        sta_p [ State.NB ] [ 0 ].value +=
            thresholdOffset * para.THRESHOLD_TO_NB;
       
        // Clamp the voltage and threshold to the reset values.

        clampVoltage = resetVoltage;
       
        clampThreshold = resetThreshold;
       
        // Reset the spike-induced currents.
       
        for (int  a = 0; a < para.allDecays [ State.NJ_SPIKE ].length; a++)
        {
          // Reset a current only if its reset rule is not the identity.
         
          if ( para.SPIKE_RATIO [ a ] != 1.0 || para.SPIKE_ADD [ a ] != 0 )
          {           
            // Update the current to the end of the refractory period.
           
            // If the current is inactive, initialize its state variable.
           
            if ( sta_p [ State.NJ_SPIKE ] [ a ] == null )
            {             
              sta_p [ State.NJ_SPIKE ] [ a ] = new State (
                timeOfNextFire + para.ABSREF,
                0,
                State.NJ_SPIKE,
                a);
             
              state.add ( sta_p [ State.NJ_SPIKE ] [ a ]);
            }
            else
            {
              sta_p [ State.NJ_SPIKE ] [ a ].value *= Math.exp (
                -( timeOfNextFire -
                    sta_p [ State.NJ_SPIKE ] [ a ].time + para.ABSREF )
                * para.allDecays [ State.NJ_SPIKE ] [ a ] );
             
              sta_p [ State.NJ_SPIKE ] [ a ].time =
                  timeOfNextFire + para.ABSREF;
            }

            // Calculate what this current will be
            // at the end of the refractory period.
           
            final double  futureCurrent =
                sta_p [ State.NJ_SPIKE ] [ a ].value
                    / para.SPIKE_CURRENT_TO_SPIKE_NJ [ a ];
           
            // Calculate the reset current, according to the reset rule
            // of the model.
                       
            final double  resetCurrent = 
                ( para.SPIKE_RATIO [ a ]* futureCurrent
                    * Math.exp ( para.ABSREF *
                  para.allDecays [ State.NJ_SPIKE ] [ a ] )
              + para.SPIKE_ADD [ a ] )
              * Math.exp ( -para.ABSREF *
                  para.allDecays [ State.NJ_SPIKE ] [ a ] );
           
            // Calculate the necessary offset.
           
            final double  currentOffset = resetCurrent - futureCurrent;
           
            // Incorporate the offset into the state variables.
           
            sta_p [ State.NG ] [ 0 ].value +=
                currentOffset * para.SPIKE_CURRENT_TO_NG [ a ];
           
            sta_p [ State.NB ] [ 0 ].value +=
                currentOffset * para.SPIKE_CURRENT_TO_NB [ a ];
           
            sta_p [ State.NJ_SPIKE ] [ a ].value +=
                currentOffset * para.SPIKE_CURRENT_TO_SPIKE_NJ [ a ];
          }
        }
       
        timeOfLastUpdate = timeOfNextFire + para.ABSREF;
       
        baselineTime  = timeOfNextFire + para.ABSREF;
       
        nowVoltage = resetVoltage;
       
        nowThreshold = resetThreshold;
      }
      else // If this fire event was only a preliminary prediction.
      {
        baselineTime = timeOfNextFire;
       
        nowVoltage = thresholdDiff(baselineTime);
       
        nowThreshold = 0;
      }
     
      // Schedule the next fire event, if one should exist.
     
      // Use a modified Newton-Raphson (NR) root-finding iteration method to
      // find the first zero of the thresholdDiff function.
     
      double  deltaT, thresholdDiff;
     
      double  deriv = safeDerivative ( baselineTime );
     
      // Ever-increasing estimate of the time until the next fire event.
     
      double  nextTime = ( -( nowVoltage - nowThreshold ) ) / deriv;
     
      deltaT = nextTime;
     
      // Repeat the NR iteration until either the derivative turns negative
      // or the predicted fire time grows so far away that an input spike
      // is likely to arrive in the intervening time, rendering any further
      // calculations useless (see D'Haene, 2009).

      while ( deriv > 0 && nextTime < cost * tAvg )
      {
        // Safe derivative of the thresholdDiff function.
       
        deriv = safeDerivative ( baselineTime + nextTime );
       
        // Value of the thresholdDiff function.
       
        thresholdDiff = thresholdDiff ( baselineTime + nextTime );
       
        // As per the NR algorithm.
       
        deltaT = ( -thresholdDiff / deriv );
       
        // If deltaT is within the proper precision, schedule an actual
        // fire event.
       
        if ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS )
        {
          fire = true;
                  
          return nextTime + baselineTime - timeOfNextFire;
        }
       
        nextTime += deltaT;
      }
     
      // If the derivative is decreasing or if the algorithm has gone far
      // enough, conclude that this neuron will not fire.

      if ( deriv < 0 || nextTime > 1.0 )
      {
        fire = false;
       
        return -1.0;
      }
     
      // Otherwise, schedule a preliminary estimate fire event.
     
      else
      {
        fire = false;
       
        return  nextTime + baselineTime - timeOfNextFire;
      }
    }

    @Override
    public double  updateInput (
      final double   time,
      final Synapse  input )
    ////////////////////////////////////////////////////////////////////////
    {
    /*
     * I.   Update spike time dependent plasticity variables.
     * II.  Update the appropriate state variables to the current time.
     * III. Add the current from the input synapse into the appropriate
     *      state variables.
     * IV.  Schedule the next fire event.
     */
     
      // Update STDP variables.
         
      if ( histTable.containsKey ( input ) )
      {
        // If this synapse is already in the table, just update its value.
       
        TimeWeight  tw = histTable.get ( input );
       
        if ( input.getType ( ) == 0 ) // Channel 0 has STDP.
        {
          if ( lastFireTime > tw.time ) // LTD only for close spikes.
          {
            // Update weight if neuron fired before.
           
            tw.weight = tw.weight * ( 1 - para.Alpha_LTD * Math.exp (
              -para.K_LTD * ( time - lastFireTime ) ) );
          }
        }
       
        tw.time = time; // Update the time.
      }
      else // If this synapse did not fire before, add it to the table.
      {
        TimeWeight  tw = new TimeWeight ( time, 1.0 );
       
        if ( input.getType ( ) == 0 ) // Channel 0 has STDP.
        {
          if ( lastFireTime > 0.0 )
          {
            // Update weight if neuron fired before.
           
            tw.weight = tw.weight * ( 1 - para.Alpha_LTD * Math.exp (
              -para.K_LTD * ( time - lastFireTime ) ) );
          }
        }
       
        // Put the default weight into the history table.
       
        histTable.put ( input, tw );
      }

      // Update the running average of mean time intervals in between spikes.
     
      tAvg = tAvg * 0.8 + ( time - lastInputTime ) *0.2;
     
      // Store the time of this spike.
     
      lastInputTime = time;

      // Update the G term to the current time.
     
      sta_p [ State.NG ] [ 0 ].value *= Math.exp (
        -( time - sta_p [ State.NG ] [ 0 ].time )
        * para.allDecays [ State.NG ] [ 0 ]);

      sta_p [ State.NG ] [ 0 ].time = time;
     
      // Update the B term to the current time.
     
      sta_p [ State.NB ] [ 0 ].value *= Math.exp (
        -( time - sta_p [ State.NB ] [ 0 ].time )
        * para.allDecays [ State.NB ] [ 0 ]);
     
      sta_p [ State.NB ] [ 0 ].time = time;

      // Update the appropriate current to the current time.
     
      // If the appropriate current's channel is inactive, initialize its
      // state variable.
     
      int channel = input.getType ( );
     
      if ( sta_p [ State.NJ_SYNAPSE ] [ channel ] == null )
      {  
        sta_p [ State.NJ_SYNAPSE ] [ channel ] =
            new State ( time, 0, State.NJ_SYNAPSE, channel);
       
        state.add ( sta_p [ State.NJ_SYNAPSE ] [ channel ] );
      }
      else
      {
        sta_p [ State.NJ_SYNAPSE ] [ channel ].value *=
                Math.exp ( -( time -
                    sta_p [ State.NJ_SYNAPSE ] [ channel ].time )
            * para.allDecays [ State.NJ_SYNAPSE ] [ channel ] );
       
        sta_p [ State.NJ_SYNAPSE ] [ channel ].time = time;
      }

      // Add the input spike to the G term state variable.
     
      sta_p [ State.NG ] [ 0 ].value +=
          input.getWeight ( ) * ( histTable.get ( input ).weight )
        * para.SYNAPSE_CURRENT_TO_NG [ channel ];
     
      // Add the input spike to the B term state variable.
     
      sta_p [ State.NB ] [ 0 ].value +=
          input.getWeight ( ) * ( histTable.get ( input ).weight )
        * para.SYNAPSE_CURRENT_TO_NB [ channel ];
     
      // Add the input spike to the appropriate current's state variable.
     
      sta_p [ State.NJ_SYNAPSE ] [ 0 ].value
      += input.getWeight ( )
          * histTable.get ( input ).weight
          * para.SYNAPSE_CURRENT_TO_SYNAPSE_NJ [ channel ];
     
      // The time to start looking for the next fire event.

      double baselineTime;

      // The voltage and threshold at baselineTime.
           
      double nowVoltage = membraneVoltage(time);
     
      double nowThreshold = nowVoltage - thresholdDiff(time);
     
      // If this neuron is still within a refractory period, offset the
      // membrane voltage and threshold such that by the end of the
      // refractory period, the voltage will be equal to the clamp voltage.
     
      if ( time < timeOfLastUpdate )
      {

        // Calculate what the voltage and threshold will be
        // at the end of the refractory period.
       
        double futureVoltage = membraneVoltage ( timeOfLastUpdate );
       
        double futureThresholdGap = thresholdDiff  ( timeOfLastUpdate );
       
        double futureThreshold = futureVoltage - futureThresholdGap;
       
        // Calculate the necessary offsets.
       
        double voltageOffset = clampVoltage - futureVoltage;
       
        double thresholdOffset = clampThreshold - futureThreshold;
       
        // Incorporate the offsets into the state variables.
       
        double ngDecay = Math.exp ( ( timeOfLastUpdate - time )
            * para.allDecays [ State.NG ] [ 0 ] );
       
        double nbDecay = Math.exp ( ( timeOfLastUpdate - time )
            * para.allDecays [ State.NB ] [ 0 ] );
       
        sta_p [ State.NG ] [ 0 ].value += para.VOLTAGE_TO_NG * voltageOffset
            * ngDecay;
       
        sta_p [ State.NB ] [ 0 ].value += para.VOLTAGE_TO_NB * voltageOffset
            * nbDecay;
       
        sta_p [ State.NB ] [ 0 ].value += para.THRESHOLD_TO_NB
            * thresholdOffset * nbDecay;       
       
        nowVoltage = clampVoltage;
       
        nowThreshold = clampThreshold;
       
        baselineTime = timeOfLastUpdate;
      }
      else
      {
        clampVoltage = Double.MAX_VALUE;
       
        clampThreshold = Double.MAX_VALUE;
       
        timeOfLastUpdate = time;
       
        baselineTime = time;
      }
     
      // Schedule the next fire event, if one should exist.
     
      // Use a modified Newton-Raphson (NR) root-finding iteration method to
      // find the first zero of the thresholdDiff function.

      double  deltaT, thresholdDiff;
     
      double  deriv = safeDerivative ( baselineTime );
     
      // Ever-increasing estimate of the time until the next fire event.

      double nextTime = ( -( nowVoltage - nowThreshold ) / deriv );
     
      deltaT = nextTime;
     
      // Repeat the NR iteration until either the derivative turns negative
      // or the predicted fire time grows so far away that an input spike
      // is likely to arrive in the intervening time, rendering any further
      // calculations useless (see D'Haene, 2009).
             
      while ( deriv > 0 && nextTime < cost * tAvg )
      {

          // Safe derivative of the thresholdDiff function.

          deriv = safeDerivative ( baselineTime + nextTime );

          // Value of the thresholdDiff function.

          thresholdDiff = thresholdDiff ( baselineTime + nextTime );

          // As per the NR algorithm.

          deltaT = ( -thresholdDiff / deriv );

          // If deltaT is within the proper precision, schedule an actual
          // fire event.

          if ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS )
          {
              fire = true;

              return ( time > timeOfLastUpdate
                      ? nextTime : nextTime + timeOfLastUpdate - time );
          }

          nextTime += deltaT;
      }

      // If the derivative is decreasing or if the algorithm has gone far
      // enough, conclude that this neuron will not fire.

      if( deriv < 0 || nextTime > 1.0 )
      {       
          fire = false;

          return -1.0;
      }

      // Otherwise, schedule a preliminary estimate fire event.

      else
      {
          fire = false;

          return ( time > timeOfLastUpdate
                  ? nextTime : nextTime + timeOfLastUpdate - time );
      }
    }


    ////////////////////////////////////////////////////////////////////////
    // accessor methods
    ////////////////////////////////////////////////////////////////////////

    public double  getSensoryWeight ( )
    ////////////////////////////////////////////////////////////////////////
    {
      throw new RuntimeException ( "This neuron type doesn't use the "
        + "Sensory Weight Functions!" );
    }

    public double  getTimeOfLastUpdate ( )
    ////////////////////////////////////////////////////////////////////////
    {
      return this.timeOfLastUpdate;
    }


    ////////////////////////////////////////////////////////////////////////
    // mutator methods
    ////////////////////////////////////////////////////////////////////////
   
    /***********************************************************************
    * set the membrane voltage
    ***********************************************************************/
    public void  setMemV ( double  memV )
    ////////////////////////////////////////////////////////////////////////
    {
      return;
    }
   
    public void  setTimeOfLastUpdate ( final double  timeOfLastUpdate )
    ////////////////////////////////////////////////////////////////////////
    {
      this.timeOfLastUpdate = timeOfLastUpdate;
    }

    ////////////////////////////////////////////////////////////////////////
    // overridden Object methods
    ////////////////////////////////////////////////////////////////////////
   
    @Override
    public String  toString ( )
    ////////////////////////////////////////////////////////////////////////
    {
      String  tmp="";
     
      tmp=tmp+"Current:"+"\n";
     
      return tmp;
    }

    ////////////////////////////////////////////////////////////////////////
    // miscellaneous methods
    ////////////////////////////////////////////////////////////////////////

   
    /***********************************************************************
     * Compute the membrane voltage.
     *
     * This method transforms the thresholdDiff (Mihalas and Niebur, 2009,
     *  Equation 3.5) into the voltage ( V(t) in Equation 3.2)
     *
     * @param t
     *            absolute time t
     * @return membrane voltage at time t.
     **********************************************************************/
    public double  membraneVoltage ( final double  t )
    ////////////////////////////////////////////////////////////////////////
    {

      double  V = 0;
     
      double  constant = 0;
     
      Iterator<State>  iter = state.iterator ( );
     
      State  first = iter.next ( );
     
      State  second;
     
      double  firstV, secondV;

      if ( first.type == State.NB ) // Ignore the NB term.
      {
        if ( iter.hasNext ( ) )
        {
          first = iter.next ( );
        }
        else
        {
          return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST;
        }
      }

      if ( first.time == t ) // Ignore no time change term.
        // TODO: what does that mean?
      {
        if ( first.type == State.NG )
        {
          constant += first.value / para.VOLTAGE_TO_NG;
        }
        else if (first.type == State.NJ_SPIKE)
        {
          constant += first.value
              / para.VOLTAGE_TO_SPIKE_NJ [ first.index ];
        }
        else if (first.type == State.NJ_SYNAPSE)
        {
          constant += first.value
              / para.VOLTAGE_TO_SYNAPSE_NJ [ first.index ];
        }
       
        if ( iter.hasNext ( ) )
        {
          first = iter.next ( );
        }
        else
        {
          return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST;
        }
      }

      if ( first.type == State.NB ) // Ignore the NB term.
      {
        if ( iter.hasNext ( ) )
        {
          first = iter.next ( );
        }
        else
        {
          return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST;
        }
      }

      if ( first.type == State.NG )
      {
        firstV = first.value / para.VOLTAGE_TO_NG;
      }
      else if (first.type == State.NJ_SPIKE)
      {       
        firstV = first.value / para.VOLTAGE_TO_SPIKE_NJ [ first.index ];
      }
      else
      {       
        firstV = first.value / para.VOLTAGE_TO_SYNAPSE_NJ [ first.index ];
      }

      V = firstV;
     
      while ( iter.hasNext ( ) )
      {
        second = iter.next ( );
       
        secondV = 0.0;

        if ( second.time == t && second.type != State.NB )
        {
          if ( second.type == State.NG )
          {
            constant += second.value / para.VOLTAGE_TO_NG;
          }
          else if (second.type == State.NJ_SPIKE)
          {
            constant += second.value
                / para.VOLTAGE_TO_SPIKE_NJ [ second.index ];
          }
          else
          {
            constant += second.value
                / para.VOLTAGE_TO_SYNAPSE_NJ [ second.index ];
          }
         
          if ( iter.hasNext ( ) )
          {
            second = iter.next ( );
          }
          else
          {
            break;
          }
        }

        if ( second.type == State.NB )
        {
          if ( iter.hasNext ( ) )
          {
            second = iter.next ( );
          }
          else
          {
            break;
          }
        }

        if ( second.type == State.NG )
        {
          secondV = second.value / para.VOLTAGE_TO_NG;
        }
        else if (second.type == State.NJ_SPIKE)
        {
          secondV = second.value
              / para.VOLTAGE_TO_SPIKE_NJ [ second.index ];
        }
        else
        {
          secondV = second.value
              / para.VOLTAGE_TO_SYNAPSE_NJ [ second.index ];
        }

        V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index ]
          - para.allDecays [ second.type ] [ second.index ] ) * t
          - ( para.allDecays [ first.type ] [ first.index ] * first.time
          - para.allDecays [ second.type ] [ second.index ] * second.time )
          ) ) * V + secondV;
    
        first = second;
      }

      V = Math.exp ( -( para.allDecays [ first.type ] [ first.index ] * t
        - para.allDecays [ first.type ] [ first.index ] * first.time )) * V;
           
      return V + para.MEMBRANE_VOLTAGE_BASE + para.VREST + constant;
    }
   
    /***********************************************************************
     * Compute the difference between the membrane voltage and the membrane
     * voltage threshold. The <code>MNNeuron</code> class is optimized to
     * perform this function quickly, at the expense of
     * <code>membraneVoltage()</code>.
     *
     * This method is based on Equation 3.5 in Mihalas and Niebur, 2009.
     *
     * @param t
     *            absolute time t
     * @return voltage - threshold at time t
     **********************************************************************/
    public double  thresholdDiff ( double  t )
    ////////////////////////////////////////////////////////////////////////
    {
      double  V = 0, constant = 0.0;
     
      Iterator<State>  iter = state.iterator ( );
     
      State  first = iter.next ( );
      State  second;
     
      double  firstV, secondV;
     
      // If the state is a current channel and its value is too small, delete
      // it.

      if ( Math.abs ( first.value ) < rEPS
          && ( int ) first.type != State.NG
          && ( int ) first.type != State.NB )
      {       
        iter.remove ( );
       
        sta_p [ first.type ] [ first.index ] = null;
      }
     
      // ignore no time change term
     
      if ( first.time == t )
      {
        constant += first.value;
       
        if ( iter.hasNext ( ) )
        {
          first = iter.next ( );
        }
        else
        {
          return V + constant + para.THRESHOLD_DIFF_BASE;
        }
      }
     
      firstV = first.value;
     
      V = firstV;
     
      while ( iter.hasNext ( ) )
      {
        second = iter.next ( );
       
        // If the state is a current channel and its value is too small,
        // delete it.
       
        if ( Math.abs ( second.value ) < rEPS
            && second.type != State.NG
            && second.type != State.NB )
        {
          iter.remove ( );
         
          sta_p [ second.type ] [ second.index ] = null;
        }
       
        secondV = 0.0;
       
        // Ignore no time change term.
       
        if ( second.time == t )
        {
          constant += second.value;
         
          if ( iter.hasNext ( ) )
          {
            second = iter.next ( );
          }
          else
          {
            break;
          }
        }

        secondV = second.value;
       
        V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index ]
          - para.allDecays [ second.type ] [ second.index ] ) * t
          - ( para.allDecays [ first.type ] [ first.index ] * first.time
          - para.allDecays [ second.type ] [ second.index ] * second.time )
          ) ) * V + secondV;
       
        first = second;
      }

      V = Math.exp ( -( para.allDecays [ first.type ] [ first.index ] * t
        - para.allDecays [ first.type ] [ first.index ] * first.time )) * V;
           
      return V + para.THRESHOLD_DIFF_BASE + constant;
    }

    /***********************************************************************
    * Computes the "safe derivative" (see D'Haene, 2009) of the
    * thresholdDiff function.
    *
    * @param t
    *   absolute time t
    ***********************************************************************/
    public double  safeDerivative ( double  t )
    ////////////////////////////////////////////////////////////////////////
    {
      // The smallest inverse decay of all negative state variables.
     
      double  tauSafe = Double.MAX_VALUE;
     
      // Find the tauSafe.
     
      Iterator<State> iter = state.descendingIterator();
     
      while(iter.hasNext())
      {
        State tmpState = iter.next();

        if ( tmpState.value < 0 )
        {
          tauSafe = para.allDecays [ tmpState.type ] [ tmpState.index];
         
          break;
        }
      }

      double  constant = 0;
     
      double  V = 0;
     
      iter = state.iterator ( );
     
      State  first = iter.next ( );
     
      State  second;
     
      double  firstV, secondV;

      if ( first.time == t )
      {
        // Ignore no time change term.
       
        if ( first.value > 0 )
        {
          final double  maxDecay = Math.min (
            para.allDecays [ first.type ] [ first.index ],
            tauSafe );
         
          constant += -first.value * maxDecay;
        }
        else
        {
          constant +=
              -first.value * para.allDecays [ first.type ] [ first.index ];
        }

        if ( iter.hasNext ( ) )
        {
          first = iter.next ( );
        }
        else
        {
          return V + constant;
        }
      }

      if ( first.value > 0 )
      {
        final double  maxDecay = Math.min (
          para.allDecays [ first.type ] [ first.index ],
          tauSafe );
       
        firstV = -first.value * maxDecay;
      }
      else
      {
        firstV =
            -first.value * para.allDecays [ first.type ] [ first.index ];
      }

      V = firstV;
     
      while ( iter.hasNext ( ) )
      {
        second = iter.next ( );
       
        secondV = 0.0;

        if( second.time == t)
        {
          if ( second.value > 0 )
          {
            final double  maxDecay = Math.min (
              para.allDecays [ second.type ] [ second.index ],
              tauSafe );
           
            constant += -second.value * maxDecay;
          }
          else
          {
            constant
              += -second.value
              * para.allDecays [ second.type ] [ second.index ];
          }

          if ( iter.hasNext ( ) )
          {
            second = iter.next ( );
          }
          else
          {
            break;
          }
        }

        if ( second.value > 0 )
        {
          final double  maxDecay = Math.min (
            para.allDecays [ second.type ] [ second.index ],
            tauSafe );
         
          secondV = -second.value * maxDecay;
        }
        else
        {
          secondV =
              -second.value * para.allDecays [ second.type] [ second.index];
        }

        V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index]
          - para.allDecays [ second.type ] [ second.index] ) * t
          - ( para.allDecays [ first.type ] [ first.index ] * first.time
          - para.allDecays [ second.type ] [ second.index] * second.time )
          ) ) * V + secondV;
       
        first = second;
      }

      V = Math.exp ( -( para.allDecays [first.type ] [ first.index ] * t
        - para.allDecays [ first.type ] [ first.index ] * first.time )) * V;
     
      return V + constant;
    }

    ////////////////////////////////////////////////////////////////////////
    ////////////////////////////////////////////////////////////////////////
    }
TOP

Related Classes of cnslab.cnsnetwork.MiNiNeuron$State

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.