Package com.heatonresearch.aifh.examples.capstone.model.milestone2

Source Code of com.heatonresearch.aifh.examples.capstone.model.milestone2.FitTitanic

/*
* Artificial Intelligence for Humans
* Volume 2: Nature Inspired Algorithms
* Java Version
* http://www.aifh.org
* http://www.jeffheaton.com
*
* Code repository:
* https://github.com/jeffheaton/aifh
*
* Copyright 2014 by Jeff Heaton
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package com.heatonresearch.aifh.examples.capstone.model.milestone2;

import com.heatonresearch.aifh.examples.capstone.model.TitanicConfig;
import com.heatonresearch.aifh.examples.capstone.model.milestone1.NormalizeTitanic;
import com.heatonresearch.aifh.examples.capstone.model.milestone1.TitanicStats;
import com.heatonresearch.aifh.general.data.BasicData;
import com.heatonresearch.aifh.learning.RBFNetwork;
import com.heatonresearch.aifh.learning.TrainPSO;
import com.heatonresearch.aifh.learning.score.ScoreFunction;
import com.heatonresearch.aifh.randomize.GenerateRandom;
import com.heatonresearch.aifh.randomize.MersenneTwisterGenerateRandom;

import java.io.File;
import java.io.IOException;
import java.util.List;

/**
* The second milestone for titanic is to fit and cross validate a model.
*/
public class FitTitanic {
    /**
     * The best RBF network.
     */
    private RBFNetwork bestNetwork;

    /**
     * The best score.
     */
    private double bestScore;

    /**
     * The cross validation folds.
     */
    private CrossValidate cross;

    /**
     * Train a fold.
     *
     * @param k    The fold number.
     * @param fold The fold.
     */
    public void trainFold(int k, CrossValidateFold fold) {
        int noImprove = 0;
        double localBest = 0;

        // Get the training and cross validation sets.
        List<BasicData> training = fold.getTrainingSet();
        List<BasicData> validation = fold.getValidationSet();

        // Create random particles for the RBF.
        GenerateRandom rnd = new MersenneTwisterGenerateRandom();
        RBFNetwork[] particles = new RBFNetwork[TitanicConfig.ParticleCount];
        for (int i = 0; i < particles.length; i++) {
            particles[i] = new RBFNetwork(TitanicConfig.InputFeatureCount, TitanicConfig.RBF_COUNT, 1);
            particles[i].reset(rnd);
        }

        /**
         * Construct a network to hold the best network.
         */
        if (bestNetwork == null) {
            bestNetwork = new RBFNetwork(TitanicConfig.InputFeatureCount, TitanicConfig.RBF_COUNT, 1);
        }

        /**
         * Setup the scoring function.
         */
        ScoreFunction score = new ScoreTitanic(training);
        ScoreFunction scoreValidate = new ScoreTitanic(validation);

        /**
         * Setup particle swarm.
         */
        boolean done = false;
        TrainPSO train = new TrainPSO(particles, score);
        int iterationNumber = 0;
        StringBuilder line = new StringBuilder();

        do {
            iterationNumber++;

            train.iteration();

            RBFNetwork best = (RBFNetwork) train.getBestParticle();

            double trainingScore = train.getLastError();
            double validationScore = scoreValidate.calculateScore(best);

            if (validationScore > bestScore) {
                System.arraycopy(best.getLongTermMemory(), 0, this.bestNetwork.getLongTermMemory(), 0, best.getLongTermMemory().length);
                this.bestScore = validationScore;
            }

            if (validationScore > localBest) {
                noImprove = 0;
                localBest = validationScore;
            } else {
                noImprove++;
            }

            line.setLength(0);
            line.append("Fold #");
            line.append(k + 1);
            line.append(", Iteration #");
            line.append(iterationNumber);
            line.append(": training correct: ");
            line.append(trainingScore);
            line.append(", validation correct: ");
            line.append(validationScore);
            line.append(", no improvement: ");
            line.append(noImprove);

            if (noImprove > TitanicConfig.AllowNoImprovement) {
                done = true;
            }

            System.out.println(line.toString());
        } while (!done);

        fold.setScore(localBest);
    }


    /**
     * Fit a RBF model to the titanic.
     *
     * @param dataPath The path that contains the data file.
     */
    public void process(File dataPath) {
        File trainingPath = new File(dataPath, TitanicConfig.TrainingFilename);
        File testPath = new File(dataPath, TitanicConfig.TestFilename);

        GenerateRandom rnd = new MersenneTwisterGenerateRandom();

        try {

            // Generate stats on the titanic.
            TitanicStats stats = new TitanicStats();
            NormalizeTitanic.analyze(stats, trainingPath);
            NormalizeTitanic.analyze(stats, testPath);

            // Get the training data for the titanic.
            List<BasicData> training = NormalizeTitanic.normalize(stats, trainingPath, null,
                    TitanicConfig.InputNormalizeLow,
                    TitanicConfig.InputNormalizeHigh,
                    TitanicConfig.PredictSurvive,
                    TitanicConfig.PredictPerish);

            // Fold the data for cross validation.
            this.cross = new CrossValidate(TitanicConfig.FoldCount, training, rnd);

            // Train each of the folds.
            for (int k = 0; k < cross.size(); k++) {
                System.out.println("Cross validation fold #" + (k + 1) + "/" + cross.size());
                trainFold(k, cross.getFolds().get(k));
            }

            // Show the cross validation summary.
            System.out.println("Crossvalidation summary:");
            int k = 1;
            for (CrossValidateFold fold : cross.getFolds()) {
                System.out.println("Fold #" + k + ": " + fold.getScore());
                k++;
            }

            System.out.print("Final, crossvalidated score:" + cross.getScore());

        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    /**
     * @return The best network from the folds.
     */
    public RBFNetwork getBestNetwork() {
        return bestNetwork;
    }

    /**
     * @return The cross validation folds.
     */
    public CrossValidate getCrossvalidation() {
        return this.cross;
    }

    /**
     * Main entry point.
     *
     * @param args The path to the data file.
     */
    public static void main(String[] args) {
        String filename;
        if (args.length != 1) {
            filename = System.getProperty("FILENAME");
            if( filename==null ) {
                System.out.println("Please call this program with a single parameter that specifies your data directory.\n" +
                        "If you are calling with gradle, consider:\n" +
                        "gradle runCapstoneTitanic2 -Pdata_path=[path to your data directory]\n");
                System.exit(0);
            }
        } else {
            filename = args[0];
        }

        File dataPath = new File(filename);

        FitTitanic fit = new FitTitanic();
        fit.process(dataPath);
    }
}
TOP

Related Classes of com.heatonresearch.aifh.examples.capstone.model.milestone2.FitTitanic

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.