Package edu.stanford.nlp.parser.metrics

Source Code of edu.stanford.nlp.parser.metrics.TaggingEval

package edu.stanford.nlp.parser.metrics;

import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;

import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.parser.lexparser.EnglishTreebankParserParams;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;

/**
* Computes POS tagging P/R/F1 from guess/gold trees. This version assumes that the yields match. For
* trees with potentially different yields, use {@link TsarfatyEval}.
* <p>
* This implementation assumes that the guess/gold input files are of equal length, and have one tree per
* line.
*
* @author Spence Green
*
*/
public class TaggingEval extends AbstractEval {

  private final Lexicon lex;

  private static boolean doCatLevelEval = false;
  private Counter<String> precisions;
  private Counter<String> recalls;
  private Counter<String> f1s;

  private Counter<String> precisions2;
  private Counter<String> recalls2;
  private Counter<String> pnums2;
  private Counter<String> rnums2;

  private Counter<String> percentOOV;
  private Counter<String> percentOOV2;

  public TaggingEval(String str) {
    this(str, true, null);
  }

  public TaggingEval(String str, boolean runningAverages, Lexicon lex) {
    super(str, runningAverages);
    this.lex = lex;

    if(doCatLevelEval) {
      precisions = new ClassicCounter<String>();
      recalls = new ClassicCounter<String>();
      f1s = new ClassicCounter<String>();

      precisions2 = new ClassicCounter<String>();
      recalls2 = new ClassicCounter<String>();
      pnums2 = new ClassicCounter<String>();
      rnums2 = new ClassicCounter<String>();

      percentOOV = new ClassicCounter<String>();
      percentOOV2 = new ClassicCounter<String>();
    }
  }

  @Override
  protected Set<HasTag> makeObjects(Tree tree) {
    return (tree == null) ? Generics.<HasTag>newHashSet() : Generics.<HasTag>newHashSet(tree.taggedLabeledYield());
  }

  private static Map<String,Set<Label>> makeObjectsByCat(Tree t) {
    Map<String,Set<Label>> catMap = Generics.newHashMap();
    List<CoreLabel> tly = t.taggedLabeledYield();

    for(CoreLabel label : tly) {
      if(catMap.containsKey(label.value()))
        catMap.get(label.value()).add(label);
      else {
        Set<Label> catSet = Generics.newHashSet();
        catSet.add(label);
        catMap.put(label.value(), catSet);
      }
    }
    return catMap;
  }

  @Override
  public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
    if(gold == null || guess == null) {
      System.err.printf("%s: Cannot compare against a null gold or guess tree!\n",this.getClass().getName());
      return;
    }

    //Do regular evaluation
    super.evaluate(guess, gold, pw);

