// error2 = getError(reference, negativeOptimal);
// assertTrue(error2 > 1e-10 && error2 < 3.5e-4 && error2 < error);
// }
public void testAddBias2D() throws StructuralException, SimulationException {
Network network = new NetworkImpl();
FunctionInput input = new FunctionInput("input", new Function[]{new IdentityFunction(1, 0)}, Units.UNK);
network.addNode(input);
NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();
int n = 300;
NEFEnsemble pre = ef.make("pre", n, 2);
pre.addDecodedTermination("input", MU.uniform(2, 1, 1), .005f, false);
network.addNode(pre);
network.addProjection(input.getOrigin(FunctionInput.ORIGIN_NAME), pre.getTermination("input"));
NEFEnsemble post = ef.make("post", n, 2);
network.addNode(post);
post.addDecodedTermination("input", MU.I(2), .01f, false);
Projection p = network.addProjection(pre.getOrigin(NEFEnsemble.X), post.getTermination("input"));
DecodedOrigin o = (DecodedOrigin) pre.getOrigin(NEFEnsemble.X);
DecodedTermination t = (DecodedTermination) post.getTermination("input");
float[][] directWeights = MU.prod(post.getEncoders(), MU.prod(t.getTransform(), MU.transpose(o.getDecoders())));
System.out.println("Direct weights: " + MU.min(directWeights) + " to " + MU.max(directWeights));
Probe probe = network.getSimulator().addProbe(post.getName(), NEFEnsemble.X, true);
network.setMode(SimulationMode.CONSTANT_RATE);
network.run(-1.5f, 1);
network.setMode(SimulationMode.DEFAULT);
float[] reference = MU.transpose(DataUtils.filter(probe.getData(), .01f).getValues())[0];
network.run(-1.5f, 1);
// Plotter.plot(probe.getData(), "mixed weights");
float[] mixed = MU.transpose(DataUtils.filter(probe.getData(), .01f).getValues())[0];
getError(reference, mixed);
p.addBias(300, .005f, .01f, true, false);
BiasOrigin bo = (BiasOrigin) pre.getOrigin("post_input");
BiasTermination bt = (BiasTermination) post.getTermination("input (bias)");
assertTrue(MU.min(getNetWeights(directWeights, bo, bt)) > -1e-10);
network.run(-1.5f, 1);
// Plotter.plot(probe.getData(), "positive non-optimal");
// float[] positiveNonOptimal = MU.transpose(DataUtils.filter(probe.getData(), .01f).getValues())[0];
// float error = getError(reference, positiveNonOptimal);
// assertTrue(error > 1e-10 && error < 5e-4);
p.removeBias();
p.addBias(300, .005f, .01f, true, true);
bo = (BiasOrigin) pre.getOrigin("post_input");
bt = (BiasTermination) post.getTermination("input (bias)");
assertTrue(MU.min(getNetWeights(directWeights, bo, bt)) > -1e-10);
network.run(-1.5f, 1);
// Plotter.plot(probe.getData(), "positive optimal");
// float[] positiveOptimal = MU.transpose(DataUtils.filter(probe.getData(), .01f).getValues())[0];
// float error2 = getError(reference, positiveOptimal);
// assertTrue(error2 > 1e-10 && error2 < 2.5e-4 && error2 < error);
p.removeBias();