Package zdenekdrahos.AI

Source Code of zdenekdrahos.AI.NetworkTestUtil

package zdenekdrahos.AI;

import java.util.List;
import static org.junit.Assert.assertEquals;
import zdenekdrahos.AI.FeedForward.INetworkValues;
import zdenekdrahos.AI.NeuralNetwork.Builder.Activations;
import zdenekdrahos.AI.NeuralNetwork.Builder.INetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.Builder.NetworkBuilder;
import zdenekdrahos.AI.NeuralNetwork.INeuralNetwork;
import zdenekdrahos.AI.NeuralNetwork.Layers.ILayer;
import zdenekdrahos.AI.NeuralNetwork.Weights.IWeight;

public class NetworkTestUtil {

    private static int[] topology = {1, 4, 1};
    private static Activations[] activations = {Activations.LIN, Activations.TANH, Activations.LIN};
    private static double[][][] weights = {
        {
            {0.1, -0.1},
            {-0.2, 0.2},
            {0.15, -0.15},
            {-0.05, 0.05}
        },
        {
            {0.1, 0.01, 0.15, 0.21, -0.50}
        }
    };

    public static INeuralNetwork getNetwork() {
        INetworkBuilder builder = new NetworkBuilder();
        INeuralNetwork network = builder.build(topology, activations);
        network.setWeights(weights);
        return network;
    }

    public static double[][][] getWeights() {
        return weights;
    }

    public static void assertNetworkValues(INeuralNetwork network, INetworkValues values, double[][] expectedWeights) {
        List<Double> temp;
        for (int i = 1; i < network.getLayersCount(); i++) {
            temp = values.getLayerValues(i);
            for (int j = 0; j < temp.size(); j++) {
                assertEquals(temp.get(j), expectedWeights[i - 1][j], 0.1);
            }
        }
    }

    public static void assertNetworkWeights(INeuralNetwork network, double[][][] expectedWeights) {
        ILayer layer;
        List<IWeight> currentWeights;
        int previousNeuronIndex;
        for (int layerIndex = 1; layerIndex < network.getLayersCount(); layerIndex++) {
            layer = network.getLayer(layerIndex);
            for (int neuronIndex = 0; neuronIndex < layer.getNeuronsCount(); neuronIndex++) {
                currentWeights = network.getWeights().getNeuronWeights(layerIndex, neuronIndex);
                previousNeuronIndex = 0;
                for (IWeight weight : currentWeights) {
                    assertEquals(weight.getWeight(), expectedWeights[layerIndex - 1][neuronIndex][previousNeuronIndex++], 0.1);
                }
            }
        }
    }
}
TOP

Related Classes of zdenekdrahos.AI.NetworkTestUtil

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.