Package edu.stanford.nlp.parser.metrics

Source Code of edu.stanford.nlp.parser.metrics.LeafAncestorEval

package edu.stanford.nlp.parser.metrics;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Stack;
import java.util.TreeMap;

import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasIndex;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;

/**
* Implementation of the Leaf Ancestor metric first described by Sampson and Babarczy (2003) and
* later analyzed more completely by Clegg and Shepherd (2005).
* <p>
* This implementation assumes that the guess/gold input files are of equal length, and have one tree per
* line.
* <p>
* TODO (spenceg): This implementation doesn't insert the "boundary symbols" as described by both
* Sampson and Clegg. Need to add those.
*
* @author Spence Green
*
*/
public class LeafAncestorEval {

  private final String name;

  private static final boolean DEBUG = false;

  //Corpus level (macro-averaged)
  private double sentAvg = 0.0;
  private double sentNum = 0.0;
  private int sentExact = 0;

  //Sentence level (micro-averaged)
  private double corpusAvg = 0.0;
  private double corpusNum = 0.0;

  //Category level
  private final Map<List<CoreLabel>,Double> catAvg;
  private final Map<List<CoreLabel>,Double> catNum;

  public LeafAncestorEval(String str) {
    this.name = str;

    catAvg = Generics.newHashMap();
    catNum = Generics.newHashMap();
  }

  /**
   * Depth-first (post-order) search through the tree, recording the stack state as the
   * lineage every time a terminal is reached.
   *
   * This implementation uses the Index annotation to store depth. If CoreLabels are
   * not present in the trees (or at least something that implements HasIndex), an exception will result.
   *
   * @param t The tree
   * @return A list of lineages
   */
  private static List<List<CoreLabel>> makeLineages(final Tree t) {
    if(t == null) return null;

    ((HasIndex) t.label()).setIndex(0);

    final Stack<Tree> treeStack = new Stack<Tree>();
    treeStack.push(t);

    final Stack<CoreLabel> labelStack = new Stack<CoreLabel>();
    CoreLabel rootLabel = new CoreLabel(t.label());
    rootLabel.setIndex(0);
    labelStack.push(rootLabel);

    final List<List<CoreLabel>> lineages = new ArrayList<List<CoreLabel>>();

    while(!treeStack.isEmpty()) {
      Tree node = treeStack.pop();
      int nodeDepth = ((HasIndex) node.label()).index();
      while(!labelStack.isEmpty() && labelStack.peek().index() != nodeDepth - 1)
        labelStack.pop();

      if(node.isPreTerminal()) {
        List<CoreLabel> lin = new ArrayList<CoreLabel>(labelStack);
        lineages.add(lin);

      } else {
        for(Tree kid : node.children()) {
          ((HasIndex) kid.label()).setIndex(nodeDepth + 1);
          treeStack.push(kid);
        }
        CoreLabel nodeLabel = new CoreLabel(node.label());
        nodeLabel.setIndex(nodeDepth);
        labelStack.add(nodeLabel);
      }
    }

    if(DEBUG) {
      System.out.println("Lineages:");
      for(List<CoreLabel> lin : lineages) {
        for(CoreLabel cl : lin)
          System.out.print(cl.value() + " <- ");
        System.out.println();
      }
    }

    return lineages;
  }

  private void updateCatAverages(final List<CoreLabel> lineage, double score) {
    if(catAvg.get(lineage) == null) {
      catAvg.put(lineage, score);
      catNum.put(lineage, 1.0);

    } else {
      double newAvg = catAvg.get(lineage) + score;
      catAvg.put(lineage, newAvg);
      double newNum = catNum.get(lineage) + 1.0;
      catNum.put(lineage, newNum);
    }
  }

  public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
    if(gold == null || guess == null) {
      System.err.printf("%s: Cannot compare against a null gold or guess tree!\n",this.getClass().getName());
      return;
    }

    final List<List<CoreLabel>> guessLineages = makeLineages(guess);
    final List<List<CoreLabel>> goldLineages = makeLineages(gold);