    if(doCatLevelEval) {
      final Map<String,Set<Label>> guessCats = makeObjectsByCat(guess);
      final Map<String,Set<Label>> goldCats = makeObjectsByCat(gold);
      final Set<String> allCats = Generics.newHashSet();
      allCats.addAll(guessCats.keySet());
      allCats.addAll(goldCats.keySet());

      for(final String cat : allCats) {
        Set<Label> thisGuessCats = guessCats.get(cat);
        Set<Label> thisGoldCats = goldCats.get(cat);

        if (thisGuessCats == null)
          thisGuessCats = Generics.newHashSet();
        if (thisGoldCats == null)
          thisGoldCats = Generics.newHashSet();

        double currentPrecision = precision(thisGuessCats, thisGoldCats);
        double currentRecall = precision(thisGoldCats, thisGuessCats);

        double currentF1 = (currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0);

        precisions.incrementCount(cat, currentPrecision);
        recalls.incrementCount(cat, currentRecall);
        f1s.incrementCount(cat, currentF1);

        precisions2.incrementCount(cat, thisGuessCats.size() * currentPrecision);
        pnums2.incrementCount(cat, thisGuessCats.size());

        recalls2.incrementCount(cat, thisGoldCats.size() * currentRecall);
        rnums2.incrementCount(cat, thisGoldCats.size());

        if(lex != null) measureOOV(guess,gold);

        if (pw != null && runningAverages) {
          pw.println(cat + "\tP: " + ((int) (currentPrecision * 10000)) / 100.0 + " (sent ave " + ((int) (precisions.getCount(cat) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (precisions2.getCount(cat) * 10000 / pnums2.getCount(cat))) / 100.0 + ")");
          pw.println("\tR: " + ((int) (currentRecall * 10000)) / 100.0 + " (sent ave " + ((int) (recalls.getCount(cat) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (recalls2.getCount(cat) * 10000 / rnums2.getCount(cat))) / 100.0 + ")");
          double cF1 = 2.0 / (rnums2.getCount(cat) / recalls2.getCount(cat) + pnums2.getCount(cat) / precisions2.getCount(cat));
          String emit = str + " F1: " + ((int) (currentF1 * 10000)) / 100.0 + " (sent ave " + ((int) (10000 * f1s.getCount(cat) / num)) / 100.0 + ", evalb " + ((int) (10000 * cF1)) / 100.0 + ")";
          pw.println(emit);
        }
      }
      if (pw != null && runningAverages) {
        pw.println("========================================");
      }
    }
  }

  /**
   * Measures the percentage of incorrect taggings that can be attributed to OOV words.
   *
   * @param guess
   * @param gold
   */
  private void measureOOV(Tree guess, Tree gold) {
    List<CoreLabel> goldTagging = gold.taggedLabeledYield();
    List<CoreLabel> guessTagging = guess.taggedLabeledYield();

    assert goldTagging.size() == guessTagging.size();

    for(int i = 0; i < goldTagging.size(); i++) {
      if(!(goldTagging.get(i) == guessTagging.get(i))) {
        percentOOV2.incrementCount(goldTagging.get(i).tag());
        if(!lex.isKnown(goldTagging.get(i).word()))
          percentOOV.incrementCount(goldTagging.get(i).tag());
      }
    }
  }

  @Override
  public void display(boolean verbose, PrintWriter pw) {
    super.display(verbose, pw);

    if(doCatLevelEval) {
      final NumberFormat nf = new DecimalFormat("0.00");
      final Set<String> cats = Generics.newHashSet();
      final Random rand = new Random();
      cats.addAll(precisions.keySet());
      cats.addAll(recalls.keySet());

      Map<Double,String> f1Map = new TreeMap<Double,String>();
      for (String cat : cats) {
        double pnum2 = pnums2.getCount(cat);
        double rnum2 = rnums2.getCount(cat);
        double prec = precisions2.getCount(cat) / pnum2;
        double rec = recalls2.getCount(cat) / rnum2;
        double f1 = 2.0 / (1.0 / prec + 1.0 / rec);

        if(new Double(f1).equals(Double.NaN)) f1 = -1.0;
        if(f1Map.containsKey(f1))
          f1Map.put(f1 + (rand.nextDouble()/1000.0), cat);
        else
          f1Map.put(f1, cat);
      }

      pw.println("============================================================");
      pw.println("Tagging Performance by Category -- final statistics");
      pw.println("============================================================");

      for (String cat : f1Map.values()) {
        double pnum2 = pnums2.getCount(cat);
        double rnum2 = rnums2.getCount(cat);
        double prec = precisions2.getCount(cat) / pnum2;
        prec *= 100.0;
        double rec = recalls2.getCount(cat) / rnum2;
        rec *= 100.0;
        double f1 = 2.0 / (1.0 / prec + 1.0 / rec);

        double oovRate = (lex == null) ? -1.0 : percentOOV.getCount(cat) / percentOOV2.getCount(cat);

        pw.println(cat + "\tLP: " + ((pnum2 == 0.0) ? " N/A": nf.format(prec)) + "\tguessed: " + (int) pnum2 +
            "\tLR: " + ((rnum2 == 0.0) ? " N/A": nf.format(rec)) + "\tgold:  " + (int) rnum2 +
            "\tF1: " + ((pnum2 == 0.0 || rnum2 == 0.0) ? " N/A": nf.format(f1)) +
            "\tOOV: " + ((lex == null) ? " N/A" : nf.format(oovRate)));
      }

      pw.println("============================================================");
    }
  }

  private static final int minArgs = 2;
  private static final StringBuilder usage = new StringBuilder();
  static {
    usage.append(String.format("Usage: java %s [OPTS] gold guess\n\n",TaggingEval.class.getName()));
    usage.append("Options:\n");
    usage.append("  -v         : Verbose mode.\n");
    usage.append("  -l lang    : Select language settings from " + Languages.listOfLanguages() + "\n");
    usage.append("  -y num     : Skip gold trees with yields longer than num.\n");
    usage.append("  -c         : Compute LP/LR/F1 by category.\n");
    usage.append("  -e         : Input encoding.\n");
  }

  public static final Map<String,Integer> optionArgDefs = Generics.newHashMap();
  static {
    optionArgDefs.put("-v", 0);
    optionArgDefs.put("-l", 1);
    optionArgDefs.put("-y", 1);
    optionArgDefs.put("-c", 0);
    optionArgDefs.put("-e", 0);
  }

  /**
   * Run the scoring metric on guess/gold input. This method performs "Collinization."
   * The default language is English.
   *
   * @param args
   */
  public static void main(String[] args) {

    if(args.length < minArgs) {
      System.out.println(usage.toString());
      System.exit(-1);
    }

    TreebankLangParserParams tlpp = new EnglishTreebankParserParams();
    int maxGoldYield = Integer.MAX_VALUE;
    boolean VERBOSE = false;
    String encoding = "UTF-8";

    String guessFile = null;
    String goldFile = null;

    Map<String, String[]> argsMap = StringUtils.argsToMap(args, optionArgDefs);

    for(Map.Entry<String, String[]> opt : argsMap.entrySet()) {
      if(opt.getKey() == null) continue;
      if(opt.getKey().equals("-l")) {
        Language lang = Language.valueOf(opt.getValue()[0].trim());
        tlpp = Languages.getLanguageParams(lang);

      } else if(opt.getKey().equals("-y")) {
        maxGoldYield = Integer.parseInt(opt.getValue()[0].trim());

      } else if(opt.getKey().equals("-v")) {
        VERBOSE = true;

      } else if(opt.getKey().equals("-c")) {
        TaggingEval.doCatLevelEval = true;

      } else if(opt.getKey().equals("-e")) {
        encoding = opt.getValue()[0];

      } else {
        System.err.println(usage.toString());
        System.exit(-1);
      }

      //Non-option arguments located at key null
      String[] rest = argsMap.get(null);
      if(rest == null || rest.length < minArgs) {
        System.err.println(usage.toString());
        System.exit(-1);
      }
      goldFile = rest[0];
      guessFile = rest[1];
    }

    tlpp.setInputEncoding(encoding);
    final PrintWriter pwOut = tlpp.pw();

    final Treebank guessTreebank = tlpp.diskTreebank();
    guessTreebank.loadPath(guessFile);
    pwOut.println("GUESS TREEBANK:");
    pwOut.println(guessTreebank.textualSummary());

    final Treebank goldTreebank = tlpp.diskTreebank();
    goldTreebank.loadPath(goldFile);
    pwOut.println("GOLD TREEBANK:");
    pwOut.println(goldTreebank.textualSummary());

    final TaggingEval metric = new TaggingEval("Tagging LP/LR");

    final TreeTransformer tc = tlpp.collinizer();

    //The evalb ref implementation assigns status for each tree pair as follows:
    //
    //   0 - Ok (yields match)
    //   1 - length mismatch
    //   2 - null parse e.g. (()).
    //
    //In the cases of 1,2, evalb does not include the tree pair in the LP/LR computation.
    final Iterator<Tree> goldItr = goldTreebank.iterator();
    final Iterator<Tree> guessItr = guessTreebank.iterator();
    int goldLineId = 0;
    int guessLineId = 0;
    int skippedGuessTrees = 0;
    while( guessItr.hasNext() && goldItr.hasNext() ) {
      Tree guessTree = guessItr.next();
      List<Label> guessYield = guessTree.yield();
      guessLineId++;

      Tree goldTree = goldItr.next();
      List<Label> goldYield = goldTree.yield();
      goldLineId++;

      // Check that we should evaluate this tree
      if(goldYield.size() > maxGoldYield) {
        skippedGuessTrees++;
        continue;
      }

      // Only trees with equal yields can be evaluated
      if(goldYield.size() != guessYield.size()) {
        pwOut.printf("Yield mismatch gold: %d tokens vs. guess: %d tokens (lines: gold %d guess %d)%n", goldYield.size(), guessYield.size(), goldLineId, guessLineId);
        skippedGuessTrees++;
        continue;
      }

      final Tree evalGuess = tc.transformTree(guessTree);
      final Tree evalGold = tc.transformTree(goldTree);

      metric.evaluate(evalGuess, evalGold, ((VERBOSE) ? pwOut : null));
    }

    if(guessItr.hasNext() || goldItr.hasNext()) {
      System.err.printf("Guess/gold files do not have equal lengths (guess: %d gold: %d)%n.", guessLineId, goldLineId);
    }

    pwOut.println("================================================================================");
    if(skippedGuessTrees != 0) pwOut.printf("%s %d guess trees\n", "Unable to evaluate", skippedGuessTrees);
    metric.display(true, pwOut);
    pwOut.println();
    pwOut.close();
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.metrics.TaggingEval

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.