Package edu.stanford.nlp.parser.metrics

Source Code of edu.stanford.nlp.parser.metrics.CollinsDepEval

package edu.stanford.nlp.parser.metrics;

import java.io.File;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;

import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.CollinsDependency;
import edu.stanford.nlp.trees.CollinsRelation;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;

/**
* Compute P/R/F1 for the dependency representation of Collins (1999; 2003).
*
* @author Spence Green
*
*/
public class CollinsDepEval extends AbstractEval {

  private static final boolean DEBUG = false;

  private final HeadFinder hf;
  private final String startSymbol;

  private final Counter<CollinsRelation> precisions;
  private final Counter<CollinsRelation> recalls;
  private final Counter<CollinsRelation> f1s;

  private final Counter<CollinsRelation> precisions2;
  private final Counter<CollinsRelation> recalls2;
  private final Counter<CollinsRelation> pnums2;
  private final Counter<CollinsRelation> rnums2;

  public CollinsDepEval(String str, boolean runningAverages, HeadFinder hf, String startSymbol) {
    super(str,runningAverages);

    this.hf = hf;
    this.startSymbol = startSymbol;

    precisions = new ClassicCounter<CollinsRelation>();
    recalls = new ClassicCounter<CollinsRelation>();
    f1s = new ClassicCounter<CollinsRelation>();

    precisions2 = new ClassicCounter<CollinsRelation>();
    recalls2 = new ClassicCounter<CollinsRelation>();
    pnums2 = new ClassicCounter<CollinsRelation>();
    rnums2 = new ClassicCounter<CollinsRelation>();
  }

  @Override
  protected Set<?> makeObjects(Tree tree) {
    System.err.println(this.getClass().getName() + ": Function makeObjects() not implemented");
    return null;
  }

  private Map<CollinsRelation,Set<CollinsDependency>> makeCollinsObjects(Tree t) {
    final Map<CollinsRelation,Set<CollinsDependency>> relMap = Generics.newHashMap();
    final Set<CollinsDependency> deps = CollinsDependency.extractNormalizedFromTree(t, startSymbol, hf);

    for(CollinsDependency dep : deps) {
      if(DEBUG) System.out.println(dep.toString());
      if(relMap.get(dep.getRelation()) == null)
        relMap.put(dep.getRelation(), Generics.<CollinsDependency>newHashSet());
      relMap.get(dep.getRelation()).add(dep);
    }
    if(DEBUG) System.out.println();

    return relMap;
  }

  @Override
  public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
    if(gold == null || guess == null) {
      System.err.printf("%s: Cannot compare against a null gold or guess tree!\n",this.getClass().getName());
      return;
    }

    if(DEBUG) System.out.println("guess:");
    Map<CollinsRelation,Set<CollinsDependency>> guessDeps = makeCollinsObjects(guess);

    if(DEBUG) System.out.println("gold:");
    Map<CollinsRelation,Set<CollinsDependency>> goldDeps = makeCollinsObjects(gold);

    Set<CollinsRelation> relations = Generics.newHashSet();
    relations.addAll(guessDeps.keySet());
    relations.addAll(goldDeps.keySet());

    num += 1.0;