    if(guessLineages.size() == goldLineages.size()) {

      double localScores = 0.0;
      for(int i = 0; i < guessLineages.size(); i++) {
        List<CoreLabel> guessLin = guessLineages.get(i);
        List<CoreLabel> goldLin = goldLineages.get(i);

        double levDist = editDistance(guessLin, goldLin);
        double la = 1.0 - (levDist / (double) (guessLin.size() + goldLin.size()));

        localScores += la;

        updateCatAverages(goldLin, la);
      }

      corpusAvg += localScores;
      corpusNum += goldLineages.size();

      double localSentAvg = localScores / goldLineages.size();
      if(localSentAvg == 1.0) sentExact++;
      sentAvg += localSentAvg;
      sentNum++;

    } else {
      System.err.printf("%s: Number of guess (%d) gold (%d) don't match!\n",this.getClass().getName(),guessLineages.size(),goldLineages.size());
      System.err.println("Cannot evaluate!");
      System.err.printf("GUESS tree:\n%s\n", guess.toString());
      System.err.printf("GOLD tree:\n%s\n", gold.toString());
    }
  }

  /**
   * Computes Levenshtein edit distance between two lists of labels;
   *
   * @param l1
   * @param l2
   */
  private static int editDistance(final List<CoreLabel> l1, final List<CoreLabel> l2) {
    int[][] m = new int[l1.size()+1][l2.size()+1];
    for(int i = 1; i <= l1.size(); i++)
      m[i][0] = i;
    for(int j = 1; j <= l2.size(); j++)
      m[0][j] = j;

    for(int i = 1; i <= l1.size(); i++) {
      for(int j = 1; j <= l2.size(); j++) {
        m[i][j] = Math.min(m[i-1][j-1] + ((l1.get(i-1).equals(l2.get(j-1))) ? 0 : 1), m[i-1][j] + 1);
        m[i][j] = Math.min(m[i][j], m[i][j-1] + 1);
      }
    }

    return m[l1.size()][l2.size()];
  }

  private static String toString(final List<CoreLabel> lineage) {
    StringBuilder sb = new StringBuilder();
    for(CoreLabel cl : lineage) {
      sb.append(cl.value());
      sb.append(" <-- ");
    }

    return sb.toString();
  }

  public void display(boolean verbose, PrintWriter pw) {
    final Random rand = new Random();

    double corpusLevel = corpusAvg / corpusNum;
    double sentLevel = sentAvg / sentNum;
    double sentEx = 100.0 * sentExact / sentNum;

    if(verbose) {
      Map<Double,List<CoreLabel>> avgMap = new TreeMap<Double,List<CoreLabel>>();
      for (List<CoreLabel> lineage : catAvg.keySet()) {
        double avg = catAvg.get(lineage) / catNum.get(lineage);
        if(new Double(avg).equals(Double.NaN)) avg = -1.0;
        if(avgMap.containsKey(avg))
          avgMap.put(avg + (rand.nextDouble()/10000.0), lineage);
        else
          avgMap.put(avg, lineage);
      }

      pw.println("============================================================");
      pw.println("Leaf Ancestor Metric" + "(" + name + ") -- final statistics");
      pw.println("============================================================");
      pw.println("#Sentences: " + (int) sentNum);
      pw.println();
      pw.println("Sentence-level (macro-averaged)");
      pw.printf(" Avg: %.3f%n", sentLevel);
      pw.printf(" Exact: %.2f%%%n", sentEx);
      pw.println();
      pw.println("Corpus-level (micro-averaged)");
      pw.printf(" Avg: %.3f%n", corpusLevel);
      pw.println("============================================================");

      for (List<CoreLabel> lineage : avgMap.values()) {
        if(catNum.get(lineage) < 30.0) continue;
        double avg = catAvg.get(lineage) / catNum.get(lineage);
        pw.printf(" %.3f\t%d\t%s\n",avg, (int) ((double)catNum.get(lineage)),toString(lineage));
      }

      pw.println("============================================================");

    } else {
      pw.printf("%s summary: corpus: %.3f sent: %.3f sent-ex: %.2f%n", name,corpusLevel,sentLevel,sentEx);
    }
  }


  private static final String USAGE =
    String.format("Usage: java %s [OPTS] goldFile guessFile%n%nOptions:%n  -l lang   : Language name %s%n" +
            "  -y num    : Skip gold trees with yields longer than num.%n  -v        : Verbose output%n",
            LeafAncestorEval.class.getName(),
            Languages.listOfLanguages());

  private static final int MIN_ARGS = 2;

  //Command line options
  private static boolean VERBOSE = false;
  private static Language LANGUAGE = Language.English;
  private static int MAX_GOLD_YIELD = Integer.MAX_VALUE;

  private static File guessFile = null;
  private static File goldFile = null;

  public static final Map<String,Integer> optionArgDefs = Generics.newHashMap();
  static {
    optionArgDefs.put("-y", 1);
    optionArgDefs.put("-l", 1);
    optionArgDefs.put("-v", 0);
  }

  private static boolean validateCommandLine(String[] args) {
    Map<String, String[]> argsMap = StringUtils.argsToMap(args,optionArgDefs);

    for(Map.Entry<String, String[]> opt : argsMap.entrySet()) {
      String key = opt.getKey();
      if(key == null) {
        continue;

      } else if(key.equals("-y")) {
        MAX_GOLD_YIELD = Integer.valueOf(opt.getValue()[0]);

      } else if(key.equals("-l")) {
        LANGUAGE = Language.valueOf(opt.getValue()[0]);

      } else if(key.equals("-v")) {
        VERBOSE = true;

      } else {
        return false;
      }
    }

    //Regular arguments
    String[] rest = argsMap.get(null);
    if(rest == null || rest.length != MIN_ARGS) {
      return false;
    } else {
      goldFile = new File(rest[0]);
      guessFile = new File(rest[1]);
    }

    return true;
  }


  /**
   * Execute with no arguments for usage.
   */
  public static void main(String[] args) {

    if(!validateCommandLine(args)) {
      System.err.println(USAGE);
      System.exit(-1);
    }

    final TreebankLangParserParams tlpp = Languages.getLanguageParams(LANGUAGE);
    final PrintWriter pwOut = tlpp.pw();

    final Treebank guessTreebank = tlpp.diskTreebank();
    guessTreebank.loadPath(guessFile);
    pwOut.println("GUESS TREEBANK:");
    pwOut.println(guessTreebank.textualSummary());

    final Treebank goldTreebank = tlpp.diskTreebank();
    goldTreebank.loadPath(goldFile);
    pwOut.println("GOLD TREEBANK:");
    pwOut.println(goldTreebank.textualSummary());

    final LeafAncestorEval metric = new LeafAncestorEval("LeafAncestor");

    final TreeTransformer tc = tlpp.collinizer();

    //The evalb ref implementation assigns status for each tree pair as follows:
    //
    //   0 - Ok (yields match)
    //   1 - length mismatch
    //   2 - null parse e.g. (()).
    //
    //In the cases of 1,2, evalb does not include the tree pair in the LP/LR computation.
    final Iterator<Tree> goldItr = goldTreebank.iterator();
    final Iterator<Tree> guessItr = guessTreebank.iterator();
    int goldLineId = 0;
    int guessLineId = 0;
    int skippedGuessTrees = 0;
    while( guessItr.hasNext() && goldItr.hasNext() ) {
      Tree guessTree = guessItr.next();
      List<Label> guessYield = guessTree.yield();
      guessLineId++;

      Tree goldTree = goldItr.next();
      List<Label> goldYield = goldTree.yield();
      goldLineId++;

      // Check that we should evaluate this tree
      if(goldYield.size() > MAX_GOLD_YIELD) {
        skippedGuessTrees++;
        continue;
      }

      // Only trees with equal yields can be evaluated
      if(goldYield.size() != guessYield.size()) {
        pwOut.printf("Yield mismatch gold: %d tokens vs. guess: %d tokens (lines: gold %d guess %d)%n", goldYield.size(), guessYield.size(), goldLineId, guessLineId);
        skippedGuessTrees++;
        continue;
      }

      final Tree evalGuess = tc.transformTree(guessTree);
      final Tree evalGold = tc.transformTree(goldTree);

      metric.evaluate(evalGuess, evalGold, ((VERBOSE) ? pwOut : null));
    }

    if(guessItr.hasNext() || goldItr.hasNext()) {
      System.err.printf("Guess/gold files do not have equal lengths (guess: %d gold: %d)%n.", guessLineId, goldLineId);
    }

    pwOut.println("================================================================================");
    if(skippedGuessTrees != 0) pwOut.printf("%s %d guess trees\n", "Unable to evaluate", skippedGuessTrees);
    metric.display(true, pwOut);
    pwOut.close();
  }

}
TOP

Related Classes of edu.stanford.nlp.parser.metrics.LeafAncestorEval

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.