package zdenekdrahos.AI.BackPropagation;
import org.junit.Test;
import zdenekdrahos.AI.NetworkTestUtil;
import zdenekdrahos.AI.FeedForward.FeedForward;
import zdenekdrahos.AI.FeedForward.IFeedForward;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
public class IBackPropagationTest {
private double[] input = {0};
private double[] output = {1};
private INeuralNetwork network;
private IBackPropagation training = new BackPropagation();
public IBackPropagationTest() {
network = NetworkTestUtil.getNetwork();
}
@Test
public void testTrainNetwork() {
INetworkValues values = feedForwardNetwork();
double[][][] expectedWeights = {
{
{0.106478, -0.1},
{-0.105682, 0.2},
{0.284352, -0.15},
{-0.376320, 0.05}
},
{
{0.754273, 0.075210, 0.020863, 0.307412, -0.532686}
}
};
training.setLearningRate(0.75);
training.setMomentum(0.9);
training.trainNetwork(network, values, output);
NetworkTestUtil.assertNetworkWeights(network, expectedWeights);
}
private INetworkValues feedForwardNetwork() {
IFeedForward feedForward = new FeedForward();
return feedForward.buildNetwork(network, input);
}
}