NNRandomInitializer rand = new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.1f, 0.1f), 0.5f);
rand.initialize(nn);
for (Layer l : nn.getLayers()) {
if (Util.isBias(l)) {
GraphConnections gc = (GraphConnections) l.getConnections().get(0);
for (float v : gc.getConnectionGraph().getElements()) {
assertEquals(0.5, v, 0f);
}
} else {
GraphConnections gc = (GraphConnections) l.getConnections().get(0);
for (float v : gc.getConnectionGraph().getElements()) {
assertTrue(v >= -0.1f && v <= 0.1f && v != 0);
}
}
}
rand = new NNRandomInitializer(new MersenneTwisterRandomInitializer(2f, 3f), new MersenneTwisterRandomInitializer(-2f, -1f));
rand.initialize(nn);
for (Layer l : nn.getLayers()) {
if (Util.isBias(l)) {
GraphConnections gc = (GraphConnections) l.getConnections().get(0);
for (float v : gc.getConnectionGraph().getElements()) {
assertTrue(v >= -2f && v <= -1f);
}
} else {
GraphConnections gc = (GraphConnections) l.getConnections().get(0);
for (float v : gc.getConnectionGraph().getElements()) {
assertTrue(v >= 2f && v <= 3f);
}
}
}
}