Package org.encog.ml.model

Source Code of org.encog.ml.model.EncogModel

/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.encog.EncogError;
import org.encog.NullStatusReportable;
import org.encog.StatusReportable;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLClassification;
import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.cross.DataFold;
import org.encog.ml.data.cross.KFoldCrossvalidation;
import org.encog.ml.data.versatile.MatrixMLDataSet;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.division.DataDivision;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.model.config.FeedforwardConfig;
import org.encog.ml.model.config.MethodConfig;
import org.encog.ml.model.config.NEATConfig;
import org.encog.ml.model.config.PNNConfig;
import org.encog.ml.model.config.RBFNetworkConfig;
import org.encog.ml.model.config.SVMConfig;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.end.SimpleEarlyStoppingStrategy;
import org.encog.util.Format;
import org.encog.util.simple.EncogUtility;

/**
* Encog model is designed to allow you to easily swap between different model
* types and automatically normalize data.  It is designed to work with a
* VersatileMLDataSet only.
*/
public class EncogModel {

  /**
   * The dataset to use.
   */
  private final VersatileMLDataSet dataset;
 
  /**
   * The input features.
   */
  private final List<ColumnDefinition> inputFeatures = new ArrayList<ColumnDefinition>();
 
  /**
   * The predicted features.
   */
  private final List<ColumnDefinition> predictedFeatures = new ArrayList<ColumnDefinition>();
 
  /**
   * The training dataset.
   */
  private MatrixMLDataSet trainingDataset;
 
  /**
   * The validation dataset.
   */
  private MatrixMLDataSet validationDataset;
 
  /**
   * The standard configrations for each method type.
   */
  private final Map<String, MethodConfig> methodConfigurations = new HashMap<String, MethodConfig>();
 
  /**
   * The current method configuration, determined by the selected model.
   */
  private MethodConfig config;
 
  /**
   * The selected method type.
   */
  private String methodType;
 
  /**
   * The method arguments for the selected method.
   */
  private String methodArgs;
 
  /**
   * The selected training type.
   */
  private String trainingType;
 
  /**
   * The training arguments for the selected training type.
   */
  private String trainingArgs;
 
  /**
   * The report.
   */
  private StatusReportable report = new NullStatusReportable();

  /**
   * Construct a model for the specified dataset.
   * @param theDataset The dataset.
   */
  public EncogModel(VersatileMLDataSet theDataset) {
    this.dataset = theDataset;
    this.methodConfigurations.put(MLMethodFactory.TYPE_FEEDFORWARD,
        new FeedforwardConfig());
    this.methodConfigurations
        .put(MLMethodFactory.TYPE_SVM, new SVMConfig());
    this.methodConfigurations.put(MLMethodFactory.TYPE_RBFNETWORK,
        new RBFNetworkConfig());
    this.methodConfigurations.put(MLMethodFactory.TYPE_NEAT,
        new NEATConfig());
    this.methodConfigurations
        .put(MLMethodFactory.TYPE_PNN, new PNNConfig());
  }

  /**
   * @return the dataset
   */
  public VersatileMLDataSet getDataset() {
    return dataset;
  }

  /**
   * @return the inputFeatures
   */
  public List<ColumnDefinition> getInputFeatures() {
    return inputFeatures;
  }

  /**
   * @return the predictedFeatures
   */
  public List<ColumnDefinition> getPredictedFeatures() {
    return predictedFeatures;
  }

  /**
   * Specify a validation set to hold back.
   * @param validationPercent The percent to use for validation.
   * @param shuffle True to shuffle.
   * @param seed The seed for random generation.
   */
  public void holdBackValidation(double validationPercent, boolean shuffle,
      int seed) {
    List<DataDivision> dataDivisionList = new ArrayList<DataDivision>();
    dataDivisionList.add(new DataDivision(1.0 - validationPercent));// Training
    dataDivisionList.add(new DataDivision(validationPercent));// Validation
    this.dataset.divide(dataDivisionList, shuffle,
        new MersenneTwisterGenerateRandom(seed));
    this.trainingDataset = dataDivisionList.get(0).getDataset();
    this.validationDataset = dataDivisionList.get(1).getDataset();
  }

