Package edu.stanford.nlp.parser.metrics

Source Code of edu.stanford.nlp.parser.metrics.EvalbByCat

package edu.stanford.nlp.parser.metrics;

import java.io.PrintWriter;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.regex.Pattern;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.trees.Constituent;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;

/**
* Computes labeled precision and recall (evalb) at the constituent category level.
*
* @author Roger Levy
* @author Spence Green
*/
public class EvalbByCat extends AbstractEval {

  private final Evalb evalb;
 
  // Only evaluate categories that match this regular expression
  private Pattern pLabelFilter = null;

  private final Counter<Label> precisions;
  private final Counter<Label> recalls;
  private final Counter<Label> f1s;

  private final Counter<Label> precisions2;
  private final Counter<Label> recalls2;
  private final Counter<Label> pnums2;
  private final Counter<Label> rnums2;


  public EvalbByCat(String str, boolean runningAverages) {
    super(str, runningAverages);

    evalb = new Evalb(str, false);
    precisions = new ClassicCounter<Label>();
    recalls = new ClassicCounter<Label>();
    f1s = new ClassicCounter<Label>();

    precisions2 = new ClassicCounter<Label>();
    recalls2 = new ClassicCounter<Label>();
    pnums2 = new ClassicCounter<Label>();
    rnums2 = new ClassicCounter<Label>();
  }
 
  public EvalbByCat(String str, boolean runningAverages, String labelRegex) {
    this(str, runningAverages);
    if (labelRegex != null) {
      pLabelFilter = Pattern.compile(labelRegex.trim());
    }
  }

  @Override
  protected Set<Constituent> makeObjects(Tree tree) {
    return evalb.makeObjects(tree);
  }

  private Map<Label,Set<Constituent>> makeObjectsByCat(Tree t) {
    Map<Label,Set<Constituent>> objMap = Generics.newHashMap();
    Set<Constituent> objSet = makeObjects(t);
    for (Constituent lc : objSet) {
      Label l = lc.label();
      if (!objMap.keySet().contains(l)) {
        objMap.put(l, Generics.<Constituent>newHashSet());
      }
      objMap.get(l).add(lc);
    }
    return objMap;
  }

  @Override
  public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
    if(gold == null || guess == null) {
      System.err.printf("%s: Cannot compare against a null gold or guess tree!%n",this.getClass().getName());
      return;
    }

    Map<Label,Set<Constituent>> guessDeps = makeObjectsByCat(guess);
    Map<Label,Set<Constituent>> goldDeps = makeObjectsByCat(gold);
    Set<Label> cats = Generics.newHashSet(guessDeps.keySet());
    cats.addAll(goldDeps.keySet());

    if (pw != null && runningAverages) {
      pw.println("========================================");
      pw.println("Labeled Bracketed Evaluation by Category");
      pw.println("========================================");
    }

    ++num;

    for (Label cat : cats) {
      Set<Constituent> thisGuessDeps = guessDeps.containsKey(cat) ? guessDeps.get(cat) : Generics.<Constituent>newHashSet();
      Set<Constituent> thisGoldDeps = goldDeps.containsKey(cat) ? goldDeps.get(cat) : Generics.<Constituent>newHashSet();

      double currentPrecision = precision(thisGuessDeps, thisGoldDeps);
      double currentRecall = precision(thisGoldDeps, thisGuessDeps);
      double currentF1 = (currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0);

      precisions.incrementCount(cat, currentPrecision);
      recalls.incrementCount(cat, currentRecall);
      f1s.incrementCount(cat, currentF1);

      precisions2.incrementCount(cat, thisGuessDeps.size() * currentPrecision);
      pnums2.incrementCount(cat, thisGuessDeps.size());

      recalls2.incrementCount(cat, thisGoldDeps.size() * currentRecall);
      rnums2.incrementCount(cat, thisGoldDeps.size());

      if (pw != null && runningAverages) {
        pw.println(cat + "\tP: " + ((int) (currentPrecision * 10000)) / 100.0 + " (sent ave " + ((int) (precisions.getCount(cat) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (precisions2.getCount(cat) * 10000 / pnums2.getCount(cat))) / 100.0 + ")");
        pw.println("\tR: " + ((int) (currentRecall * 10000)) / 100.0 + " (sent ave " + ((int) (recalls.getCount(cat) * 10000 / num)) / 100.0 + ") (evalb " + ((int) (recalls2.getCount(cat) * 10000 / rnums2.getCount(cat))) / 100.0 + ")");
        double cF1 = 2.0 / (rnums2.getCount(cat) / recalls2.getCount(cat) + pnums2.getCount(cat) / precisions2.getCount(cat));
        String emit = str + " F1: " + ((int) (currentF1 * 10000)) / 100.0 + " (sent ave " + ((int) (10000 * f1s.getCount(cat) / num)) / 100.0 + ", evalb " + ((int) (10000 * cF1)) / 100.0 + ")";
        pw.println(emit);
      }
    }
    if (pw != null && runningAverages) {
      pw.println("========================================");
    }
  }

