Package ca.nengo.model.nef.impl

Source Code of ca.nengo.model.nef.impl.NEFEnsembleImplTest

/*
* Created on 24-Apr-07
*/
package ca.nengo.model.nef.impl;

import ca.nengo.math.Function;
import ca.nengo.math.impl.AbstractFunction;
//import ca.nengo.math.impl.ConstantFunction;
import ca.nengo.model.Network;
import ca.nengo.model.Node;
import ca.nengo.model.Projection;
import ca.nengo.model.SimulationException;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Units;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NetworkImpl;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.NEFEnsembleFactory;
import ca.nengo.model.nef.impl.BiasOrigin;
//import ca.nengo.model.nef.impl.DecodedOrigin;
//import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.neuron.impl.SpikingNeuron;
import ca.nengo.plot.Plotter;
import ca.nengo.util.MU;
import ca.nengo.util.Probe;
import ca.nengo.util.TimeSeries;
import ca.nengo.util.impl.TimeSeriesImpl;
import junit.framework.TestCase;

/**
* Unit tests for NEFEnsembleImpl.
*
* TODO: this is a functional test with no failures ... convert to unit test
* TODO: make sure performance optimization works with inhibitory projections
*
* @author Bryan Tripp
*/
public class NEFEnsembleImplTest extends TestCase {

  protected void setUp() throws Exception {
    super.setUp();
  }

  public void functionalTestAddBiasOrigin() throws StructuralException, SimulationException {
    NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();

    boolean regenerate = false;
    NEFEnsemble source = ef.make("source", 300, 1, "nefeTest_source", regenerate);
    NEFEnsemble dest = ef.make("dest", 300, 1, "nefeTest_dest", regenerate);
   
    Function f = new AbstractFunction(1) {
      private static final long serialVersionUID = 1L;
      public float map(float[] from) {
        return from[0] - 1;
      }
    };
    FunctionInput input = new FunctionInput("input", new Function[]{f}, Units.UNK);
//    FunctionInput zero = new FunctionInput("zero", new Function[]{new ConstantFunction(1, 0f)}, Units.UNK);
   
    Network network = new NetworkImpl();
    network.addNode(input);
    network.addNode(source);
    network.addNode(dest);

    source.addDecodedTermination("input", MU.I(1), .005f, false); //OK
    BiasOrigin bo = source.addBiasOrigin(source.getOrigin(NEFEnsemble.X), 200, "interneurons", true); //should have -ve bias decoders
    network.addNode(bo.getInterneurons()); //should be backwards response functions
//**    bo.getInterneurons().addDecodedTermination("source", MU.I(1), .005f, false);
   
//    Plotter.plot(bo.getInterneurons());
//    Plotter.plot(bo.getInterneurons(), NEFEnsemble.X);
   
//    DecodedTermination t = (DecodedTermination) dest.addDecodedTermination("source", MU.I(1), .005f, false);
//**    BiasTermination[] bt = dest.addBiasTerminations(t, .002f, bo.getDecoders()[0][0], ((DecodedOrigin) source.getOrigin(NEFEnsemble.X)).getDecoders());
//**    bt[1].setStaticBias(-1); //creates intrinsic current needed to counteract interneuron activity at 0
   
//    float[][] weights = MU.prod(dest.getEncoders(), MU.transpose(((DecodedOrigin) source.getOrigin(NEFEnsemble.X)).getDecoders()));
//*    float[][] biasEncoders = MU.transpose(new float[][]{bt[0].getBiasEncoders()});
//*    float[][] biasDecoders = MU.transpose(bo.getDecoders());
//*    float[][] weightBiases = MU.prod(biasEncoders, biasDecoders);
//*    float[][] biasedWeights = MU.sum(weights, weightBiases);
//    Plotter.plot(weights[0], "some weights");
//    Plotter.plot(biasedWeights[0], "some biased weights");
//    Plotter.plot(weights[1], "some more weights");
   
//    Plotter.plot(bt[0].getBiasEncoders(), "bias decoders");
   
    network.addProjection(input.getOrigin(FunctionInput.ORIGIN_NAME), source.getTermination("input"));
    network.addProjection(source.getOrigin(NEFEnsemble.X), dest.getTermination("source"));
//*    network.addProjection(bo, bo.getInterneurons().getTermination("source"));
//*    network.addProjection(bo, bt[0]);
//*    network.addProjection(bo.getInterneurons().getOrigin(NEFEnsemble.X), bt[1]);
//    network.addProjection(zero.getOrigin(FunctionInput.ORIGIN_NAME), bt[1]);
   
//    Probe sourceProbe = network.getSimulator().addProbe("source", NEFEnsemble.X, true);
//    Probe destProbe = network.getSimulator().addProbe("dest", NEFEnsemble.X, true);
//    Probe interProbe = network.getSimulator().addProbe("source_X_bias_interneurons", NEFEnsemble.X, true);
   
    network.run(0, 2);
   
//    Plotter.plot(sourceProbe.getData(), "source");
//    Plotter.plot(destProbe.getData(), "dest");
//    Plotter.plot(interProbe.getData(), "interneurons");
  }
 
