/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Multinomial;
import cc.mallet.util.MalletLogger;
/**
* @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
public class NaiveBayesEMTrainer extends ClassifierTrainer<NaiveBayes> {
private static Logger logger = MalletLogger.getLogger(MCMaxEntTrainer.class.getName());
Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
double docLengthNormalization = -1;
double unlabeledDataWeight = 1.0;
int iteration = 0;
NaiveBayesTrainer.Factory nbTrainer;
NaiveBayes classifier;
public NaiveBayesEMTrainer () {
nbTrainer = new NaiveBayesTrainer.Factory ();
nbTrainer.setDocLengthNormalization(docLengthNormalization);
nbTrainer.setFeatureMultinomialEstimator(featureEstimator);
nbTrainer.setPriorMultinomialEstimator (priorEstimator);
}
public Multinomial.Estimator getFeatureMultinomialEstimator () {
return featureEstimator;
}
public void setFeatureMultinomialEstimator (Multinomial.Estimator me) {
featureEstimator = me;
nbTrainer.setFeatureMultinomialEstimator(featureEstimator);
}
public Multinomial.Estimator getPriorMultinomialEstimator () {
return priorEstimator;
}
public void setPriorMultinomialEstimator (Multinomial.Estimator me) {
priorEstimator = me;
nbTrainer.setPriorMultinomialEstimator(priorEstimator);
}
public void setDocLengthNormalization (double d) {
docLengthNormalization = d;
nbTrainer.setDocLengthNormalization(docLengthNormalization);
}
public double getDocLengthNormalization () {
return docLengthNormalization;
}
public double getUnlabeledDataWeight () {
return unlabeledDataWeight;
}
public void setUnlabeledDataWeight (double unlabeledDataWeight) {
this.unlabeledDataWeight = unlabeledDataWeight;
}
public int getIteration() { return iteration; }
public boolean isFinishedTraining() { return false; }
public NaiveBayes getClassifier() { return classifier; }
public NaiveBayes train (InstanceList trainingSet)
{
// Get a classifier trained on the labeled examples only
NaiveBayes c = (NaiveBayes) nbTrainer.newClassifierTrainer().train (trainingSet);
double prevLogLikelihood = 0, logLikelihood = 0;
boolean converged = false;
int iteration = 0;
while (!converged) {
// Make a new trainingSet that has some labels set
InstanceList trainingSet2 = new InstanceList (trainingSet.getPipe());
for (int ii = 0; ii < trainingSet.size(); ii++) {
Instance inst = trainingSet.get(ii);
if (inst.getLabeling() != null)
trainingSet2.add(inst, 1.0);
else {
Instance inst2 = inst.shallowCopy();
inst2.unLock();
inst2.setLabeling(c.classify(inst).getLabeling());
inst2.lock();
trainingSet2.add(inst2, unlabeledDataWeight);
}
}
c = (NaiveBayes) nbTrainer.newClassifierTrainer().train (trainingSet2);
logLikelihood = c.dataLogLikelihood (trainingSet2);
System.err.println ("Loglikelihood = "+logLikelihood);
// Wait for a change in log-likelihood of less than 0.01% and at least 10 iterations
if (Math.abs((logLikelihood - prevLogLikelihood)/logLikelihood) < 0.0001)
converged = true;
prevLogLikelihood = logLikelihood;
iteration++;
}
return c;
}
public String toString()
{
String ret = "NaiveBayesEMTrainer";
if (docLengthNormalization != 1.0) ret += ",docLengthNormalization="+docLengthNormalization;
if (unlabeledDataWeight != 1.0) ret += ",unlabeledDataWeight="+unlabeledDataWeight;
return ret;
}
// Serialization
// serialVersionUID is overriden to prevent innocuous changes in this
// class from making the serialization mechanism think the external
// format has changed.
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException
{
out.writeInt(CURRENT_SERIAL_VERSION);
//default selections for the kind of Estimator used
out.writeObject(featureEstimator);
out.writeObject(priorEstimator);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
if (version != CURRENT_SERIAL_VERSION)
throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted " +
CURRENT_SERIAL_VERSION + ", got " +
version);
//default selections for the kind of Estimator used
featureEstimator = (Multinomial.Estimator) in.readObject();
priorEstimator = (Multinomial.Estimator) in.readObject();
}
}