  private Set<Label> getEvalLabelSet(Set<Label> labelSet) {
    if (pLabelFilter == null) {
      return Generics.newHashSet(precisions.keySet());
    } else {
      Set<Label> evalSet = Generics.newHashSet(precisions.keySet().size());
      for (Label label : labelSet) {
        if (pLabelFilter.matcher(label.value()).matches()) {
          evalSet.add(label);
        }
      }
      return evalSet;
    }
  }
 
  @Override
  public void display(boolean verbose, PrintWriter pw) {
    if (precisions.keySet().size() != recalls.keySet().size()) {
      System.err.println("ERROR: Different counts for precisions and recalls!");
      return;
    }
    final Set<Label> cats = getEvalLabelSet(precisions.keySet());
    final Random rand = new Random();

    Map<Double,Label> f1Map = new TreeMap<Double,Label>();
    for (Label cat : cats) {
      double pnum2 = pnums2.getCount(cat);
      double rnum2 = rnums2.getCount(cat);
      double prec = precisions2.getCount(cat) / pnum2;
      double rec = recalls2.getCount(cat) / rnum2;
      double f1 = 2.0 / (1.0 / prec + 1.0 / rec);

      if(new Double(f1).equals(Double.NaN)) f1 = -1.0;
      if(f1Map.containsKey(f1)) {
        f1Map.put(f1 + (rand.nextDouble()/1000.0), cat);
      } else {
        f1Map.put(f1, cat);
      }
    }
    pw.println("============================================================");
    pw.println("Labeled Bracketed Evaluation by Category -- final statistics");
    pw.println("============================================================");
    // Per category
    double catPrecisions = 0.0;
    double catPrecisionNums = 0.0;
    double catRecalls = 0.0;
    double catRecallNums = 0.0;
    for (Label cat : f1Map.values()) {
      double pnum2 = pnums2.getCount(cat);
      double rnum2 = rnums2.getCount(cat);
      double prec = precisions2.getCount(cat) / pnum2;
      prec *= 100.0;
      double rec = recalls2.getCount(cat) / rnum2;
      rec *= 100.0;
      double f1 = 2.0 / (1.0 / prec + 1.0 / rec);

      catPrecisions += precisions2.getCount(cat);
      catPrecisionNums += pnum2;
      catRecalls += recalls2.getCount(cat);
      catRecallNums += rnum2;
     
      String LP = pnum2 == 0.0 ? "N/A" : String.format("%.2f", prec);
      String LR = rnum2 == 0.0 ? "N/A" : String.format("%.2f", rec);
      String F1 = (pnum2 == 0.0 || rnum2 == 0.0) ? "N/A": String.format("%.2f", f1);
     
      pw.printf("%s\tLP: %s\tguessed: %d\tLR: %s\tgold: %d\t F1: %s%n",
          cat.value(),
          LP,
          (int) pnum2,
          LR,
          (int) rnum2,
          F1);
    }
    pw.println("============================================================");
    // Totals
    double prec = catPrecisions / catPrecisionNums;
    double rec = catRecalls / catRecallNums;
    double f1 = (2 * prec * rec) / (prec + rec);
    pw.printf("Total\tLP: %.2f\tguessed: %d\tLR: %.2f\tgold: %d\t F1: %.2f%n",
        prec*100.0,
        (int) catPrecisionNums,
        rec*100.0,
        (int) catRecallNums,
        f1*100.0);
    pw.println("============================================================");
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.metrics.EvalbByCat

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.