/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.pipe;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;
import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.util.MalletLogger;
* This pipe uses a Classifier to label each token (i.e., using 0-th order Markov assumption),
* then adds the predictions as features to each token.
* This pipe assumes the input Instance's data is of type FeatureVectorSequence
* (each an augmentable feature vector).
* Example usage:<pre>
* 1) Create and serialize a featurePipe that converts raw input to FeatureVectorSequences
* 2) Pipe input data through featurePipe, train a TokenClassifiers via cross validation, then serialize the classifiers
* 2) Pipe input data through featurePipe and this pipe (using the saved classifiers), and train a Transducer
* 4) Serialize the trained Transducer
* </pre>
* @author ghuang
public class AddClassifierTokenPredictions extends Pipe implements Serializable
private static Logger logger = MalletLogger.getLogger(AddClassifierTokenPredictions.class.getName());
// Specify which predictions are to be added as features.
// E.g., { 1, 2 } = add labels of the top 2 highest-scoring predictions as features.
int[] m_predRanks2add;
// The trained token classifier
TokenClassifiers m_tokenClassifiers;
// Whether to treat each instance's feature values as binary
boolean m_binary;
// Whether the pipe is currently being used at production time
// (i.e., not being used as pipeline for training a transducer)
boolean m_inProduction;
// Augmented data alphabet that includes the class predictions
Alphabet m_dataAlphabet;
public AddClassifierTokenPredictions(InstanceList trainList)
this(trainList, null);
public AddClassifierTokenPredictions(InstanceList trainList, InstanceList testList)
this(new TokenClassifiers(convert(trainList, (Noop) trainList.getPipe())), new int[] { 1 }, true,
convert(testList, (Noop) trainList.getPipe()));
public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add,
boolean binary, InstanceList testList)
m_predRanks2add = predRanks2add;
m_binary = binary;
m_tokenClassifiers = tokenClassifiers;
m_inProduction = false;
m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();
Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
// add the token prediction features to the alphabet
for (int i = 0; i < m_predRanks2add.length; i++) {
for (int j = 0; j < labelAlphabet.size(); j++) {
String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "_@_RANK_" + m_predRanks2add[i];
m_dataAlphabet.lookupIndex(featName, true);
// evaluate token classifier
if (testList != null) {
Trial trial = new Trial(m_tokenClassifiers, testList);
logger.info("Token classifier accuracy on test set = " + trial.getAccuracy());
public void setInProduction(boolean inProduction) { m_inProduction = inProduction; }
public boolean getInProduction() { return m_inProduction; }
public static void setInProduction(Pipe p, boolean value)
if (p instanceof AddClassifierTokenPredictions)
((AddClassifierTokenPredictions) p).setInProduction(value);
else if (p instanceof SerialPipes) {
SerialPipes sp = (SerialPipes) p;
for (int i = 0; i < sp.size(); i++)
setInProduction(sp.getPipe(i), value);
public Alphabet getDataAlphabet() { return m_dataAlphabet; }
* Add the token classifier's predictions as features to the instance.
* This method assumes the input instance contains FeatureVectorSequence as data
public Instance pipe(Instance carrier)
FeatureVectorSequence fvs = (FeatureVectorSequence) carrier.getData();
InstanceList ilist = convert(carrier, (Noop) m_tokenClassifiers.getInstancePipe());
assert (fvs.size() == ilist.size());
// For passing instances to the token classifier, each instance's data alphabet needs to
// match that used by the token classifier at training time. For the resulting piped
// instance, each instance's data alphabet needs to contain token classifier's prediction
// as features
FeatureVector[] fva = new FeatureVector[fvs.size()];
for (int i = 0; i < ilist.size(); i++) {
Instance inst = ilist.get(i);
Classification c = m_tokenClassifiers.classify(inst, ! m_inProduction);
LabelVector lv = c.getLabelVector();
AugmentableFeatureVector afv1 = (AugmentableFeatureVector) inst.getData();
int[] indices = afv1.getIndices();
AugmentableFeatureVector afv2 = new AugmentableFeatureVector(m_dataAlphabet,
indices, afv1.getValues(), indices.length + m_predRanks2add.length);
for (int j = 0; j < m_predRanks2add.length; j++) {
Label label = lv.getLabelAtRank(m_predRanks2add[j]);
int idx = m_dataAlphabet.lookupIndex("TOK_PRED=" + label.toString() + "_@_RANK_" + m_predRanks2add[j]);
assert(idx >= 0);
afv2.add(idx, 1);
fva[i] = afv2;
carrier.setData(new FeatureVectorSequence(fva));
return carrier;
* Converts each instance containing a FeatureVectorSequence to multiple instances,
* each containing an AugmentableFeatureVector as data.
* @param ilist Instances with FeatureVectorSequence as data field
* @param alphabetsPipe a Noop pipe containing the data and target alphabets for the resulting InstanceList
* @return an InstanceList where each Instance contains one Token's AugmentableFeatureVector as data
public static InstanceList convert(InstanceList ilist, Noop alphabetsPipe)
if (ilist == null) return null;
// This monstrosity is necessary b/c Classifiers obtain the data/target alphabets via pipes
InstanceList ret = new InstanceList(alphabetsPipe);
for (Instance inst : ilist)
//for (int i = 0; i < ilist.size(); i++) ret.add(convert(ilist.get(i), alphabetsPipe));
return ret;
* @param inst input instance, with FeatureVectorSequence as data.
* @param alphabetsPipe a Noop pipe containing the data and target alphabets for
* the resulting InstanceList and AugmentableFeatureVectors
* @return list of instances, each with one AugmentableFeatureVector as data
public static InstanceList convert(Instance inst, Noop alphabetsPipe)
InstanceList ret = new InstanceList(alphabetsPipe);
Object obj = inst.getData();
assert(obj instanceof FeatureVectorSequence);
FeatureVectorSequence fvs = (FeatureVectorSequence) obj;
LabelSequence ls = (LabelSequence) inst.getTarget();
assert(fvs.size() == ls.size());
Object instName = (inst.getName() == null ? "NONAME" : inst.getName());
for (int j = 0; j < fvs.size(); j++) {
FeatureVector fv = fvs.getFeatureVector(j);
int[] indices = fv.getIndices();
FeatureVector data = new AugmentableFeatureVector (alphabetsPipe.getDataAlphabet(),
indices, fv.getValues(), indices.length);
Labeling target = ls.getLabelAtPosition(j);
String name = instName.toString() + "_@_POS_" + (j + 1);
Object source = inst.getSource();
Instance toAdd = alphabetsPipe.pipe(new Instance(data, target, name, source));
return ret;
// Serialization
private static final long serialVersionUID = 1;
* This inner class represents the trained token classifiers.
* @author ghuang
public static class TokenClassifiers extends Classifier implements Serializable
// number of folds in cross-validation training
int m_numCV;
// random seed to split training data for cross-validation
int m_randSeed;
// trainer for token classifier
ClassifierTrainer m_trainer;
// token classifier trained on the entirety of the training set
Classifier m_tokenClassifier;
// table storing instance name --> out-of-fold classifier
// Used to prevent overfitting to the token classifier's predictions
HashMap m_table;
* Train a token classifier using the given Instances with 5-fold cross validation
* @param trainList training instances
public TokenClassifiers(InstanceList trainList)
this(trainList, 0, 5);
public TokenClassifiers(InstanceList trainList, int randSeed, int numCV)
// this(new AdaBoostM2Trainer(new DecisionTreeTrainer(2), 10), trainList, randSeed, numCV);
// this(new NaiveBayesTrainer(), trainList, randSeed, numCV);
this(new BalancedWinnowTrainer(), trainList, randSeed, numCV);
// this(new SVMTrainer(), trainList, randSeed, numCV);
public TokenClassifiers(ClassifierTrainer trainer, InstanceList trainList, int randSeed, int numCV)
m_trainer = trainer;
m_randSeed = randSeed;
m_numCV = numCV;
m_table = new HashMap();
// train the token classifier
private void doTraining(InstanceList trainList)
// train a classifier on the entire training set
logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");
m_tokenClassifier = m_trainer.train(trainList);
Trial t = new Trial(m_tokenClassifier, trainList);
logger.info("Training set accuracy = " + t.getAccuracy());
if (m_numCV == 0)
// train classifiers using cross validation
InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed);
int f = 1;
while (cvIter.hasNext()) {
InstanceList[] fold = cvIter.nextSplit();
logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")...");
Classifier foldClassifier = m_trainer.train(fold[0]);
Trial t1 = new Trial(foldClassifier, fold[0]);
Trial t2 = new Trial(foldClassifier, fold[1]);
logger.info("Within-fold accuracy = " + t1.getAccuracy());
logger.info("Out-of-fold accuracy = " + t2.getAccuracy());
/*for (int x = 0; x < t2.size(); x++) {
logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling());
for (int i = 0; i < fold[1].size(); i++) {
Instance inst = fold[1].get(i);
m_table.put(inst.getName(), foldClassifier);
public Classification classify(Instance instance)
return classify(instance, false);
* @param instance the instance to classify
* @param useOutOfFold whether to check the instance name and use the out-of-fold classifier
* if the instance name matches one in the training data
* @return the token classifier's output
public Classification classify(Instance instance, boolean useOutOfFold)
Object instName = instance.getName();
if (! useOutOfFold || ! m_table.containsKey(instName))
return m_tokenClassifier.classify(instance);
Classifier classifier = (Classifier) m_table.get(instName);
return classifier.classify(instance);
// serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
throw new ClassNotFoundException("Mismatched TokenClassifiers versions: wanted " +
instancePipe = (Pipe) in.readObject();
m_numCV = in.readInt();
m_randSeed = in.readInt();
m_table = (HashMap) in.readObject();
m_tokenClassifier = (Classifier) in.readObject();
m_trainer = (ClassifierTrainer) in.readObject();