  public void functionalTestBiasOriginError() throws StructuralException, SimulationException {
    float tauPSC = .01f;
   
    Network network = new NetworkImpl();
   
    Function f = new AbstractFunction(1) {
      private static final long serialVersionUID = 1L;
      public float map(float[] from) {
        return from[0] - 1;
      }
    };
   
    FunctionInput input = new FunctionInput("input", new Function[]{f}, Units.UNK);
    network.addNode(input);
   
    NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();
    NEFEnsemble pre = ef.make("pre", 400, 1, "nefe_pre", false);
    pre.addDecodedTermination("input", MU.I(1), tauPSC, false);
//    DecodedOrigin baseOrigin = (DecodedOrigin) pre.getOrigin(NEFEnsemble.X);
    network.addNode(pre);
   
    NEFEnsemble post = ef.make("post", 200, 1, "nefe_post", false);
//    DecodedTermination baseTermination = (DecodedTermination) post.addDecodedTermination("pre", MU.I(1), tauPSC, false);
    network.addNode(post);
   
    network.addProjection(input.getOrigin(FunctionInput.ORIGIN_NAME), pre.getTermination("input"));
    Projection projection = network.addProjection(pre.getOrigin(NEFEnsemble.X), post.getTermination("pre"));
   
    Probe pPost = network.getSimulator().addProbe("post", NEFEnsemble.X, true);
    network.run(0, 2);
    TimeSeries ideal = pPost.getData();
    Plotter.plot(pPost.getData(), .005f, "mixed weights result");   
   
    //remove negative weights ...
    System.out.println("Minimum weight without bias: " + MU.min(projection.getWeights()));
    projection.addBias(100, .005f, tauPSC, true, false);
    System.out.println("Minimum weight with bias: " + MU.min(projection.getWeights()));
    pPost.reset();
    network.run(0, 2);
    TimeSeries diff = new TimeSeriesImpl(ideal.getTimes(), MU.difference(ideal.getValues(), pPost.getData().getValues()), ideal.getUnits());
    Plotter.plot(diff, .01f, "positive weights");
   
    projection.removeBias();
    projection.addBias(100, tauPSC/5f, tauPSC, true, true);
    pPost.reset();
    Probe pInter = network.getSimulator().addProbe("post:pre:interneurons", NEFEnsemble.X, true);
    network.run(0, 2);
    diff = new TimeSeriesImpl(ideal.getTimes(), MU.difference(ideal.getValues(), pPost.getData().getValues()), ideal.getUnits());
    Plotter.plot(diff, .01f, "positive weights optimized");
    Plotter.plot(pInter.getData(), .01f, "interneurons");

   
   
//    //remove negative weights ...
//    BiasOrigin bo = pre.addBiasOrigin(baseOrigin, 100, "interneurons", true);
//    BiasTermination[] bt = post.addBiasTerminations(baseTermination, tauPSC, bo.getDecoders()[0][0], baseOrigin.getDecoders());
//    DecodedTermination it = (DecodedTermination) bo.getInterneurons().addDecodedTermination("bias", MU.I(1), tauPSC/5f, false);
//    network.addNode(bo.getInterneurons());
//    network.addProjection(bo, bt[0]);
//    network.addProjection(bo, bo.getInterneurons().getTermination("bias"));
//    network.addProjection(bo.getInterneurons().getOrigin(NEFEnsemble.X), bt[1]);
//    Plotter.plot(MU.transpose(bo.getDecoders())[0], "bias decoders");
////    Plotter.plot(bo.getInterneurons(), NEFEnsemble.X);
//   
//    pPost.reset();
//    network.run(0, 2);
//    TimeSeries diff = new TimeSeriesImpl(ideal.getTimes(), MU.difference(ideal.getValues(), pPost.getData().getValues()), ideal.getUnits());
//    Plotter.plot(diff, .01f, "positive weights");
////    Plotter.plot(ideal, pPost.getData(), .005f, "positive weights result");
//   
//    //narrow bias range ...
////    Plotter.plot(pre, bo.getName());
//    float[][] baseWeights = MU.prod(post.getEncoders(), MU.prod(baseTermination.getTransform(), MU.transpose(baseOrigin.getDecoders())));
//    float[] encodersBeforeTweak = findBiasEncoders(baseWeights, MU.transpose(bo.getDecoders())[0]);
//    bo.optimizeDecoders(baseWeights, bt[0].getBiasEncoders());
////    Plotter.plot(pre, bo.getName());
//    float[] encodersAfterTweak = findBiasEncoders(baseWeights, MU.transpose(bo.getDecoders())[0]);
//    TestUtil.assertClose(MU.sum(MU.difference(encodersBeforeTweak, encodersAfterTweak)), 0, .0001f);
//    Plotter.plot(MU.transpose(bo.getDecoders())[0], "narrow bias decoders");
//   
//    pPost.reset();   
//    network.run(0, 2);
//    diff = new TimeSeriesImpl(ideal.getTimes(), MU.difference(ideal.getValues(), pPost.getData().getValues()), ideal.getUnits());
//    Plotter.plot(diff, .01f, "narrowed bias");    
////    Plotter.plot(ideal, pPost.getData(), .005f, "narrowed bias result");
//   
//    //optimize interneuron range ...
//    float[] range = bo.getRange();
//    System.out.println(range[0] + " to " + range[1]);
//    range[0] = range[0] - .25f * (range[1] - range[0]); //avoid distorted area near zero in interneurons
//    it.setStaticBias(new float[]{-range[0]});
//    it.getTransform()[0][0] = 1f / (range[1] - range[0]);
//    bt[1].setStaticBias(new float[]{range[0]/(range[1] - range[0])});
//    bt[1].getTransform()[0][0] = -(range[1] - range[0]);   
//   
//    pPost.reset();
//    network.run(0, 2);
//    diff = new TimeSeriesImpl(ideal.getTimes(), MU.difference(ideal.getValues(), pPost.getData().getValues()), ideal.getUnits());
//    Plotter.plot(diff, .01f, "optimized interneuron");        
////    Plotter.plot(ideal, pPost.getData(), .005f, "optimized interneuron result");
//   
////    Probe pBias = network.getSimulator().addProbe("pre", bo.getName(), true);
////    Probe pInter = network.getSimulator().addProbe(bo.getInterneurons().getName(), NEFEnsemble.X, true);
////    Probe pBT0 = network.getSimulator().addProbe("post", bt[0].getName(), true);
////    Probe pBT1 = network.getSimulator().addProbe("post", bt[1].getName(), true);
////    Probe pT = network.getSimulator().addProbe("post", "pre", true);
////   
////    network.run(0, 2);
////    Plotter.plot(pPost.getData(), .005f, "post");
////    Plotter.plot(pBias.getData(), .005f, "bias");
////    Plotter.plot(pInter.getData(), .005f, "interneurons");
////    Plotter.plot(pBT0.getData(), .005f, "BT0");
////    Plotter.plot(pBT1.getData(), .005f, "BT1");
////    Plotter.plot(pT.getData(), .005f, "base termination");
  }
 
//  private float[] findBiasEncoders(float[][] baseWeights, float[] biasDecoders) {
//    float[] biasEncoders = new float[baseWeights.length];
//   
//    for (int j = 0; j < biasEncoders.length; j++) {
//      float max = 0;
//      for (int i = 0; i < biasDecoders.length; i++) {
//        float x = - baseWeights[j][i] / biasDecoders[i];
//        if (x > max) max = x;
//      }     
//      biasEncoders[j] = max;
//    }
//   
//    return biasEncoders;
//  }
 
