/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
/**
* Stores the results of classifying a collection of Instances,
* and provides many methods for evaluating the results.
*
* If you just need one evaluation result, you may find it easier to one
* of the corresponding methods in Classifier, which simply call the methods here.
*
* @see InstanceList
* @see Classifier
* @see Classification
*
* @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
public class Trial extends ArrayList<Classification>
{
private static Logger logger = Logger.getLogger(Trial.class.getName());
Classifier classifier;
public Trial (Classifier c, InstanceList ilist)
{
super (ilist.size());
this.classifier = c;
for (Instance instance : ilist)
this.add (c.classify (instance));
}
public boolean add (Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
return super.add (c);
}
public void add (int index, Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
super.add (index, c);
}
public boolean addAll(Collection<? extends Classification> collection) {
boolean ret = true;
for (Classification c : collection)
if (!this.add(c))
ret = false;
return ret;
}
public boolean addAll (int index, Collection<? extends Classification> collection) {
throw new IllegalStateException ("Not implemented.");
}
public Classifier getClassifier ()
{
return classifier;
}
/** Return the fraction of instances that have the correct label as their best predicted label. */
public double getAccuracy ()
{
int numCorrect = 0;
for (int i = 0; i < this.size(); i++)
if (this.get(i).bestLabelIsCorrect())
numCorrect++;
return (double)numCorrect/this.size();
}
/** Calculate the precision of the classifier on an instance list for a
particular target entry */
public double getPrecision (Object labelEntry)
{
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
else
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getPrecision (index);
}
public double getPrecision (Labeling label)
{
return getPrecision (label.getBestIndex());
}
/** Calculate the precision for a particular target index from an
array list of classifications */
public double getPrecision (int index)
{
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i<this.size(); i++) {
trueLabel = this.get(i).getInstance().getLabeling().getBestIndex();
classLabel = this.get(i).getLabeling().getBestIndex();
if (classLabel == index) {
numInstances++;
if (trueLabel == index)
numCorrect++;
}
}
// gdruck@cs.umass.edu
// When no examples are predicted to have this label,
// we define precision to be 1.
if (numInstances==0) {
logger.warning("No examples with predicted label " +
classifier.getLabelAlphabet().lookupLabel(index) + "!");
assert(numCorrect == 0);
return 1;
}
return ((double)numCorrect/(double)numInstances);
}
/** Calculate the recall of the classifier on an instance list for a
particular target entry */
public double getRecall (Object labelEntry)
{
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
else
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getRecall (index);
}
public double getRecall (Labeling label)
{
return getRecall (label.getBestIndex());
}
/** Calculate the recall for a particular target index from an
array list of classifications */
public double getRecall (int labelIndex)
{
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i<this.size(); i++) {
trueLabel = this.get(i).getInstance().getLabeling().getBestIndex();
classLabel = this.get(i).getLabeling().getBestIndex();
if ( trueLabel == labelIndex ) {
numInstances++;
if ( classLabel == labelIndex)
numCorrect++;
}
}
// gdruck@cs.umass.edu
// When no examples have this label,
// we define recall to be 1.
if (numInstances==0) {
logger.warning("No examples with true label " +
classifier.getLabelAlphabet().lookupLabel(labelIndex) + "!");
assert(numCorrect == 0);
return 1;
}
return ((double)numCorrect/(double)numInstances);
}
/** Calculate the F1-measure of the classifier on an instance list for a
particular target entry */
public double getF1 (Object labelEntry)
{
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
else
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getF1 (index);
}
public double getF1 (Labeling label)
{
return getF1 (label.getBestIndex());
}
/** Calculate the F1-measure for a particular target index from an
array list of classifications */
public double getF1 (int index)
{
double precision = getPrecision (index);
double recall = getRecall (index);
// gdruck@cs.umass.edu
// When both precision and recall are 0, F1 is 0.
if (precision==0.0 && recall==0.0) {
return 0;
}
return 2*precision*recall/(precision+recall);
}
/** Return the average rank of the correct class label as returned by Labeling.getRank(correctLabel) on the predicted Labeling. */
public double getAverageRank ()
{
double rsum = 0;
Labeling tmpL;
Classification tmpC;
Instance tmpI;
Label tmpLbl, tmpLbl2;
int tmpInt;
for(int i = 0; i < this.size(); i++) {
tmpC = this.get(i);
tmpI = tmpC.getInstance();
tmpL = tmpC.getLabeling();
tmpLbl = (Label)tmpI.getTarget();
tmpInt = tmpL.getRank(tmpLbl);
tmpLbl2 = tmpL.getLabelAtRank(0);
rsum = rsum + tmpInt;
}
return rsum/this.size();
}
}