    for (CollinsRelation rel : relations) {
      Set<CollinsDependency> thisGuessDeps = guessDeps.get(rel);
      Set<CollinsDependency> thisGoldDeps = goldDeps.get(rel);

      if (thisGuessDeps == null)
        thisGuessDeps = Generics.newHashSet();
      if (thisGoldDeps == null)
        thisGoldDeps = Generics.newHashSet();

      double currentPrecision = precision(thisGuessDeps, thisGoldDeps);
      double currentRecall = precision(thisGoldDeps, thisGuessDeps);
      double currentF1 = (currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0);

      precisions.incrementCount(rel, currentPrecision);
      recalls.incrementCount(rel, currentRecall);
      f1s.incrementCount(rel, currentF1);

      precisions2.incrementCount(rel, thisGuessDeps.size() * currentPrecision);
      pnums2.incrementCount(rel, thisGuessDeps.size());

      recalls2.incrementCount(rel, thisGoldDeps.size() * currentRecall);
      rnums2.incrementCount(rel, thisGoldDeps.size());

      if (pw != null && runningAverages) {
        pw.println(rel + "\tP: " + ((int) (currentPrecision * 10000)) / 100.0 + " (sent ave " + ((int) (precisions.getCount(rel) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (precisions2.getCount(rel) * 10000 / pnums2.getCount(rel))) / 100.0 + ")");
        pw.println("\tR: " + ((int) (currentRecall * 10000)) / 100.0 + " (sent ave " + ((int) (recalls.getCount(rel) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (recalls2.getCount(rel) * 10000 / rnums2.getCount(rel))) / 100.0 + ")");
        double cF1 = 2.0 / (rnums2.getCount(rel) / recalls2.getCount(rel) + pnums2.getCount(rel) / precisions2.getCount(rel));
        String emit = str + " F1: " + ((int) (currentF1 * 10000)) / 100.0 + " (sent ave " + ((int) (10000 * f1s.getCount(rel) / num)) / 100.0 + ", evalb " + ((int) (10000 * cF1)) / 100.0 + ")";
        pw.println(emit);
      }
    }
    if (pw != null && runningAverages) {
      pw.println("================================================================================");
    }
  }

  @Override
  public void display(boolean verbose, PrintWriter pw) {
    final NumberFormat nf = new DecimalFormat("0.00");
    final Set<CollinsRelation> cats = Generics.newHashSet();
    final Random rand = new Random();
    cats.addAll(precisions.keySet());
    cats.addAll(recalls.keySet());

    Map<Double,CollinsRelation> f1Map = new TreeMap<Double,CollinsRelation>();
    for (CollinsRelation cat : cats) {
      double pnum2 = pnums2.getCount(cat);
      double rnum2 = rnums2.getCount(cat);
      double prec = precisions2.getCount(cat) / pnum2;//(num > 0.0 ? precision/num : 0.0);
      double rec = recalls2.getCount(cat) / rnum2;//(num > 0.0 ? recall/num : 0.0);
      double f1 = 2.0 / (1.0 / prec + 1.0 / rec);//(num > 0.0 ? f1/num : 0.0);

      if(new Double(f1).equals(Double.NaN)) f1 = -1.0;
      if(f1Map.containsKey(f1))
        f1Map.put(f1 + (rand.nextDouble()/1000.0), cat);
      else
        f1Map.put(f1, cat);
    }

    pw.println(" Abstract Collins Dependencies -- final statistics");
    pw.println("================================================================================");

    for (CollinsRelation cat : f1Map.values()) {
      double pnum2 = pnums2.getCount(cat);
      double rnum2 = rnums2.getCount(cat);
      double prec = precisions2.getCount(cat) / pnum2;//(num > 0.0 ? precision/num : 0.0);
      double rec = recalls2.getCount(cat) / rnum2;//(num > 0.0 ? recall/num : 0.0);
      double f1 = 2.0 / (1.0 / prec + 1.0 / rec);//(num > 0.0 ? f1/num : 0.0);

      pw.println(cat + "\tLP: " + ((pnum2 == 0.0) ? " N/A": nf.format(prec)) + "\tguessed: " + (int) pnum2 +
          "\tLR: " + ((rnum2 == 0.0) ? " N/A": nf.format(rec)) + "\tgold:  " + (int) rnum2 +
          "\tF1: " + ((pnum2 == 0.0 || rnum2 == 0.0) ? " N/A": nf.format(f1)));
    }

    pw.println("================================================================================");
  }

  private final static int MIN_ARGS = 2;
  private static String usage() {
    StringBuilder usage = new StringBuilder();
    String nl = System.getProperty("line.separator");
    usage.append(String.format("Usage: java %s [OPTS] goldFile guessFile%n%n",CollinsDepEval.class.getName()));
    usage.append("Options:").append(nl);
    usage.append("  -v        : Verbose output").append(nl);
    usage.append("  -l lang   : Language name " + Languages.listOfLanguages()).append(nl);
    usage.append("  -y num    : Max yield of gold trees").append(nl);
    usage.append("  -g num    : Max yield of guess trees").append(nl);
    return usage.toString();
  }
  private static Map<String,Integer> optionArgDefs() {
    Map<String,Integer> optionArgDefs = Generics.newHashMap();
    optionArgDefs.put("v", 0);
    optionArgDefs.put("l", 1);
    optionArgDefs.put("g", 1);
    optionArgDefs.put("y", 1);
    return optionArgDefs;
  }

  /**
   *
   * @param args
   */
  public static void main(String[] args) {
    if(args.length < MIN_ARGS) {
      System.err.println(usage());
      System.exit(-1);
    }
    Properties options = StringUtils.argsToProperties(args, optionArgDefs());
   
    boolean VERBOSE = PropertiesUtils.getBool(options, "v", false);
    Language LANGUAGE = PropertiesUtils.get(options, "l", Language.English, Language.class);
    int MAX_GOLD_YIELD = PropertiesUtils.getInt(options, "g", Integer.MAX_VALUE);
    int MAX_GUESS_YIELD = PropertiesUtils.getInt(options, "y", Integer.MAX_VALUE);
   
    String[] parsedArgs = options.getProperty("","").split("\\s+");
    if (parsedArgs.length != MIN_ARGS) {
      System.err.println(usage());
      System.exit(-1);
    }
    File goldFile = new File(parsedArgs[0]);
    File guessFile = new File(parsedArgs[1]);
   
    final TreebankLangParserParams tlpp = Languages.getLanguageParams(LANGUAGE);
    final PrintWriter pwOut = tlpp.pw();

    final Treebank guessTreebank = tlpp.diskTreebank();
    guessTreebank.loadPath(guessFile);
    pwOut.println("GUESS TREEBANK:");
    pwOut.println(guessTreebank.textualSummary());

    final Treebank goldTreebank = tlpp.diskTreebank();
    goldTreebank.loadPath(goldFile);
    pwOut.println("GOLD TREEBANK:");
    pwOut.println(goldTreebank.textualSummary());

    final CollinsDepEval depEval = new CollinsDepEval("CollinsDep", true, tlpp.headFinder(), tlpp.treebankLanguagePack().startSymbol());

    final TreeTransformer tc = tlpp.collinizer();

    //PennTreeReader skips over null/malformed parses. So when the yields of the gold/guess trees
    //don't match, we need to keep looking for the next gold tree that matches.
    //The evalb ref implementation differs slightly as it expects one tree per line. It assigns
    //status as follows:
    //
    //   0 - Ok (yields match)
    //   1 - length mismatch
    //   2 - null parse e.g. (()).
    //
    //In the cases of 1,2, evalb does not include the tree pair in the LP/LR computation.

    final Iterator<Tree> goldItr = goldTreebank.iterator();
    int goldLineId = 0;
    int skippedGuessTrees = 0;

    for(final Tree guess : guessTreebank) {
      final Tree evalGuess = tc.transformTree(guess);
      if(guess.yield().size() > MAX_GUESS_YIELD) {
        skippedGuessTrees++;
        continue;
      }

      boolean doneEval = false;
      while(goldItr.hasNext() && !doneEval) {
        final Tree gold = goldItr.next();
        final Tree evalGold = tc.transformTree(gold);
        goldLineId++;

        if(gold.yield().size() > MAX_GOLD_YIELD) {
          continue;

        } else if(evalGold.yield().size() != evalGuess.yield().size()) {
          pwOut.println("Yield mismatch at gold line " + goldLineId);
          skippedGuessTrees++;
          break; //Default evalb behavior -- skip this guess tree
        }

        depEval.evaluate(evalGuess, evalGold, ((VERBOSE) ? pwOut : null));

        doneEval = true; //Move to the next guess parse
      }
    }

    pwOut.println("================================================================================");
    if(skippedGuessTrees != 0) pwOut.printf("%s %d guess trees\n", ((MAX_GUESS_YIELD < Integer.MAX_VALUE) ? "Skipped" : "Unable to evaluate"), skippedGuessTrees);
    depEval.display(true, pwOut);
    pwOut.close();
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.metrics.CollinsDepEval

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.