  /**
   * Fit the model using cross validation.
   * @param k The number of folds total.
   * @param foldNum The current fold.
   * @param fold The current fold.
   */
  private void fitFold(int k, int foldNum, DataFold fold) {
    MLMethod method = this.createMethod();
    MLTrain train = this.createTrainer(method, fold.getTraining());

    if (train.getImplementationType() == TrainingImplementationType.Iterative) {
      SimpleEarlyStoppingStrategy earlyStop = new SimpleEarlyStoppingStrategy(
          fold.getValidation());
      train.addStrategy(earlyStop);

      StringBuilder line = new StringBuilder();
      while (!train.isTrainingDone()) {
        train.iteration();
        line.setLength(0);
        line.append("Fold #");
        line.append(foldNum);
        line.append("/");
        line.append(k);
        line.append(": Iteration #");
        line.append(train.getIteration());
        line.append(", Training Error: ");
        line.append(Format.formatDouble(train.getError(), 8));
        line.append(", Validation Error: ");
        line.append(Format.formatDouble(earlyStop.getValidationError(),
            8));
        report.report(k, foldNum, line.toString());
      }
      fold.setScore(earlyStop.getValidationError());
      fold.setMethod(method);
    } else if (train.getImplementationType() == TrainingImplementationType.OnePass) {
      train.iteration();
      double validationError = calculateError(method,
          fold.getValidation());
      this.report.report(k, k,
          "Trained, Training Error: " + train.getError()
              + ", Validatoin Error: " + validationError);
      fold.setScore(validationError);
      fold.setMethod(method);
    } else {
      throw new EncogError("Unsupported training type for EncogModel: "
          + train.getImplementationType());
    }
  }

  /**
   * Calculate the error for the given method and dataset.
   * @param method The method to use.
   * @param data The data to use.
   * @return The error.
   */
  public double calculateError(MLMethod method, MLDataSet data) {
    if (this.dataset.getNormHelper().getOutputColumns().size() == 1) {
      ColumnDefinition cd = this.dataset.getNormHelper()
          .getOutputColumns().get(0);
      if (cd.getDataType() == ColumnType.nominal) {
        return EncogUtility.calculateClassificationError(
            (MLClassification) method, data);
      }
    }

    return EncogUtility.calculateRegressionError((MLRegression) method,
        data);
  }

  /**
   * Create a trainer.
   * @param method The method to train.
   * @param dataset The dataset.
   * @return The trainer.
   */
  private MLTrain createTrainer(MLMethod method, MLDataSet dataset) {

    if (this.trainingType == null) {
      throw new EncogError(
          "Please call selectTraining first to choose how to train.");
    }
    MLTrainFactory trainFactory = new MLTrainFactory();
    MLTrain train = trainFactory.create(method, dataset, this.trainingType,
        this.trainingArgs);
    return train;
  }

  /**
   * Crossvalidate and fit.
   * @param k The number of folds.
   * @param shuffle True if we should shuffle.
   * @return The trained method.
   */
  public MLMethod crossvalidate(int k, boolean shuffle) {
    KFoldCrossvalidation cross = new KFoldCrossvalidation(
        this.trainingDataset, k);
    cross.process(shuffle);

    int foldNumber = 0;
    for (DataFold fold : cross.getFolds()) {
      foldNumber++;
      report.report(k, foldNumber, "Fold #" + foldNumber);
      fitFold(k, foldNumber, fold);
    }

    double sum = 0;
    double bestScore = Double.POSITIVE_INFINITY;
    MLMethod bestMethod = null;
    for (DataFold fold : cross.getFolds()) {
      sum += fold.getScore();
      if (fold.getScore() < bestScore) {
        bestScore = fold.getScore();
        bestMethod = fold.getMethod();
      }
    }
    sum = sum / cross.getFolds().size();
    report.report(k, k, "Cross-validated score:" + sum);
    return bestMethod;
  }

