Package org.encog.ensemble

Source Code of org.encog.ensemble.Ensemble

/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ensemble;

import java.util.ArrayList;

import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.EnsembleDataSetFactory;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;

public abstract class Ensemble {

  protected EnsembleDataSetFactory dataSetFactory;
  protected EnsembleTrainFactory trainFactory;
  protected EnsembleAggregator aggregator;
  protected ArrayList<EnsembleML> members;
  protected EnsembleMLMethodFactory mlFactory;
  protected MLDataSet aggregatorDataSet;

  public class NotPossibleInThisMethod extends Exception {

    /**
     * This means the current feature is not applicable in the specified method
     */
    private static final long serialVersionUID = 5118253806179408868L;

  }

  /**
   * Initialise ensemble components
   */
  abstract public void initMembers();

  public void initMembersBySplits(int splits)
  {
    if ((this.dataSetFactory != null) &&
      (splits > 0) &&
      (this.dataSetFactory.hasSource()))
    {
      for (int i = 0; i < splits; i++)
      {
        GenericEnsembleML newML = new GenericEnsembleML(mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()),mlFactory.getLabel());
        newML.setTrainingSet(dataSetFactory.getNewDataSet());
        newML.setTraining(trainFactory.getTraining(newML.getMl(), newML.getTrainingSet()));
        members.add(newML);
      }
      if(aggregator.needsTraining())
        aggregatorDataSet = dataSetFactory.getNewDataSet();
    }
  }

  /**
   * Set the training method to use for this ensemble
   * @param newTrainFactory The training factory.
   */
  public void setTrainingMethod(EnsembleTrainFactory newTrainFactory) {
    this.trainFactory = newTrainFactory;
    initMembers();
  }

  /**
   * Set which training data to base the training on
   * @param data The training data.
   */
  public void setTrainingData(MLDataSet data) {
    dataSetFactory.setInputData(data);
    initMembers();
  }

  /**
   * Set which dataSetFactory to use to create the correct tranining sets
   * @param dataSetFactory The data set factory.
   */
  public void setTrainingDataFactory(EnsembleDataSetFactory dataSetFactory) {
    this.dataSetFactory = dataSetFactory;
    initMembers();
  }

  /**
   * Train the ensemble to a target accuracy
   * @param targetError The target error.
   * @param selectionError The selection error.
   * @param testset The test set.
   * @param verbose Verbose mode?
   */
  public void train(double targetError, double selectionError, EnsembleDataSet testset, boolean verbose) {

    for (EnsembleML current : members)
    {
      do {
        mlFactory.reInit(current.getMl());
        current.train(targetError, verbose);
        if (verbose) {System.out.println("test MSE: " + current.getError(testset));};
      } while (current.getError(testset) > selectionError);
    }
    if(aggregator.needsTraining()) {
      EnsembleDataSet aggTrainingSet = new EnsembleDataSet(members.size() * aggregatorDataSet.getIdealSize(),aggregatorDataSet.getIdealSize());
      for (MLDataPair trainingInput:aggregatorDataSet) {
        BasicMLData trainingInstance = new BasicMLData(members.size() * aggregatorDataSet.getIdealSize());
        int index = 0;
        for(EnsembleML member:members){
          for(double val:member.compute(trainingInput.getInput()).getData()) {
            trainingInstance.add(index++, val);
          }
        }
        aggTrainingSet.add(trainingInstance,trainingInput.getIdeal());
      }
      aggregator.setTrainingSet(aggTrainingSet);
      aggregator.train();
    }
  }

  /**
   * Train the ensemble to a target accuracy
   * @param targetError The target error.
   * @param selectionError The selection error.
   * @param testset The test set.
   */
  public void train(double targetError, double selectionError, EnsembleDataSet testset) {
    train(targetError, selectionError, testset, false);
  }

  /**
   * Extract a specific training set from the Ensemble
   * @param setNumber
   * @return The training set.
   */
  public MLDataSet getTrainingSet(int setNumber) {
    return members.get(setNumber).getTrainingSet();
  }

  /**
   * Extract a specific MLMethod
   * @param memberNumber
   * @return The MLMethod.
   */
  public EnsembleML getMember(int memberNumber) {
    return members.get(memberNumber);
  }

  /**
   * Add a member to the ensemble
   * @param newMember
   * @throws NotPossibleInThisMethod
   */
  public void addMember(EnsembleML newMember) throws NotPossibleInThisMethod {
    members.add(newMember);
  }

  /**
   * Compute the output for a specific input
   * @param input
   * @return The data.
   */
  public MLData compute(MLData input) {
    ArrayList<MLData> outputs = new ArrayList<MLData>();
    for(EnsembleML member: members)
    {
      MLData computed = member.compute(input);
      outputs.add(computed);
    }
    return aggregator.evaluate(outputs);
  }

  /**
   * @return Returns the ensemble aggregation method
   */
  public EnsembleAggregator getAggregator() {
    return aggregator;
  }

  /**
   * Sets the ensemble aggregation method
   * @param aggregator
   */
  public void setAggregator(EnsembleAggregator aggregator) {
    this.aggregator = aggregator;
  }

  /**
   * Return what type of problem this Ensemble is solving
   * @return The problem type.
   */
  abstract public EnsembleTypes.ProblemType getProblemType();

}
TOP

Related Classes of org.encog.ensemble.Ensemble

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.