Package com.heatonresearch.aifh.learning

Source Code of com.heatonresearch.aifh.learning.TestRBFNetwork

/*
* Artificial Intelligence for Humans
* Volume 1: Fundamental Algorithms
* Java Version
* http://www.aifh.org
* http://www.jeffheaton.com
*
* Code repository:
* https://github.com/jeffheaton/aifh

* Copyright 2013 by Jeff Heaton
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/

package com.heatonresearch.aifh.learning;

import com.heatonresearch.aifh.AIFH;
import com.heatonresearch.aifh.randomize.BasicGenerateRandom;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
* Test the RBF network.
*/
public class TestRBFNetwork {

    @Test
    public void testBasics() {
        final RBFNetwork network = new RBFNetwork(2, 1, 1);

        // should be 7, (2*1) + (1+(1 bias))*1 + 3 RBF params
        // 2 + 2 + 3 = 7
        assertEquals(7, network.getLongTermMemory().length);

        assertEquals("[RBFNetwork:inputCount=2,outputCount=1,RBFs=[[GaussianFunction:width=0.00,center=0.00,0.00]]]", network.toString());

    }

    @Test
    public void testResetCompute() {
        final RBFNetwork network = new RBFNetwork(2, 1, 1);
        double total = 0;
        for (int i = 0; i < network.getLongTermMemory().length; i++) {
            total += network.getLongTermMemory()[i];
        }
        assertEquals(0, total, AIFH.DEFAULT_PRECISION);

        network.reset(new BasicGenerateRandom());

        for (int i = 0; i < network.getLongTermMemory().length; i++) {
            total += network.getLongTermMemory()[i];
        }

        assertTrue(Math.abs(total) > AIFH.DEFAULT_PRECISION);

    }

    @Test
    public void testComputeRegression() {
        final RBFNetwork network = new RBFNetwork(2, 1, 1);

        final double[] ltm = {
                2.0// input 1 to RBF 1
                2.0// input 2 to RBF 1
                5.0// RBF width
                2.0// RBF, center-0
                4.0// RBF, center-1
                3.0// RBF1 to Output 1
                4.0}// Bias to Output 1


        System.arraycopy(ltm, 0, network.getLongTermMemory(), 0, ltm.length);

        final double[] x = {1, 2};

        final double y = network.computeRegression(x)[0];

        // Inputs: (2*1) + (2*2) = 6
        // RBF: Gaussian(6) = 1
        // Outputs: (1*3) + (1*4) = 7
        assertEquals(7, y, AIFH.DEFAULT_PRECISION);
    }

    @Test
    public void testComputeClassification() {
        final RBFNetwork network = new RBFNetwork(2, 1, 2);

        final double[] ltm = {
                2.0// input 1 to RBF 1
                2.0// input 2 to RBF 1
                5.0// RBF width
                2.0// RBF, center-0
                4.0// RBF, center-1
                3.0// RBF1 to Output 1
                4.0// Bias to Output 1
                5.0// RBF1 to Output 2
                6.0}; // Bias to Output 2


        System.arraycopy(ltm, 0, network.getLongTermMemory(), 0, ltm.length);

        final double[] x = {1, 2};

        final double[] y = network.computeRegression(x);

        // Inputs: (2*1) + (2*2) = 6
        // RBF: Gaussian(6) = 1
        // Outputs: (1*3) + (1*4) = 7
        assertEquals(7, y[0], AIFH.DEFAULT_PRECISION);

        // Inputs: (2*1) + (2*2) = 6
        // RBF: Gaussian(6) = 1
        // Outputs: (1*5) + (1*6) = 11
        assertEquals(11, y[1], AIFH.DEFAULT_PRECISION);

        final int cls = network.computeClassification(x);

        // class 1 is higher than class 0
        assertEquals(1, cls);
    }
}
TOP

Related Classes of com.heatonresearch.aifh.learning.TestRBFNetwork

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.