Package cc.mallet.classify

Source Code of cc.mallet.classify.DecisionTree

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify;

import java.io.Serializable;
import java.util.logging.Logger;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.util.MalletLogger;

/**
   Decision Tree classifier.
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

public class DecisionTree extends Classifier implements Serializable //implements InduceFeatures
{
  private static final long serialVersionUID = 1L;
 
  private static Logger logger = MalletLogger.getLogger(DecisionTree.class.getName());
  Node root;
 
  public DecisionTree (Pipe instancePipe, DecisionTree.Node root)
  {
    super (instancePipe);
    this.root = root;
  }

  public Node getRoot ()
  {
    return root;
  }

  private Node getLeaf (Node node, FeatureVector fv)
  {
    if (node.child0 == null)
      return node;
    else if (fv.value (node.featureIndex) != 0)
      return getLeaf (node.child1, fv);
    else
      return getLeaf (node.child0, fv);
  }

  public Classification classify (Instance instance)
  {
    FeatureVector fv = (FeatureVector) instance.getData ();
    assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
   
    Node leaf = getLeaf (root, fv);
    return new Classification (instance, this, leaf.labeling);
  }

  // Entropy of 1.0 would say that it take "one bit" to indicate the correct class,
  // e.g. that there is a 50/50 split between two classes given a particular feature
  public double addFeaturesClassEntropyThreshold = 0.7;
 
  public void induceFeatures (InstanceList ilist, boolean withFeatureShrinkage, boolean inducePerClassFeatures)
  {
    if (inducePerClassFeatures) {
      int numClasses = ilist.getTargetAlphabet().size();
//      int numFeatures = ilist.getDataAlphabet().size();
      FeatureSelection[] pcfs = new FeatureSelection[numClasses];
      for (int j = 0; j < numClasses; j++)
        pcfs[j] = (FeatureSelection) ilist.getPerLabelFeatureSelection()[j].clone();
      for (int i = 0; i < ilist.size(); i++) {
        Object data = ilist.get(i).getData();
        AugmentableFeatureVector afv = (AugmentableFeatureVector) data;
        root.induceFeatures (afv, null, pcfs, ilist.getFeatureSelection(), ilist.getPerLabelFeatureSelection(),
                             withFeatureShrinkage, inducePerClassFeatures, addFeaturesClassEntropyThreshold);
      }
    } else {
      throw new UnsupportedOperationException ("Not yet implemented");
    }
  }

 
  public static class Node implements Serializable
  {
    private static final long serialVersionUID = 1L;
   
    int featureIndex;                    // the feature on which the children (would) distinguish
    double infoGain;                    // the information gain of splitting on this feature
    InstanceList ilist;
    Alphabet dictionary;
    double labelEntropy;                // the class label entropy of data in this (unsplit) node
    Labeling labeling;                  // the class label distribution in the node (unsplit)
    Node parent, child0, child1;
    String name;

    // xxx Also calculate some sort of inverted entropy for feature induction,
    // in order to find the one class needs a new feature with a negative weight.

    public Node (InstanceList ilist, Node parent, FeatureSelection fs)
    {
      InfoGain ig = new InfoGain (ilist);
      this.featureIndex = ig.getMaxValuedIndexIn (fs);
      this.infoGain = ig.value(featureIndex);
      this.ilist = ilist;
      this.dictionary = ilist.getDataAlphabet();
      this.parent = parent;
      this.labeling = ig.getBaseLabelDistribution();
      this.labelEntropy = ig.getBaseEntropy();
      this.child0 = this.child1 = null;
    }

    /** The root has depth zero. */
    public int depth ()
    {
      int depth = 0;
      Node p = parent;
      while (p != null) {
        p = p.parent;
        depth++;
      }
      return depth;
    }

    public boolean isLeaf ()
    {
      return (child0 == null && child1 == null);
    }

    public boolean isRoot ()
    {
      return parent == null;
    }
   
    public Node getFeatureAbsentChild () { return child0; }
    public Node getFeaturePresentChild () { return child1; }
    public double getSplitInfoGain () { return infoGain; }
    public Object getSplitFeature () { return ilist.getDataAlphabet().lookupObject(featureIndex); }

    public void split (FeatureSelection fs)
    {
      if (ilist == null)
        throw new IllegalStateException ("Frozen.  Cannot split.");
      InstanceList ilist0 = new InstanceList (ilist.getPipe());
      InstanceList ilist1 = new InstanceList (ilist.getPipe());
      for (int i = 0; i < ilist.size(); i++) {
        Instance instance = ilist.get(i);
        FeatureVector fv = (FeatureVector) instance.getData ();
        // xxx What test should this be?  What to do with negative values?
          // Whatever is decided here should also go in InfoGain.calcInfoGains()
        if (fv.value (featureIndex) != 0) {
          //System.out.println ("list1 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
          ilist1.add (instance, ilist.getInstanceWeight(i));
        } else {
          //System.out.println ("list0 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
          ilist0.add (instance, ilist.getInstanceWeight(i));
        }
      }
      logger.info("child0="+ilist0.size()+" child1="+ilist1.size());
      child0 = new Node (ilist0, this, fs);
      child1 = new Node (ilist1, this, fs);
    }

    // Saves memory by allowing ilist to be garbage collected
    public void stopGrowth ()
    {
      if (child0 != null) {
        child0.stopGrowth();
        child1.stopGrowth();
      }
      ilist = null;
    }

    public void induceFeatures (AugmentableFeatureVector afv,
                                FeatureSelection featuresAlreadyThere,
                                FeatureSelection[] perClassFeaturesAlreadyThere,
                                FeatureSelection newFeatureSelection,
                                FeatureSelection[] perClassNewFeatureSelection,
                                boolean withInteriorNodes,
                                boolean addPerClassFeatures,
                                double classEntropyThreshold)
    {
      if (!isRoot() && (isLeaf() || withInteriorNodes) && labelEntropy < classEntropyThreshold) {
        String name = getName();
        logger.info("Trying to add feature "+name);
        //int conjunctionIndex = afv.getAlphabet().lookupIndex (name, false);
        if (addPerClassFeatures) {
          int classIndex = labeling.getBestIndex();
          if (!perClassFeaturesAlreadyThere[classIndex].contains (name)) {
            afv.add (name, 1.0);
            perClassNewFeatureSelection[classIndex].add (name);
          }
        } else {
          throw new UnsupportedOperationException ("Not yet implemented.");
        }
      }
      boolean featurePresent = afv.value (featureIndex) != 0;
      if (child0 != null && !featurePresent)
        child0.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere,
                               newFeatureSelection, perClassNewFeatureSelection,
                               withInteriorNodes, addPerClassFeatures, classEntropyThreshold);
      if (child1 != null && featurePresent)
        child1.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere,
                               newFeatureSelection, perClassNewFeatureSelection,
                               withInteriorNodes, addPerClassFeatures, classEntropyThreshold);
    }

    public String getName ()
    {
//      String prefix;
      if (parent == null)
        return "root";
      else if (parent.parent == null) {
        if (parent.getFeaturePresentChild() == this)
          return dictionary.lookupObject(parent.featureIndex).toString();
        else {
          assert (dictionary != null);
          assert (dictionary.lookupObject(parent.featureIndex) != null);
          return "!" + dictionary.lookupObject(parent.featureIndex).toString();
        }
      } else {
        if (parent.getFeaturePresentChild() == this)
          return parent.getName() + "&" + dictionary.lookupObject(parent.featureIndex).toString();
        else
          return parent.getName() + "&!" + dictionary.lookupObject(parent.featureIndex).toString();
      }
    }

    public void print ()
    {
      if (child0 == null)
        System.out.println (getName() + ": " + labeling.getBestLabel());
      else {
        child0.print();
        child1.print();
      }
    }
   
  }



}
TOP

Related Classes of cc.mallet.classify.DecisionTree

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.