Package cc.mallet.classify

Source Code of cc.mallet.classify.C45Trainer

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org.  For further
information, see the file `LICENSE' included with this distribution. */




package cc.mallet.classify;

import java.util.logging.Logger;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;


/**
* A C4.5 decision tree learner, approximtely.  Currently treats all
* features as continuous-valued, and has no notion of missing values.<p>
*
* This implementation uses MDL for pruning.<p>
*
* J. R. Quinlan<br>
* "Improved Use of Continuous Attributes in C4.5" <br>
* ftp://ftp.cs.cmu.edu/project/jair/volume4/quinlan96a.ps<p>
*
* J. R. Quinlan and R. L. Rivest<br>
* "Inferring Decision Trees Using Minimum Description Length Principle"
*
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class C45Trainer extends ClassifierTrainer<C45> implements Boostable
{
  private static Logger logger = MalletLogger.getLogger(C45Trainer.class.getName());
  boolean m_depthLimited = false;
  int m_maxDepth = 4;
  int m_minNumInsts = 2// minimum number of instances in each node
  boolean m_doPruning = true;
  C45 classifier;
  public C45 getClassifier () { return classifier; }

  /**
   * Uses default values: not depth limited tree with
   * a minimum of 2 instances in each leaf node
   */
  public C45Trainer() {}
 
  /**
   * Construct a depth-limited tree with the given depth limit
   */ 
  public C45Trainer(int maxDepth)
  {
    m_maxDepth = maxDepth;
    m_depthLimited = true;
  }
 
  public C45Trainer(boolean doPruning)
  {
    m_doPruning = doPruning;
  }
 
  public C45Trainer(int maxDepth, boolean doPruning)
  {
    m_depthLimited = true;
    m_maxDepth = maxDepth;
    m_doPruning = doPruning;
  }
 
  public void setDoPruning(boolean doPruning)
  {
    m_doPruning = doPruning;
  }
 
  public boolean getDoPruning()
  {
    return m_doPruning;
  }
 
  public void setDepthLimited(boolean depthLimited)
  {
    m_depthLimited = depthLimited;
  }
 
  public boolean getDepthLimited()
  {
    return m_depthLimited;
  }
 
  public void setMaxDepth(int maxDepth)
  {
    m_maxDepth = maxDepth;
  }
 
  public int getMaxDepth()
  {
    return m_maxDepth;
  }
 
  public void setMinNumInsts(int minNumInsts)
  {
    m_minNumInsts = minNumInsts;
  }
 
  public int getMinNumInsts()
  {
    return m_minNumInsts;
  }
 
  protected void splitTree(C45.Node node, int depth)
  {
    // Stop growing the tree when any of the following is true:
    //   1.  We care about tree depth and maximum depth is reached
    //   2.  The entropy of the node is too small (i.e., all
    //       instances belong to the same class)
    //   3.  The gain ratio of the best split available is too small
    if (m_depthLimited && depth == m_maxDepth) {
      logger.info("Splitting stopped: maximum depth reached (" + m_maxDepth + ")");
      return;
    }      
    else if (Maths.almostEquals(node.getGainRatio().getBaseEntropy(), 0)) {
      logger.info("Splitting stopped: entropy of node too small (" + node.getGainRatio().getBaseEntropy() + ")");
      return;
    }
    else if (Maths.almostEquals(node.getGainRatio().getMaxValue(), 0)) {
      logger.info("Splitting stopped: node has insignificant gain ratio (" + node.getGainRatio().getMaxValue() + ")");
      return;
    }
    logger.info("Splitting feature \""+node.getSplitFeature()
        +"\" at threshold=" + node.getGainRatio().getMaxValuedThreshold()
        + " gain ratio="+node.getGainRatio().getMaxValue());
    node.split();
    splitTree(node.getLeftChild(), depth+1);
    splitTree(node.getRightChild(), depth+1);
  }
 
  public C45 train (InstanceList trainingList)
  {
    FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
    if (selectedFeatures != null)
      // xxx Attend to FeatureSelection!!!
      throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");
    C45.Node root = new C45.Node(trainingList, null, m_minNumInsts);
    splitTree(root, 0);
    C45 tree = new C45 (trainingList.getPipe(), root);
    logger.info("C45 learned: (size=" + tree.getSize() + ")\n");
    tree.print();
    if (m_doPruning) {
      tree.prune();
      logger.info("\nPruned C45: (size=" + tree.getSize() + ")\n");
      root.print();
    }
    root.stopGrowth();
    this.classifier = tree;
    return classifier;
  }
 
}
TOP

Related Classes of cc.mallet.classify.C45Trainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.