/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
* Stores the results of classifying a collection of Instances,
* and provides many methods for evaluating the results.
* If you just need one evaluation result, you may find it easier to one
* of the corresponding methods in Classifier, which simply call the methods here.
* @see InstanceList
* @see Classifier
* @see Classification
* @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
public class Trial extends ArrayList<Classification>
private static Logger logger = Logger.getLogger(Trial.class.getName());
Classifier classifier;
public Trial (Classifier c, InstanceList ilist)
super (ilist.size());
this.classifier = c;
for (Instance instance : ilist)
this.add (c.classify (instance));
public boolean add (Classification c)
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
return super.add (c);
public void add (int index, Classification c)
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
super.add (index, c);
public boolean addAll(Collection<? extends Classification> collection) {
boolean ret = true;
for (Classification c : collection)
if (!this.add(c))
ret = false;
return ret;
public boolean addAll (int index, Collection<? extends Classification> collection) {
throw new IllegalStateException ("Not implemented.");
public Classifier getClassifier ()
return classifier;
/** Return the fraction of instances that have the correct label as their best predicted label. */
public double getAccuracy ()
int numCorrect = 0;
for (int i = 0; i < this.size(); i++)
if (this.get(i).bestLabelIsCorrect())
return (double)numCorrect/this.size();
/** Calculate the precision of the classifier on an instance list for a
particular target entry */
public double getPrecision (Object labelEntry)
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getPrecision (index);
public double getPrecision (Labeling label)
return getPrecision (label.getBestIndex());
/** Calculate the precision for a particular target index from an
array list of classifications */
public double getPrecision (int index)
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i<this.size(); i++) {
trueLabel = this.get(i).getInstance().getLabeling().getBestIndex();
classLabel = this.get(i).getLabeling().getBestIndex();
if (classLabel == index) {
if (trueLabel == index)
// gdruck@cs.umass.edu
// When no examples are predicted to have this label,
// we define precision to be 1.
if (numInstances==0) {
logger.warning("No examples with predicted label " +
classifier.getLabelAlphabet().lookupLabel(index) + "!");
assert(numCorrect == 0);
return 1;
return ((double)numCorrect/(double)numInstances);
/** Calculate the recall of the classifier on an instance list for a
particular target entry */
public double getRecall (Object labelEntry)
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getRecall (index);
public double getRecall (Labeling label)
return getRecall (label.getBestIndex());
/** Calculate the recall for a particular target index from an
array list of classifications */
public double getRecall (int labelIndex)
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i<this.size(); i++) {
trueLabel = this.get(i).getInstance().getLabeling().getBestIndex();
classLabel = this.get(i).getLabeling().getBestIndex();
if ( trueLabel == labelIndex ) {
if ( classLabel == labelIndex)
// gdruck@cs.umass.edu
// When no examples have this label,
// we define recall to be 1.
if (numInstances==0) {
logger.warning("No examples with true label " +
classifier.getLabelAlphabet().lookupLabel(labelIndex) + "!");
assert(numCorrect == 0);
return 1;
return ((double)numCorrect/(double)numInstances);
/** Calculate the F1-measure of the classifier on an instance list for a
particular target entry */
public double getF1 (Object labelEntry)
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getF1 (index);
public double getF1 (Labeling label)
return getF1 (label.getBestIndex());
/** Calculate the F1-measure for a particular target index from an
array list of classifications */
public double getF1 (int index)
double precision = getPrecision (index);
double recall = getRecall (index);
// gdruck@cs.umass.edu
// When both precision and recall are 0, F1 is 0.
if (precision==0.0 && recall==0.0) {
return 0;
return 2*precision*recall/(precision+recall);
/** Return the average rank of the correct class label as returned by Labeling.getRank(correctLabel) on the predicted Labeling. */
public double getAverageRank ()
double rsum = 0;
Labeling tmpL;
Classification tmpC;
Instance tmpI;
Label tmpLbl, tmpLbl2;
int tmpInt;
for(int i = 0; i < this.size(); i++) {
tmpC = this.get(i);
tmpI = tmpC.getInstance();
tmpL = tmpC.getLabeling();
tmpLbl = (Label)tmpI.getTarget();
tmpInt = tmpL.getRank(tmpLbl);
tmpLbl2 = tmpL.getLabelAtRank(0);
rsum = rsum + tmpInt;
return rsum/this.size();