Package edu.stanford.nlp.ie.machinereading

Source Code of edu.stanford.nlp.ie.machinereading.EntityExtractorResultsPrinter

package edu.stanford.nlp.ie.machinereading;

import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import edu.stanford.nlp.ie.machinereading.structure.EntityMention;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.StringUtils;

public class EntityExtractorResultsPrinter extends ResultsPrinter {

  /** Contains a set of labels that should be excluded from scoring */
  private Set<String> excludedClasses;

  /** Use subtypes for scoring or just types? */
  private boolean useSubTypes;

  private boolean verbose;
 
  private boolean verboseInstances;

  private static final DecimalFormat FORMATTER = new DecimalFormat();
  static {
    FORMATTER.setMaximumFractionDigits(1);
    FORMATTER.setMinimumFractionDigits(1);
  }

  public EntityExtractorResultsPrinter() {
    this(null, false);
  }

  protected EntityExtractorResultsPrinter(Set<String> excludedClasses, boolean useSubTypes) {
    this.excludedClasses = excludedClasses;
    this.useSubTypes = useSubTypes;
    this.verbose = true;
    this.verboseInstances = true;
  }

  @Override
  public void printResults(PrintWriter pw, List<CoreMap> goldStandard,
      List<CoreMap> extractorOutput) {
    ResultsPrinter.align(goldStandard, extractorOutput);

    Counter<String> correct = new ClassicCounter<String>();
    Counter<String> predicted = new ClassicCounter<String>();
    Counter<String> gold = new ClassicCounter<String>();

    for (int i = 0; i < goldStandard.size(); i++) {
      CoreMap goldSent = goldStandard.get(i);
      CoreMap sysSent = extractorOutput.get(i);
      String sysText = sysSent.get(TextAnnotation.class);
      String goldText = goldSent.get(TextAnnotation.class);

      if (verbose) {
        System.err.println("SCORING THE FOLLOWING SENTENCE:");
        System.err.println(sysSent.get(CoreAnnotations.TokensAnnotation.class));
      }

      HashSet<String> matchedGolds = new HashSet<String>();
      List<EntityMention> goldEntities = goldSent
          .get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
      if (goldEntities == null) {
        goldEntities = new ArrayList<EntityMention>();
      }

      for (EntityMention m : goldEntities) {
        String label = makeLabel(m);
        if (excludedClasses != null && excludedClasses.contains(label))
          continue;
        gold.incrementCount(label);
      }

      List<EntityMention> sysEntities = sysSent
          .get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
      if (sysEntities == null) {
        sysEntities = new ArrayList<EntityMention>();
      }
      for (EntityMention m : sysEntities) {
        String label = makeLabel(m);
        if (excludedClasses != null && excludedClasses.contains(label))
          continue;
        predicted.incrementCount(label);
        if (verbose)
          System.err.println("COMPARING PREDICTED MENTION: " + m);

        boolean found = false;
        for (EntityMention gm : goldEntities) {
          if (matchedGolds.contains(gm.getObjectId()))
            continue;
          if (verbose)
            System.err.println("\tagainst: " + gm);
          if(gm.equals(m, useSubTypes)){
            if (verbose) System.err.println("\t\t\tMATCH!");
            found = true;
            matchedGolds.add(gm.getObjectId());
            if(verboseInstances){
              System.err.println("TRUE POSITIVE: " + m + " matched " + gm);
              System.err.println("In sentence: " + sysText);
            }
            break;
          }
        }

        if (found) {
          correct.incrementCount(label);
        } else if(verboseInstances){
          System.err.println("FALSE POSITIVE: " + m.toString());
          System.err.println("In sentence: " + sysText);
        }
      }
     
      if (verboseInstances) {
        for (EntityMention m : goldEntities) {
          String label = makeLabel(m);
          if (!matchedGolds.contains(m.getObjectId())
              && (excludedClasses == null || !excludedClasses.contains(label))) {
            System.err.println("FALSE NEGATIVE: " + m.toString());
            System.err.println("In sentence: " + goldText);
          }
        }
      }
    }

    double totalCount = 0;
    double totalCorrect = 0;
    double totalPredicted = 0;
    pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
    List<String> labels = new ArrayList<String>(gold.keySet());
    Collections.sort(labels);
    for (String label : labels) {
      if (excludedClasses != null && excludedClasses.contains(label))
        continue;
      double numCorrect = correct.getCount(label);
      double numPredicted = predicted.getCount(label);
      double trueCount = gold.getCount(label);
      double precision = (numPredicted > 0) ? (numCorrect / numPredicted) : 0;
      double recall = numCorrect / trueCount;
      double f = (precision + recall > 0) ? 2 * precision * recall
          / (precision + recall) : 0.0;
      pw.println(StringUtils.padOrTrim(label, 21) + "\t" + numCorrect + "\t"
          + numPredicted + "\t" + trueCount + "\t"
          + FORMATTER.format(precision * 100) + "\t"
          + FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));
      totalCount += trueCount;
      totalCorrect += numCorrect;
      totalPredicted += numPredicted;
    }
    double precision = (totalPredicted > 0) ? (totalCorrect / totalPredicted)
        : 0;
    double recall = totalCorrect / totalCount;
    double f = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision
        * recall / (precision + recall) : 0.0;
    pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t"
        + totalCount + "\t" + FORMATTER.format(100 * precision) + "\t"
        + FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));   
  }

  private String makeLabel(EntityMention m) {
    String label = m.getType();
    if (useSubTypes && m.getSubType() != null)
      label += "-" + m.getSubType();
    return label;
  }
 
  public void printResultsUsingLabels(PrintWriter pw,
      List<String> goldStandard,
      List<String> extractorOutput) {}
}
TOP

Related Classes of edu.stanford.nlp.ie.machinereading.EntityExtractorResultsPrinter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.