package edu.stanford.nlp.ie.machinereading;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import edu.stanford.nlp.ie.machinereading.structure.EntityMention;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.StringUtils;
public class EntityExtractorResultsPrinter extends ResultsPrinter {
/** Contains a set of labels that should be excluded from scoring */
private Set<String> excludedClasses;
/** Use subtypes for scoring or just types? */
private boolean useSubTypes;
private boolean verbose;
private boolean verboseInstances;
private static final DecimalFormat FORMATTER = new DecimalFormat();
static {
FORMATTER.setMaximumFractionDigits(1);
FORMATTER.setMinimumFractionDigits(1);
}
public EntityExtractorResultsPrinter() {
this(null, false);
}
protected EntityExtractorResultsPrinter(Set<String> excludedClasses, boolean useSubTypes) {
this.excludedClasses = excludedClasses;
this.useSubTypes = useSubTypes;
this.verbose = true;
this.verboseInstances = true;
}
@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard,
List<CoreMap> extractorOutput) {
ResultsPrinter.align(goldStandard, extractorOutput);
Counter<String> correct = new ClassicCounter<String>();
Counter<String> predicted = new ClassicCounter<String>();
Counter<String> gold = new ClassicCounter<String>();
for (int i = 0; i < goldStandard.size(); i++) {
CoreMap goldSent = goldStandard.get(i);
CoreMap sysSent = extractorOutput.get(i);
String sysText = sysSent.get(TextAnnotation.class);
String goldText = goldSent.get(TextAnnotation.class);
if (verbose) {
System.err.println("SCORING THE FOLLOWING SENTENCE:");
System.err.println(sysSent.get(CoreAnnotations.TokensAnnotation.class));
}
HashSet<String> matchedGolds = new HashSet<String>();
List<EntityMention> goldEntities = goldSent
.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
if (goldEntities == null) {
goldEntities = new ArrayList<EntityMention>();
}
for (EntityMention m : goldEntities) {
String label = makeLabel(m);
if (excludedClasses != null && excludedClasses.contains(label))
continue;
gold.incrementCount(label);
}
List<EntityMention> sysEntities = sysSent
.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
if (sysEntities == null) {
sysEntities = new ArrayList<EntityMention>();
}
for (EntityMention m : sysEntities) {
String label = makeLabel(m);
if (excludedClasses != null && excludedClasses.contains(label))
continue;
predicted.incrementCount(label);
if (verbose)
System.err.println("COMPARING PREDICTED MENTION: " + m);
boolean found = false;
for (EntityMention gm : goldEntities) {
if (matchedGolds.contains(gm.getObjectId()))
continue;
if (verbose)
System.err.println("\tagainst: " + gm);
if(gm.equals(m, useSubTypes)){
if (verbose) System.err.println("\t\t\tMATCH!");
found = true;
matchedGolds.add(gm.getObjectId());
if(verboseInstances){
System.err.println("TRUE POSITIVE: " + m + " matched " + gm);
System.err.println("In sentence: " + sysText);
}
break;
}
}
if (found) {
correct.incrementCount(label);
} else if(verboseInstances){
System.err.println("FALSE POSITIVE: " + m.toString());
System.err.println("In sentence: " + sysText);
}
}
if (verboseInstances) {
for (EntityMention m : goldEntities) {
String label = makeLabel(m);
if (!matchedGolds.contains(m.getObjectId())
&& (excludedClasses == null || !excludedClasses.contains(label))) {
System.err.println("FALSE NEGATIVE: " + m.toString());
System.err.println("In sentence: " + goldText);
}
}
}
}
double totalCount = 0;
double totalCorrect = 0;
double totalPredicted = 0;
pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
List<String> labels = new ArrayList<String>(gold.keySet());
Collections.sort(labels);
for (String label : labels) {
if (excludedClasses != null && excludedClasses.contains(label))
continue;
double numCorrect = correct.getCount(label);
double numPredicted = predicted.getCount(label);
double trueCount = gold.getCount(label);
double precision = (numPredicted > 0) ? (numCorrect / numPredicted) : 0;
double recall = numCorrect / trueCount;
double f = (precision + recall > 0) ? 2 * precision * recall
/ (precision + recall) : 0.0;
pw.println(StringUtils.padOrTrim(label, 21) + "\t" + numCorrect + "\t"
+ numPredicted + "\t" + trueCount + "\t"
+ FORMATTER.format(precision * 100) + "\t"
+ FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));
totalCount += trueCount;
totalCorrect += numCorrect;
totalPredicted += numPredicted;
}
double precision = (totalPredicted > 0) ? (totalCorrect / totalPredicted)
: 0;
double recall = totalCorrect / totalCount;
double f = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision
* recall / (precision + recall) : 0.0;
pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t"
+ totalCount + "\t" + FORMATTER.format(100 * precision) + "\t"
+ FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));
}
private String makeLabel(EntityMention m) {
String label = m.getType();
if (useSubTypes && m.getSubType() != null)
label += "-" + m.getSubType();
return label;
}
public void printResultsUsingLabels(PrintWriter pw,
List<String> goldStandard,
List<String> extractorOutput) {}
}