package edu.stanford.nlp.parser.metrics;
import java.io.File;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.CollinsDependency;
import edu.stanford.nlp.trees.CollinsRelation;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
/**
* Compute P/R/F1 for the dependency representation of Collins (1999; 2003).
*
* @author Spence Green
*
*/
public class CollinsDepEval extends AbstractEval {
private static final boolean DEBUG = false;
private final HeadFinder hf;
private final String startSymbol;
private final Counter<CollinsRelation> precisions;
private final Counter<CollinsRelation> recalls;
private final Counter<CollinsRelation> f1s;
private final Counter<CollinsRelation> precisions2;
private final Counter<CollinsRelation> recalls2;
private final Counter<CollinsRelation> pnums2;
private final Counter<CollinsRelation> rnums2;
public CollinsDepEval(String str, boolean runningAverages, HeadFinder hf, String startSymbol) {
super(str,runningAverages);
this.hf = hf;
this.startSymbol = startSymbol;
precisions = new ClassicCounter<CollinsRelation>();
recalls = new ClassicCounter<CollinsRelation>();
f1s = new ClassicCounter<CollinsRelation>();
precisions2 = new ClassicCounter<CollinsRelation>();
recalls2 = new ClassicCounter<CollinsRelation>();
pnums2 = new ClassicCounter<CollinsRelation>();
rnums2 = new ClassicCounter<CollinsRelation>();
}
@Override
protected Set<?> makeObjects(Tree tree) {
System.err.println(this.getClass().getName() + ": Function makeObjects() not implemented");
return null;
}
private Map<CollinsRelation,Set<CollinsDependency>> makeCollinsObjects(Tree t) {
final Map<CollinsRelation,Set<CollinsDependency>> relMap = Generics.newHashMap();
final Set<CollinsDependency> deps = CollinsDependency.extractNormalizedFromTree(t, startSymbol, hf);
for(CollinsDependency dep : deps) {
if(DEBUG) System.out.println(dep.toString());
if(relMap.get(dep.getRelation()) == null)
relMap.put(dep.getRelation(), Generics.<CollinsDependency>newHashSet());
relMap.get(dep.getRelation()).add(dep);
}
if(DEBUG) System.out.println();
return relMap;
}
@Override
public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
if(gold == null || guess == null) {
System.err.printf("%s: Cannot compare against a null gold or guess tree!\n",this.getClass().getName());
return;
}
if(DEBUG) System.out.println("guess:");
Map<CollinsRelation,Set<CollinsDependency>> guessDeps = makeCollinsObjects(guess);
if(DEBUG) System.out.println("gold:");
Map<CollinsRelation,Set<CollinsDependency>> goldDeps = makeCollinsObjects(gold);
Set<CollinsRelation> relations = Generics.newHashSet();
relations.addAll(guessDeps.keySet());
relations.addAll(goldDeps.keySet());
num += 1.0;
for (CollinsRelation rel : relations) {
Set<CollinsDependency> thisGuessDeps = guessDeps.get(rel);
Set<CollinsDependency> thisGoldDeps = goldDeps.get(rel);
if (thisGuessDeps == null)
thisGuessDeps = Generics.newHashSet();
if (thisGoldDeps == null)
thisGoldDeps = Generics.newHashSet();
double currentPrecision = precision(thisGuessDeps, thisGoldDeps);
double currentRecall = precision(thisGoldDeps, thisGuessDeps);
double currentF1 = (currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0);
precisions.incrementCount(rel, currentPrecision);
recalls.incrementCount(rel, currentRecall);
f1s.incrementCount(rel, currentF1);
precisions2.incrementCount(rel, thisGuessDeps.size() * currentPrecision);
pnums2.incrementCount(rel, thisGuessDeps.size());
recalls2.incrementCount(rel, thisGoldDeps.size() * currentRecall);
rnums2.incrementCount(rel, thisGoldDeps.size());
if (pw != null && runningAverages) {
pw.println(rel + "\tP: " + ((int) (currentPrecision * 10000)) / 100.0 + " (sent ave " + ((int) (precisions.getCount(rel) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (precisions2.getCount(rel) * 10000 / pnums2.getCount(rel))) / 100.0 + ")");
pw.println("\tR: " + ((int) (currentRecall * 10000)) / 100.0 + " (sent ave " + ((int) (recalls.getCount(rel) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (recalls2.getCount(rel) * 10000 / rnums2.getCount(rel))) / 100.0 + ")");
double cF1 = 2.0 / (rnums2.getCount(rel) / recalls2.getCount(rel) + pnums2.getCount(rel) / precisions2.getCount(rel));
String emit = str + " F1: " + ((int) (currentF1 * 10000)) / 100.0 + " (sent ave " + ((int) (10000 * f1s.getCount(rel) / num)) / 100.0 + ", evalb " + ((int) (10000 * cF1)) / 100.0 + ")";
pw.println(emit);
}
}
if (pw != null && runningAverages) {
pw.println("================================================================================");
}
}
@Override
public void display(boolean verbose, PrintWriter pw) {
final NumberFormat nf = new DecimalFormat("0.00");
final Set<CollinsRelation> cats = Generics.newHashSet();
final Random rand = new Random();
cats.addAll(precisions.keySet());
cats.addAll(recalls.keySet());
Map<Double,CollinsRelation> f1Map = new TreeMap<Double,CollinsRelation>();
for (CollinsRelation cat : cats) {
double pnum2 = pnums2.getCount(cat);
double rnum2 = rnums2.getCount(cat);
double prec = precisions2.getCount(cat) / pnum2;//(num > 0.0 ? precision/num : 0.0);
double rec = recalls2.getCount(cat) / rnum2;//(num > 0.0 ? recall/num : 0.0);
double f1 = 2.0 / (1.0 / prec + 1.0 / rec);//(num > 0.0 ? f1/num : 0.0);
if(new Double(f1).equals(Double.NaN)) f1 = -1.0;
if(f1Map.containsKey(f1))
f1Map.put(f1 + (rand.nextDouble()/1000.0), cat);
else
f1Map.put(f1, cat);
}
pw.println(" Abstract Collins Dependencies -- final statistics");
pw.println("================================================================================");
for (CollinsRelation cat : f1Map.values()) {
double pnum2 = pnums2.getCount(cat);
double rnum2 = rnums2.getCount(cat);
double prec = precisions2.getCount(cat) / pnum2;//(num > 0.0 ? precision/num : 0.0);
double rec = recalls2.getCount(cat) / rnum2;//(num > 0.0 ? recall/num : 0.0);
double f1 = 2.0 / (1.0 / prec + 1.0 / rec);//(num > 0.0 ? f1/num : 0.0);
pw.println(cat + "\tLP: " + ((pnum2 == 0.0) ? " N/A": nf.format(prec)) + "\tguessed: " + (int) pnum2 +
"\tLR: " + ((rnum2 == 0.0) ? " N/A": nf.format(rec)) + "\tgold: " + (int) rnum2 +
"\tF1: " + ((pnum2 == 0.0 || rnum2 == 0.0) ? " N/A": nf.format(f1)));
}
pw.println("================================================================================");
}
private final static int MIN_ARGS = 2;
private static String usage() {
StringBuilder usage = new StringBuilder();
String nl = System.getProperty("line.separator");
usage.append(String.format("Usage: java %s [OPTS] goldFile guessFile%n%n",CollinsDepEval.class.getName()));
usage.append("Options:").append(nl);
usage.append(" -v : Verbose output").append(nl);
usage.append(" -l lang : Language name " + Languages.listOfLanguages()).append(nl);
usage.append(" -y num : Max yield of gold trees").append(nl);
usage.append(" -g num : Max yield of guess trees").append(nl);
return usage.toString();
}
private static Map<String,Integer> optionArgDefs() {
Map<String,Integer> optionArgDefs = Generics.newHashMap();
optionArgDefs.put("v", 0);
optionArgDefs.put("l", 1);
optionArgDefs.put("g", 1);
optionArgDefs.put("y", 1);
return optionArgDefs;
}
/**
*
* @param args
*/
public static void main(String[] args) {
if(args.length < MIN_ARGS) {
System.err.println(usage());
System.exit(-1);
}
Properties options = StringUtils.argsToProperties(args, optionArgDefs());
boolean VERBOSE = PropertiesUtils.getBool(options, "v", false);
Language LANGUAGE = PropertiesUtils.get(options, "l", Language.English, Language.class);
int MAX_GOLD_YIELD = PropertiesUtils.getInt(options, "g", Integer.MAX_VALUE);
int MAX_GUESS_YIELD = PropertiesUtils.getInt(options, "y", Integer.MAX_VALUE);
String[] parsedArgs = options.getProperty("","").split("\\s+");
if (parsedArgs.length != MIN_ARGS) {
System.err.println(usage());
System.exit(-1);
}
File goldFile = new File(parsedArgs[0]);
File guessFile = new File(parsedArgs[1]);
final TreebankLangParserParams tlpp = Languages.getLanguageParams(LANGUAGE);
final PrintWriter pwOut = tlpp.pw();
final Treebank guessTreebank = tlpp.diskTreebank();
guessTreebank.loadPath(guessFile);
pwOut.println("GUESS TREEBANK:");
pwOut.println(guessTreebank.textualSummary());
final Treebank goldTreebank = tlpp.diskTreebank();
goldTreebank.loadPath(goldFile);
pwOut.println("GOLD TREEBANK:");
pwOut.println(goldTreebank.textualSummary());
final CollinsDepEval depEval = new CollinsDepEval("CollinsDep", true, tlpp.headFinder(), tlpp.treebankLanguagePack().startSymbol());
final TreeTransformer tc = tlpp.collinizer();
//PennTreeReader skips over null/malformed parses. So when the yields of the gold/guess trees
//don't match, we need to keep looking for the next gold tree that matches.
//The evalb ref implementation differs slightly as it expects one tree per line. It assigns
//status as follows:
//
// 0 - Ok (yields match)
// 1 - length mismatch
// 2 - null parse e.g. (()).
//
//In the cases of 1,2, evalb does not include the tree pair in the LP/LR computation.
final Iterator<Tree> goldItr = goldTreebank.iterator();
int goldLineId = 0;
int skippedGuessTrees = 0;
for(final Tree guess : guessTreebank) {
final Tree evalGuess = tc.transformTree(guess);
if(guess.yield().size() > MAX_GUESS_YIELD) {
skippedGuessTrees++;
continue;
}
boolean doneEval = false;
while(goldItr.hasNext() && !doneEval) {
final Tree gold = goldItr.next();
final Tree evalGold = tc.transformTree(gold);
goldLineId++;
if(gold.yield().size() > MAX_GOLD_YIELD) {
continue;
} else if(evalGold.yield().size() != evalGuess.yield().size()) {
pwOut.println("Yield mismatch at gold line " + goldLineId);
skippedGuessTrees++;
break; //Default evalb behavior -- skip this guess tree
}
depEval.evaluate(evalGuess, evalGold, ((VERBOSE) ? pwOut : null));
doneEval = true; //Move to the next guess parse
}
}
pwOut.println("================================================================================");
if(skippedGuessTrees != 0) pwOut.printf("%s %d guess trees\n", ((MAX_GUESS_YIELD < Integer.MAX_VALUE) ? "Skipped" : "Unable to evaluate"), skippedGuessTrees);
depEval.display(true, pwOut);
pwOut.close();
}
}