Package cc.mallet.classify

Source Code of cc.mallet.classify.NaiveBayesTrainer$Factory

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




package cc.mallet.classify;


import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Iterator;

import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;

/**
* Class used to generate a NaiveBayes classifier from a set of training data.
* In an Bayes classifier,
*     the p(Classification|Data) = p(Data|Classification)p(Classification)/p(Data)
* <p>
*  To compute the likelihood: <br>
*      p(Data|Classification) = p(d1,d2,..dn | Classification) <br>
* Naive Bayes makes the assumption  that all of the data are conditionally
* independent given the Classification: <br>
*      p(d1,d2,...dn | Classification) = p(d1|Classification)p(d2|Classification)..
* <p>
* As with other classifiers in Mallet, NaiveBayes is implemented as two classes:
* a trainer and a classifier.  The NaiveBayesTrainer produces estimates of the various
* p(dn|Classifier) and contructs this class with those estimates.
* <p>
* A call to train() or incrementalTrain() produces a
* {@link cc.mallet.classify.NaiveBayes} classifier that can
* can be used to classify instances.  A call to incrementalTrain() does not throw
* away the internal state of the trainer; subsequent calls to incrementalTrain()
* train by extending the previous training set.
* <p>
* A NaiveBayesTrainer can be persisted using serialization.
* @see NaiveBayes
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*
*/
public class NaiveBayesTrainer extends ClassifierTrainer<NaiveBayes>
implements ClassifierTrainer.ByInstanceIncrements<NaiveBayes>, Boostable, AlphabetCarrying, Serializable
{
  // These function as default selections for the kind of Estimator used
  Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
  Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();

  // Added to support incremental training.
  // These are the counts formed after NaiveBayes training.  Note that
  // these are *not* the estimates passed to the NaiveBayes classifier;
  // rather the estimates are formed from these counts.
  // we could break these five fields out into a inner class.
  Multinomial.Estimator[] me;
  Multinomial.Estimator pe;
  double docLengthNormalization = -1// A value of -1 means don't do any document length normalization
  NaiveBayes classifier;

  // If this style of incremental training is successful, the following members
  // should probably be moved up into IncrementalClassifierTrainer
  Pipe instancePipe;        // Needed to construct a new classifier
  Alphabet dataAlphabet;    // Extracted from InstanceList. Must be the same for all calls to incrementalTrain()
  Alphabet targetAlphabet; // Extracted from InstanceList. Must be the same for all calls to incrementalTrain
 
 
  public NaiveBayesTrainer (NaiveBayes initialClassifier) {
    if (initialClassifier != null) {
      this.instancePipe = initialClassifier.getInstancePipe();
      this.dataAlphabet = initialClassifier.getAlphabet();
      this.targetAlphabet = initialClassifier.getLabelAlphabet();
      this.classifier = initialClassifier;
    }
  }
 
  public NaiveBayesTrainer (Pipe instancePipe) {
    this.instancePipe = instancePipe;
    this.dataAlphabet = instancePipe.getDataAlphabet();
    this.targetAlphabet = instancePipe.getTargetAlphabet();
  }
 
  public NaiveBayesTrainer () {
  }
 
 
  public NaiveBayes getClassifier () { return classifier; }

  public NaiveBayesTrainer setDocLengthNormalization (double d) {
    docLengthNormalization = d;
    return this;
  }
 
  public double getDocLengthNormalization () {
    return docLengthNormalization;
  }
 
  /**
   *  Get the MultinomialEstimator instance used to specify the type of estimator
   *  for features.
   *
   * @return  estimator to be cloned on next call to train() or first call
   * to incrementalTrain()
   */
  public Multinomial.Estimator getFeatureMultinomialEstimator ()
  {
    return featureEstimator;
  }

  /**
   * Set the Multinomial Estimator used for features. The MulitnomialEstimator
   * is internally cloned and the clone is used to maintain the counts
   * that will be used to generate probability estimates
   * the next time train() or an initial incrementalTrain() is run.
   * Defaults to a Multinomial.LaplaceEstimator()
   * @param me to be cloned on next call to train() or first call
   * to incrementalTrain()
   */
  public NaiveBayesTrainer setFeatureMultinomialEstimator (Multinomial.Estimator me)
  {
    if (instancePipe != null)
      throw new IllegalStateException("Can't set after incrementalTrain() is called");
    featureEstimator = me;
    return this;
  }

  /**
   *  Get the MultinomialEstimator instance used to specify the type of estimator
   *  for priors.
   *
   * @return  estimator to be cloned on next call to train() or first call
   * to incrementalTrain()
   */
  public Multinomial.Estimator getPriorMultinomialEstimator ()
  {
    return priorEstimator;
  }

  /**
   * Set the Multinomial Estimator used for priors. The MulitnomialEstimator
   * is internally cloned and the clone is used to maintain the counts
   * that will be used to generate probability estimates
   * the next time train() or an initial incrementalTrain() is run.
   * Defaults to a Multinomial.LaplaceEstimator()
   * @param me to be cloned on next call to train() or first call
   * to incrementalTrain()
   */
  public NaiveBayesTrainer setPriorMultinomialEstimator (Multinomial.Estimator me)
  {
    if (instancePipe != null)
      throw new IllegalStateException("Can't set after incrementalTrain() is called");
    priorEstimator = me;
    return this;
  }

 

  /**
   * Create a NaiveBayes classifier from a set of training data.
   * The trainer uses counts of each feature in an instance's feature vector
   * to provide an estimate of p(Labeling| feature).  The internal state
   * of the trainer is thrown away ( by a call to reset() ) when train() returns. Each
   * call to train() is completely independent of any other.
   * @param trainingList        The InstanceList to be used to train the classifier.
   * Within each instance the data slot is an instance of FeatureVector and the
   * target slot is an instance of Labeling
   * @param validationList      Currently unused
   * @param testSet             Currently unused
   * @param evaluator           Currently unused
   * @param initialClassifier   Currently unused
   * @return The NaiveBayes classifier as trained on the trainingList
   */
  public NaiveBayes train (InstanceList trainingList)
  {
    // Forget all the previous sufficient statistics counts;
    me = null; pe = null;
    // Train a new classifier based on this data
    this.classifier = trainIncremental (trainingList);
    return classifier;
  }
 
  public NaiveBayes trainIncremental (InstanceList trainingInstancesToAdd)
  {
    // Initialize and check instance variables as necessary...
    setup(trainingInstancesToAdd, null);

    // Incrementally add the counts of this new training data
    for (Instance instance : trainingInstancesToAdd)
      incorporateOneInstance(instance, trainingInstancesToAdd.getInstanceWeight(instance));
   
    // Estimate multinomials, and return a new naive Bayes classifier. 
    // Note that, unlike MaxEnt, NaiveBayes is immutable, so we create a new one each time.
    classifier = new NaiveBayes (instancePipe, pe.estimate(), estimateFeatureMultinomials());
    return classifier;
  }
 
  public NaiveBayes trainIncremental (Instance instance) {
    setup (null, instance);
   
    // Incrementally add the counts of this new training instance
    incorporateOneInstance (instance, 1.0);
    if (instancePipe == null)
      instancePipe = new Noop (dataAlphabet, targetAlphabet);
    classifier = new NaiveBayes (instancePipe, pe.estimate(), estimateFeatureMultinomials());
    return classifier;
  }

 
  private void setup (InstanceList instances, Instance instance) {
    assert (instances != null || instance != null);
    if (instance == null && instances != null)
      instance = instances.get(0);
    // Initialize the alphabets
    if (dataAlphabet == null) {
      this.dataAlphabet = instance.getDataAlphabet();
      this.targetAlphabet = instance.getTargetAlphabet();
    else if (!Alphabet.alphabetsMatch(instance, this))
      // Make sure the alphabets match
      throw new IllegalArgumentException ("Training set alphabets do not match those of NaiveBayesTrainer.");

    // Initialize or check the instancePipe
    if (instances != null) {
      if (instancePipe == null)
        instancePipe = instances.getPipe();
      else if (instancePipe != instances.getPipe())
        // Make sure that this pipes match.  Is this really necessary?? 
        // I don't think so, but it could be confusing to have each returned classifier have a different pipe?  -akm 1/08
        throw new IllegalArgumentException ("Training set pipe does not match that of NaiveBayesTrainer.");
    }
   
    if (me == null) {
      int numLabels = targetAlphabet.size();
      me = new Multinomial.Estimator[numLabels];
      for (int i = 0; i < numLabels; i++) {
        me[i] = (Multinomial.Estimator) featureEstimator.clone();
        me[i].setAlphabet(dataAlphabet);
      }
      pe = (Multinomial.Estimator) priorEstimator.clone();
    }
   
    if (targetAlphabet.size() > me.length) {
      // target alphabet grew. increase size of our multinomial array
      int targetAlphabetSize = targetAlphabet.size();
      // copy over old values
      Multinomial.Estimator[] newMe = new Multinomial.Estimator[targetAlphabetSize];
      System.arraycopy (me, 0, newMe, 0, me.length);
      // initialize new expanded space
      for (int i= me.length; i<targetAlphabetSize; i++){
        Multinomial.Estimator mest = (Multinomial.Estimator)featureEstimator.clone ();
        mest.setAlphabet (dataAlphabet);
        newMe[i] = mest;
      }
      me = newMe;
    }
  }

  private void incorporateOneInstance (Instance instance, double instanceWeight)
  {
    Labeling labeling = instance.getLabeling ();
    if (labeling == null) return; // Handle unlabeled instances by skipping them
    FeatureVector fv = (FeatureVector) instance.getData ();
    double oneNorm = fv.oneNorm();
    if (oneNorm <= 0) return; // Skip instances that have no features present
    if (docLengthNormalization > 0)
      // Make the document have counts that sum to docLengthNormalization
      // I.e., if 20, it would be as if the document had 20 words.
      instanceWeight *= docLengthNormalization / oneNorm;
    assert (instanceWeight > 0 && !Double.isInfinite(instanceWeight));
    for (int lpos = 0; lpos < labeling.numLocations(); lpos++) {
      int li = labeling.indexAtLocation (lpos);
      double labelWeight = labeling.valueAtLocation (lpos);
      if (labelWeight == 0) continue;
      //System.out.println ("NaiveBayesTrainer me.increment "+ labelWeight * instanceWeight);
      me[li].increment (fv, labelWeight * instanceWeight);
      // This relies on labelWeight summing to 1 over all labels
      pe.increment (li, labelWeight * instanceWeight);
    }
  }
 
  private Multinomial[] estimateFeatureMultinomials () {
    int numLabels = targetAlphabet.size();
    Multinomial[] m = new Multinomial[numLabels];
    for (int li = 0; li < numLabels; li++) {
      //me[li].print (); // debugging
      m[li] = me[li].estimate();
    }
    return m;
  }

  /**
   * Create a NaiveBayes classifier from a set of training data and the
   * previous state of the trainer.  Subsequent calls to incrementalTrain()
   * add to the state of the trainer.  An incremental training session
   * should consist only of calls to incrementalTrain() and have no
   * calls to train();     *
   * @param trainingList        The InstanceList to be used to train the classifier.
   * Within each instance the data slot is an instance of FeatureVector and the
   * target slot is an instance of Labeling
   * @param validationList      Currently unused
   * @param testSet             Currently unused
   * @param evaluator           Currently unused
   * @param initialClassifier   Currently unused
   * @return The NaiveBayes classifier as trained on the trainingList and the previous
   * trainingLists passed to incrementalTrain()
   */

  public String toString()
  {
    return "NaiveBayesTrainer";
  }

 
  // AlphabetCarrying interface
  public boolean alphabetsMatch(AlphabetCarrying object) {
    return Alphabet.alphabetsMatch (this, object);
  }

  public Alphabet getAlphabet() {
    return dataAlphabet;
  }

  public Alphabet[] getAlphabets() {
    return new Alphabet[] { dataAlphabet, targetAlphabet };
  }


  // Serialization
  // serialVersionUID is overriden to prevent innocuous changes in this
  // class from making the serialization mechanism think the external
  // format has changed.

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 1;

  private void writeObject(ObjectOutputStream out) throws IOException
  {
    out.writeInt(CURRENT_SERIAL_VERSION);

    //default selections for the kind of Estimator used
    out.writeObject(featureEstimator);
    out.writeObject(priorEstimator);

    // These are the counts formed after NaiveBayes training.
    out.writeObject(me);
    out.writeObject(pe);

    // pipe and alphabets
    out.writeObject(instancePipe);
    out.writeObject(dataAlphabet);
    out.writeObject(targetAlphabet);
  }

  private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
    int version = in.readInt();
    if (version != CURRENT_SERIAL_VERSION)
      throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted " +
                                       CURRENT_SERIAL_VERSION + ", got " +
                                       version);

    //default selections for the kind of Estimator used
    featureEstimator = (Multinomial.Estimator) in.readObject();
    priorEstimator = (Multinomial.Estimator) in.readObject();

    // These are the counts formed after NaiveBayes training.
    me = (Multinomial.Estimator []) in.readObject();
    pe = (Multinomial.Estimator) in.readObject();

    // pipe and alphabets
    instancePipe = (Pipe) in.readObject();
    dataAlphabet = (Alphabet) in.readObject();
    targetAlphabet = (Alphabet) in.readObject();
  }
 
  public static class Factory extends ClassifierTrainer.Factory<NaiveBayesTrainer>
  {
    Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
    Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
    double docLengthNormalization = -1;
   
    public NaiveBayesTrainer newClassifierTrainer(Classifier initialClassifier) {
      return new NaiveBayesTrainer ((NaiveBayes)initialClassifier);
    }
    public NaiveBayesTrainer.Factory setDocLengthNormalization (double docLengthNormalization) {
      this.docLengthNormalization = docLengthNormalization;
      return this;
    }
   
    public NaiveBayesTrainer.Factory setFeatureMultinomialEstimator (Multinomial.Estimator featureEstimator) {
      this.featureEstimator = featureEstimator;
      return this;
    }
   
    public NaiveBayesTrainer.Factory setPriorMultinomialEstimator (Multinomial.Estimator priorEstimator) {
      this.priorEstimator = priorEstimator;
      return this;
    }
   
  }

}
TOP

Related Classes of cc.mallet.classify.NaiveBayesTrainer$Factory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.