Environment.getInstance().setUseWeightsSharedMemory(true);
NeuralNetworkImpl nn = NNFactory.convNN(new int[][] { { 3, 3, 2 }, { 2, 2, 1, 1 } }, true);
nn.setLayerCalculator(NNFactory.lcSigmoid(nn, null));
Conv2DConnection c = (Conv2DConnection) nn.getInputLayer().getConnections().get(0);
TensorIterator it = c.getWeights().iterator();
float x = 0.1f;
while (it.hasNext()) {
c.getWeights().getElements()[it.next()] = x;
x += 0.1f;
}
Conv2DConnection b = (Conv2DConnection) nn.getOutputLayer().getConnections().get(1);
b.getWeights().getElements()[b.getWeights().getStartIndex()] = -3f;
SimpleInputProvider ts = new SimpleInputProvider(new float[][] { { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f } }, new float[][] { { 1, 1, 1, 1 }, { 1, 1, 1, 1 } });
BackPropagationTrainer<?> t = TrainerFactory.backPropagation(nn, ts, null, null, null, 0.5f, 0f, 0f, 0f, 0f, 1, 1, 1);
t.train();
it = c.getWeights().iterator();
assertEquals(0.12317, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.23533, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.35966, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.47182, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.63263, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.74479, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.86911, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(0.98127, c.getWeights().getElements()[it.next()], 0.00001);
assertEquals(-2.87839, b.getWeights().getElements()[b.getWeights().getStartIndex()], 0.00001);
}