  /**
   * @return the trainingDataset
   */
  public MatrixMLDataSet getTrainingDataset() {
    return trainingDataset;
  }

  /**
   * @param trainingDataset
   *            the trainingDataset to set
   */
  public void setTrainingDataset(MatrixMLDataSet trainingDataset) {
    this.trainingDataset = trainingDataset;
  }

  /**
   * @return the validationDataset
   */
  public MatrixMLDataSet getValidationDataset() {
    return validationDataset;
  }

  /**
   * @param validationDataset
   *            the validationDataset to set
   */
  public void setValidationDataset(MatrixMLDataSet validationDataset) {
    this.validationDataset = validationDataset;
  }

  /**
   * Select the method to use.
   * @param dataset The dataset.
   * @param methodType The type of method.
   * @param methodArgs The method arguments.
   * @param trainingType The training type.
   * @param trainingArgs The training arguments.
   */
  public void selectMethod(VersatileMLDataSet dataset, String methodType,
      String methodArgs, String trainingType, String trainingArgs) {

    if (!this.methodConfigurations.containsKey(methodType)) {
      throw new EncogError("Don't know how to autoconfig method: "
          + methodType);
    }
    this.methodType = methodType;
    this.methodArgs = methodArgs;
    dataset.getNormHelper().setStrategy(
        this.methodConfigurations.get(methodType)
            .suggestNormalizationStrategy(dataset, methodArgs));

  }

  /**
   * Create the selected method.
   * @return The created method.
   */
  public MLMethod createMethod() {
    if (this.methodType == null) {
      throw new EncogError(
          "Please call selectMethod first to choose what type of method you wish to use.");
    }
    MLMethodFactory methodFactory = new MLMethodFactory();
    MLMethod method = methodFactory.create(methodType, methodArgs, dataset
        .getNormHelper().calculateNormalizedInputCount(), this.config
        .determineOutputCount(dataset));
    return method;
  }

  /**
   * Select the method to create.
   * @param dataset The dataset.
   * @param methodType The method type.
   */
  public void selectMethod(VersatileMLDataSet dataset, String methodType) {
    if (!this.methodConfigurations.containsKey(methodType)) {
      throw new EncogError("Don't know how to autoconfig method: "
          + methodType);
    }

    this.config = this.methodConfigurations.get(methodType);
    this.methodType = methodType;
    this.methodArgs = this.config.suggestModelArchitecture(dataset);
    dataset.getNormHelper().setStrategy(
        this.config.suggestNormalizationStrategy(dataset, methodArgs));

  }

  /**
   * Select the training type.
   * @param dataset The dataset.
   */
  public void selectTrainingType(VersatileMLDataSet dataset) {
    if (this.methodType == null) {
      throw new EncogError(
          "Please select your training method, before your training type.");
    }
    MethodConfig config = this.methodConfigurations.get(methodType);
    selectTraining(dataset, config.suggestTrainingType(),
        config.suggestTrainingArgs(trainingType));
  }

  /**
   * Select the training to use.
   * @param dataset The dataset.
   * @param trainingType The type of training.
   * @param trainingArgs The training arguments.
   */
  public void selectTraining(VersatileMLDataSet dataset, String trainingType,
      String trainingArgs) {
    if (this.methodType == null) {
      throw new EncogError(
          "Please select your training method, before your training type.");
    }

    this.trainingType = trainingType;
    this.trainingArgs = trainingArgs;
  }

  /**
   * @return the methodConfigurations
   */
  public Map<String, MethodConfig> getMethodConfigurations() {
    return methodConfigurations;
  }

  /**
   * @return the report
   */
  public StatusReportable getReport() {
    return report;
  }

  /**
   * @param report
   *            the report to set
   */
  public void setReport(StatusReportable report) {
    this.report = report;
  }
}
TOP

Related Classes of org.encog.ml.model.EncogModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.