  public void testClone() throws StructuralException, CloneNotSupportedException {
    NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();
    NEFEnsemble ensemble = ef.make("test", 100, 1);
    long startTime = System.currentTimeMillis();
    ensemble.clone();
    System.out.println(System.currentTimeMillis() - startTime);
  }
 
  public static void main(String[] args) {
    NEFEnsembleImplTest test = new NEFEnsembleImplTest();
    try {
//      test.testAddBiasOrigin();
//      test.functionalTestBiasOriginError();
      test.testClone();
    } catch (StructuralException e) {
      e.printStackTrace();
//    } catch (SimulationException e) {
//      e.printStackTrace();
    } catch (CloneNotSupportedException e) {
      e.printStackTrace();
    }
  }
 
  public void testKillNeurons() throws StructuralException
  {
    NEFEnsembleFactoryImpl ef = new NEFEnsembleFactoryImpl();
    NEFEnsembleImpl nef1 = (NEFEnsembleImpl)ef.make("nef1", 1000, 1);
   
    nef1.killNeurons(0.0f,true);
    int numDead = countDeadNeurons(nef1);
    if(numDead != 0) {
      fail("Number of dead neurons outside expected range");
    }
   
    nef1.killNeurons(0.5f,true);
    numDead = countDeadNeurons(nef1);
    if(numDead < 400 || numDead > 600) {
      fail("Number of dead neurons outside expected range");
    }
   
    nef1.killNeurons(1.0f,true);
    numDead = countDeadNeurons(nef1);
    if(numDead != 1000) {
      fail("Number of dead neurons outside expected range");
    }
   
    NEFEnsembleImpl nef2 = (NEFEnsembleImpl)ef.make("nef2", 1, 1);
    nef2.killNeurons(1.0f,true);
    numDead = countDeadNeurons(nef2);
    if(numDead != 0)
      fail("Relay protection did not work");
    nef2.killNeurons(1.0f,false);
    numDead = countDeadNeurons(nef2);
    if(numDead != 1)
      fail("Number of dead neurons outside expected range");

  }
  private int countDeadNeurons(NEFEnsembleImpl pop)
  {
    Node[] neurons = pop.getNodes();
    int numDead = 0;
   
    for(int i = 0; i < neurons.length; i++)
    {
      SpikingNeuron n = (SpikingNeuron)neurons[i];
      if(n.getBias() == 0.0f && n.getScale() == 0.0f)
        numDead++;
    }
   
    return numDead;
  }
 
