Package com.heatonresearch.aifh.aco

Source Code of com.heatonresearch.aifh.aco.ContinuousACO

/*
* Artificial Intelligence for Humans
* Volume 2: Nature Inspired Algorithms
* Java Version
* http://www.aifh.org
* http://www.jeffheaton.com
*
* Code repository:
* https://github.com/jeffheaton/aifh
*
* Copyright 2014 by Jeff Heaton
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package com.heatonresearch.aifh.aco;

import com.heatonresearch.aifh.AIFH;
import com.heatonresearch.aifh.learning.LearningMethod;
import com.heatonresearch.aifh.learning.MLMethod;
import com.heatonresearch.aifh.learning.score.ScoreFunction;
import com.heatonresearch.aifh.randomize.GenerateRandom;
import com.heatonresearch.aifh.randomize.MersenneTwisterGenerateRandom;

import java.util.Arrays;

/**
* This class implements continuous ant colony optimization (CACO)
* <p/>
* References:
* <p/>
* Training Neural Networks with Ant Colony Optimization,
* Arun Pandian, Spring, 2013
* <p/>
* Krzysztof Socha and Christian Blum. “An ant colony optimization algorithm for
* continuous optimization: application to feed-forward neural network training”, in
* Springer London (2007).
* <p/>
* M.Dorigo, V.Maniezzo, and A.Colorni. “Ant System: Optimization by a colony of
* cooperating agents”, in IEEE Transactions on Systems, Man, and Cybernetics,
* 1996.
*/
public class ContinuousACO implements LearningMethod {
    /**
     * Sigma constant. Minimum standard deviation.
     */
    public static final double CONST_SIGMA = 0.1;

    /**
     * Q constant.  Weighting exponent factor.
     */
    public static final double CONST_Q = 0.08;

    /**
     * The population of ants.
     */
    private final ContinuousAnt[] population;

    /**
     * The population size.
     */
    private final int populationSize;

    /**
     * The parameter count.
     */
    private int paramCount = 0;

    /**
     * The weighting of each ant.
     */
    private final double[] weighting;

    /**
     * The sum of the weighting.
     */
    private double sumWeighting = 0;


    /**
     * Epsilon, learning rate.
     */
    private double epsilon = .75;

    /**
     * Random number generation.
     */
    private GenerateRandom random;

    /**
     * The algorithm that we are fitting.
     */
    private MLMethod algorithm;

    /**
     * The score function.
     */
    private ScoreFunction score;

    /**
     * The constructor.
     *
     * @param theAlgorithm      The algorithm to fit.
     * @param theScore          The score function.
     * @param thePopulationSize The population size.
     */
    public ContinuousACO(final MLMethod theAlgorithm, final ScoreFunction theScore, final int thePopulationSize) {
        this.algorithm = theAlgorithm;
        this.populationSize = thePopulationSize;
        this.score = theScore;
        this.random = new MersenneTwisterGenerateRandom();
        this.paramCount = theAlgorithm.getLongTermMemory().length;

        this.population = new ContinuousAnt[thePopulationSize * 2];
        this.weighting = new double[thePopulationSize];
        for (int i = 0; i < this.population.length; i++) {
            this.population[i] = new ContinuousAnt(paramCount, score.shouldMinimize());
            for (int j = 0; j < paramCount; j++) {
                this.population[i].getParams()[j] = random.nextDouble(-1, 1);
            }
        }

        updateScore();
        Arrays.sort(this.population);
        computeWeighting();
        sampleSolutions();
        Arrays.sort(this.population);

    }

    /**
     * Update the score.
     */
    private void updateScore() {

        for (final ContinuousAnt aPopulation : this.population) {
            System.arraycopy(aPopulation.getParams(), 0, this.algorithm.getLongTermMemory(), 0, this.paramCount);
            aPopulation.setScore(this.score.calculateScore(this.algorithm));
        }
    }

    /**
     * Compute the weighting for each ant.
     */
    private void computeWeighting() {
        sumWeighting = 0;
        for (int i = 0; i < this.populationSize; i++) {
            double exponent = (i * i) / (2 * CONST_Q * CONST_Q * this.populationSize * this.populationSize);
            this.weighting[i] =
                    (1 / (0.1 * Math.sqrt(2 * Math.PI))) * Math.pow(Math.E, -exponent);
            sumWeighting += weighting[i];
        }
    }

    /**
     * Compute the standard deviation.
     *
     * @param x The parameter to compute for.
     * @param l The population member.
     * @return The standard deviation.
     */
    private double computeSD(int x, int l) {
        double sum = 0.0;
        for (int i = 0; i < this.populationSize; i++) {
            sum += Math.abs(this.population[i].getParams()[x] - this.population[l].getParams()[x]) / (this.populationSize - 1);
        }
        if (sum < AIFH.DEFAULT_PRECISION) {
            return CONST_SIGMA;
        }
        return (epsilon * sum);
    }

    /**
     * Select a probability distribution function (PDF).
     *
     * @return The PDF index.
     */
    private int selectPDF() {
        int l = 0;
        double temp = 0;

        double r = random.nextDouble();
        for (int i = 0; i < this.populationSize; i++) {
            temp += weighting[i] / sumWeighting;
            if (r < temp) {
                l = i;
                break;
            }
        }
        return l;
    }

    /**
     * Sample new parameters.
     */
    private void sampleSolutions() {
        for (int i = this.populationSize; i < this.population.length; i++) {
            int pdf = selectPDF();
            for (int j = 0; j < paramCount; j++) {
                double sigma = computeSD(j, pdf);
                double mu = this.population[pdf].getParams()[j];
                double d = (random.nextGaussian() * sigma) + mu;
                this.population[i].getParams()[j] = d;
            }
        }
    }

    /**
     * @return The value for epsilon, the learning rate.
     */
    public double getEpsilon() {
        return epsilon;
    }

    /**
     * Set epsilon, the learning rate.
     *
     * @param epsilon The epsilon value.
     */
    public void setEpsilon(final double epsilon) {
        this.epsilon = epsilon;
    }

    /**
     * @return Random number generator.
     */
    public GenerateRandom getRandom() {
        return random;
    }

    public void setRandom(final GenerateRandom random) {
        this.random = random;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void iteration() {
        computeWeighting();
        sampleSolutions();
        updateScore();
        Arrays.sort(this.population);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public double getLastError() {
        return this.population[0].getScore();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public boolean done() {
        return false;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String getStatus() {
        return "";
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void finishTraining() {
        System.arraycopy(this.population[0].getParams(), 0, this.algorithm.getLongTermMemory(), 0, this.algorithm.getLongTermMemory().length);
    }
}
TOP

Related Classes of com.heatonresearch.aifh.aco.ContinuousACO

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.