/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Iterator;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;
/**
* Class used to generate a NaiveBayes classifier from a set of training data.
* In an Bayes classifier,
* the p(Classification|Data) = p(Data|Classification)p(Classification)/p(Data)
* <p>
* To compute the likelihood: <br>
* p(Data|Classification) = p(d1,d2,..dn | Classification) <br>
* Naive Bayes makes the assumption that all of the data are conditionally
* independent given the Classification: <br>
* p(d1,d2,...dn | Classification) = p(d1|Classification)p(d2|Classification)..
* <p>
* As with other classifiers in Mallet, NaiveBayes is implemented as two classes:
* a trainer and a classifier. The NaiveBayesTrainer produces estimates of the various
* p(dn|Classifier) and contructs this class with those estimates.
* <p>
* A call to train() or incrementalTrain() produces a
* {@link cc.mallet.classify.NaiveBayes} classifier that can
* can be used to classify instances. A call to incrementalTrain() does not throw
* away the internal state of the trainer; subsequent calls to incrementalTrain()
* train by extending the previous training set.
* <p>
* A NaiveBayesTrainer can be persisted using serialization.
* @see NaiveBayes
* @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*
*/
public class NaiveBayesTrainer extends ClassifierTrainer<NaiveBayes>
implements ClassifierTrainer.ByInstanceIncrements<NaiveBayes>, Boostable, AlphabetCarrying, Serializable
{
// These function as default selections for the kind of Estimator used
Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
// Added to support incremental training.
// These are the counts formed after NaiveBayes training. Note that
// these are *not* the estimates passed to the NaiveBayes classifier;
// rather the estimates are formed from these counts.
// we could break these five fields out into a inner class.
Multinomial.Estimator[] me;
Multinomial.Estimator pe;
double docLengthNormalization = -1; // A value of -1 means don't do any document length normalization
NaiveBayes classifier;
// If this style of incremental training is successful, the following members
// should probably be moved up into IncrementalClassifierTrainer
Pipe instancePipe; // Needed to construct a new classifier
Alphabet dataAlphabet; // Extracted from InstanceList. Must be the same for all calls to incrementalTrain()
Alphabet targetAlphabet; // Extracted from InstanceList. Must be the same for all calls to incrementalTrain
public NaiveBayesTrainer (NaiveBayes initialClassifier) {
if (initialClassifier != null) {
this.instancePipe = initialClassifier.getInstancePipe();
this.dataAlphabet = initialClassifier.getAlphabet();
this.targetAlphabet = initialClassifier.getLabelAlphabet();
this.classifier = initialClassifier;
}
}
public NaiveBayesTrainer (Pipe instancePipe) {
this.instancePipe = instancePipe;
this.dataAlphabet = instancePipe.getDataAlphabet();
this.targetAlphabet = instancePipe.getTargetAlphabet();
}
public NaiveBayesTrainer () {
}
public NaiveBayes getClassifier () { return classifier; }
public NaiveBayesTrainer setDocLengthNormalization (double d) {
docLengthNormalization = d;
return this;
}
public double getDocLengthNormalization () {
return docLengthNormalization;
}
/**
* Get the MultinomialEstimator instance used to specify the type of estimator
* for features.
*
* @return estimator to be cloned on next call to train() or first call
* to incrementalTrain()
*/
public Multinomial.Estimator getFeatureMultinomialEstimator ()
{
return featureEstimator;
}
/**
* Set the Multinomial Estimator used for features. The MulitnomialEstimator
* is internally cloned and the clone is used to maintain the counts
* that will be used to generate probability estimates
* the next time train() or an initial incrementalTrain() is run.
* Defaults to a Multinomial.LaplaceEstimator()
* @param me to be cloned on next call to train() or first call
* to incrementalTrain()
*/
public NaiveBayesTrainer setFeatureMultinomialEstimator (Multinomial.Estimator me)
{
if (instancePipe != null)
throw new IllegalStateException("Can't set after incrementalTrain() is called");
featureEstimator = me;
return this;
}
/**
* Get the MultinomialEstimator instance used to specify the type of estimator
* for priors.
*
* @return estimator to be cloned on next call to train() or first call
* to incrementalTrain()
*/
public Multinomial.Estimator getPriorMultinomialEstimator ()
{
return priorEstimator;
}
/**
* Set the Multinomial Estimator used for priors. The MulitnomialEstimator
* is internally cloned and the clone is used to maintain the counts
* that will be used to generate probability estimates
* the next time train() or an initial incrementalTrain() is run.
* Defaults to a Multinomial.LaplaceEstimator()
* @param me to be cloned on next call to train() or first call
* to incrementalTrain()
*/
public NaiveBayesTrainer setPriorMultinomialEstimator (Multinomial.Estimator me)
{
if (instancePipe != null)
throw new IllegalStateException("Can't set after incrementalTrain() is called");
priorEstimator = me;
return this;
}
/**
* Create a NaiveBayes classifier from a set of training data.
* The trainer uses counts of each feature in an instance's feature vector
* to provide an estimate of p(Labeling| feature). The internal state
* of the trainer is thrown away ( by a call to reset() ) when train() returns. Each
* call to train() is completely independent of any other.
* @param trainingList The InstanceList to be used to train the classifier.
* Within each instance the data slot is an instance of FeatureVector and the
* target slot is an instance of Labeling
* @param validationList Currently unused
* @param testSet Currently unused
* @param evaluator Currently unused
* @param initialClassifier Currently unused
* @return The NaiveBayes classifier as trained on the trainingList
*/
public NaiveBayes train (InstanceList trainingList)
{
// Forget all the previous sufficient statistics counts;
me = null; pe = null;
// Train a new classifier based on this data
this.classifier = trainIncremental (trainingList);
return classifier;
}
public NaiveBayes trainIncremental (InstanceList trainingInstancesToAdd)
{
// Initialize and check instance variables as necessary...
setup(trainingInstancesToAdd, null);
// Incrementally add the counts of this new training data
for (Instance instance : trainingInstancesToAdd)
incorporateOneInstance(instance, trainingInstancesToAdd.getInstanceWeight(instance));
// Estimate multinomials, and return a new naive Bayes classifier.
// Note that, unlike MaxEnt, NaiveBayes is immutable, so we create a new one each time.
classifier = new NaiveBayes (instancePipe, pe.estimate(), estimateFeatureMultinomials());
return classifier;
}
public NaiveBayes trainIncremental (Instance instance) {
setup (null, instance);
// Incrementally add the counts of this new training instance
incorporateOneInstance (instance, 1.0);
if (instancePipe == null)
instancePipe = new Noop (dataAlphabet, targetAlphabet);
classifier = new NaiveBayes (instancePipe, pe.estimate(), estimateFeatureMultinomials());
return classifier;
}
private void setup (InstanceList instances, Instance instance) {
assert (instances != null || instance != null);
if (instance == null && instances != null)
instance = instances.get(0);
// Initialize the alphabets
if (dataAlphabet == null) {
this.dataAlphabet = instance.getDataAlphabet();
this.targetAlphabet = instance.getTargetAlphabet();
} else if (!Alphabet.alphabetsMatch(instance, this))
// Make sure the alphabets match
throw new IllegalArgumentException ("Training set alphabets do not match those of NaiveBayesTrainer.");
// Initialize or check the instancePipe
if (instances != null) {
if (instancePipe == null)
instancePipe = instances.getPipe();
else if (instancePipe != instances.getPipe())
// Make sure that this pipes match. Is this really necessary??
// I don't think so, but it could be confusing to have each returned classifier have a different pipe? -akm 1/08
throw new IllegalArgumentException ("Training set pipe does not match that of NaiveBayesTrainer.");
}
if (me == null) {
int numLabels = targetAlphabet.size();
me = new Multinomial.Estimator[numLabels];
for (int i = 0; i < numLabels; i++) {
me[i] = (Multinomial.Estimator) featureEstimator.clone();
me[i].setAlphabet(dataAlphabet);
}
pe = (Multinomial.Estimator) priorEstimator.clone();
}
if (targetAlphabet.size() > me.length) {
// target alphabet grew. increase size of our multinomial array
int targetAlphabetSize = targetAlphabet.size();
// copy over old values
Multinomial.Estimator[] newMe = new Multinomial.Estimator[targetAlphabetSize];
System.arraycopy (me, 0, newMe, 0, me.length);
// initialize new expanded space
for (int i= me.length; i<targetAlphabetSize; i++){
Multinomial.Estimator mest = (Multinomial.Estimator)featureEstimator.clone ();
mest.setAlphabet (dataAlphabet);
newMe[i] = mest;
}
me = newMe;
}
}
private void incorporateOneInstance (Instance instance, double instanceWeight)
{
Labeling labeling = instance.getLabeling ();
if (labeling == null) return; // Handle unlabeled instances by skipping them
FeatureVector fv = (FeatureVector) instance.getData ();
double oneNorm = fv.oneNorm();
if (oneNorm <= 0) return; // Skip instances that have no features present
if (docLengthNormalization > 0)
// Make the document have counts that sum to docLengthNormalization
// I.e., if 20, it would be as if the document had 20 words.
instanceWeight *= docLengthNormalization / oneNorm;
assert (instanceWeight > 0 && !Double.isInfinite(instanceWeight));
for (int lpos = 0; lpos < labeling.numLocations(); lpos++) {
int li = labeling.indexAtLocation (lpos);
double labelWeight = labeling.valueAtLocation (lpos);
if (labelWeight == 0) continue;
//System.out.println ("NaiveBayesTrainer me.increment "+ labelWeight * instanceWeight);
me[li].increment (fv, labelWeight * instanceWeight);
// This relies on labelWeight summing to 1 over all labels
pe.increment (li, labelWeight * instanceWeight);
}
}
private Multinomial[] estimateFeatureMultinomials () {
int numLabels = targetAlphabet.size();
Multinomial[] m = new Multinomial[numLabels];
for (int li = 0; li < numLabels; li++) {
//me[li].print (); // debugging
m[li] = me[li].estimate();
}
return m;
}
/**
* Create a NaiveBayes classifier from a set of training data and the
* previous state of the trainer. Subsequent calls to incrementalTrain()
* add to the state of the trainer. An incremental training session
* should consist only of calls to incrementalTrain() and have no
* calls to train(); *
* @param trainingList The InstanceList to be used to train the classifier.
* Within each instance the data slot is an instance of FeatureVector and the
* target slot is an instance of Labeling
* @param validationList Currently unused
* @param testSet Currently unused
* @param evaluator Currently unused
* @param initialClassifier Currently unused
* @return The NaiveBayes classifier as trained on the trainingList and the previous
* trainingLists passed to incrementalTrain()
*/
public String toString()
{
return "NaiveBayesTrainer";
}
// AlphabetCarrying interface
public boolean alphabetsMatch(AlphabetCarrying object) {
return Alphabet.alphabetsMatch (this, object);
}
public Alphabet getAlphabet() {
return dataAlphabet;
}
public Alphabet[] getAlphabets() {
return new Alphabet[] { dataAlphabet, targetAlphabet };
}
// Serialization
// serialVersionUID is overriden to prevent innocuous changes in this
// class from making the serialization mechanism think the external
// format has changed.
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException
{
out.writeInt(CURRENT_SERIAL_VERSION);
//default selections for the kind of Estimator used
out.writeObject(featureEstimator);
out.writeObject(priorEstimator);
// These are the counts formed after NaiveBayes training.
out.writeObject(me);
out.writeObject(pe);
// pipe and alphabets
out.writeObject(instancePipe);
out.writeObject(dataAlphabet);
out.writeObject(targetAlphabet);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
if (version != CURRENT_SERIAL_VERSION)
throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted " +
CURRENT_SERIAL_VERSION + ", got " +
version);
//default selections for the kind of Estimator used
featureEstimator = (Multinomial.Estimator) in.readObject();
priorEstimator = (Multinomial.Estimator) in.readObject();
// These are the counts formed after NaiveBayes training.
me = (Multinomial.Estimator []) in.readObject();
pe = (Multinomial.Estimator) in.readObject();
// pipe and alphabets
instancePipe = (Pipe) in.readObject();
dataAlphabet = (Alphabet) in.readObject();
targetAlphabet = (Alphabet) in.readObject();
}
public static class Factory extends ClassifierTrainer.Factory<NaiveBayesTrainer>
{
Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
double docLengthNormalization = -1;
public NaiveBayesTrainer newClassifierTrainer(Classifier initialClassifier) {
return new NaiveBayesTrainer ((NaiveBayes)initialClassifier);
}
public NaiveBayesTrainer.Factory setDocLengthNormalization (double docLengthNormalization) {
this.docLengthNormalization = docLengthNormalization;
return this;
}
public NaiveBayesTrainer.Factory setFeatureMultinomialEstimator (Multinomial.Estimator featureEstimator) {
this.featureEstimator = featureEstimator;
return this;
}
public NaiveBayesTrainer.Factory setPriorMultinomialEstimator (Multinomial.Estimator priorEstimator) {
this.priorEstimator = priorEstimator;
return this;
}
}
}