  public void testAddDecodedSignalOrigin() throws StructuralException
  {
    NEFEnsembleFactoryImpl ef = new NEFEnsembleFactoryImpl();
    NEFEnsembleImpl ensemble = (NEFEnsembleImpl)ef.make("test", 5, 1);
    float[][] vals = new float[2][1];
    vals[0][0] = 1;
    vals[1][0] = 1;
    TimeSeriesImpl targetSignal = new TimeSeriesImpl(new float[]{0,1}, vals, new Units[]{Units.UNK});
    TimeSeriesImpl[] evalSignals = new TimeSeriesImpl[1];
   
    //test the per-dimension eval signals
    evalSignals[0] = new TimeSeriesImpl(new float[]{0,1}, vals, new Units[]{Units.UNK});
    ensemble.addDecodedSignalOrigin("test1", targetSignal, evalSignals, "AXON");
    if(ensemble.getOrigin("test1") == null)
      fail("Error creating per-dimension signal origin");
   
    //test the per-node eval signals
    vals[0] = new float[]{1, 1, 1, 1, 1};
    vals[1] = new float[]{1, 1, 1, 1, 1};
    evalSignals[0] = new TimeSeriesImpl(new float[]{0,1}, vals, new Units[]{Units.UNK,Units.UNK,Units.UNK,Units.UNK,Units.UNK});
    ensemble.addDecodedSignalOrigin("test2", targetSignal, evalSignals, "AXON");
    if(ensemble.getOrigin("test2") == null)
      fail("Error creating per-node signal origin");
  }

}
TOP

Related Classes of ca.nengo.model.nef.impl.NEFEnsembleImplTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.