bcg.set(0.2f, 1, 0);
List<Connections> connections = new ArrayList<>();
connections.add(c1);
NeuralNetworkImpl nn = new NeuralNetworkImpl();
nn.addConnections(connections.toArray(new Connections[connections.size()]));
ValuesProvider vp = TensorFactory.tensorProvider(nn, 2, true);
Matrix i1 = vp.get(nn.getInputLayer());
i1.set(1, 0, 0);
i1.set(2, 1, 0);
i1.set(3, 2, 0);
i1.set(4, 0, 1);
i1.set(5, 1, 1);
i1.set(6, 2, 1);
ConnectionCalculatorFullyConnected aws = new AparapiWeightedSumConnectionCalculator();
aws.calculate(connections, vp, ol);
// most simple case
Matrix o = vp.get(nn.getOutputLayer());
assertEquals(14, o.get(0, 0), 0);
assertEquals(32, o.get(0, 1), 0);
assertEquals(32, o.get(1, 0), 0);
assertEquals(77, o.get(1, 1), 0);
// with bias
connections = new ArrayList<>();
connections.add(c1);
connections.add(bc);
nn = new NeuralNetworkImpl();
nn.addConnections(connections.toArray(new Connections[connections.size()]));
vp = TensorFactory.tensorProvider(nn, 2, true);
i1 = vp.get(nn.getInputLayer());
i1.set(1, 0, 0);
i1.set(2, 1, 0);
i1.set(3, 2, 0);
i1.set(4, 0, 1);
i1.set(5, 1, 1);
i1.set(6, 2, 1);
aws = new AparapiWeightedSumConnectionCalculator();
aws.calculate(connections, vp, ol);
o = vp.get(nn.getOutputLayer());
assertEquals(14.1, o.get(0, 0), 0.01);
assertEquals(32.1, o.get(0, 1), 0.01);
assertEquals(32.2, o.get(1, 0), 0.01);
assertEquals(77.2, o.get(1, 1), 0.01);
// combined layers
connections = new ArrayList<>();
connections.add(c1);
connections.add(c2);
connections.add(bc);
nn = new NeuralNetworkImpl();
nn.addConnections(connections.toArray(new Connections[connections.size()]));
vp = TensorFactory.tensorProvider(nn, 2, true);
i1 = vp.get(il1);
i1.set(1, 0, 0);
i1.set(2, 1, 0);
i1.set(3, 2, 0);
i1.set(4, 0, 1);
i1.set(5, 1, 1);
i1.set(6, 2, 1);
Matrix i2 = vp.get(il2);
i2.set(1, 0, 0);
i2.set(2, 1, 0);
i2.set(3, 2, 0);
i2.set(4, 0, 1);
i2.set(5, 1, 1);
i2.set(6, 2, 1);
aws = new AparapiWeightedSumConnectionCalculator();
aws.calculate(connections, vp, ol);
o = vp.get(nn.getOutputLayer());
assertEquals(28.1, o.get(0, 0), 0.01);
assertEquals(64.1, o.get(0, 1), 0.01);
assertEquals(64.2, o.get(1, 0), 0.01);
assertEquals(154.2, o.get(1, 1), 0.01);
}