Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.ParentAnnotationStats

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.trees.*;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import java.io.Reader;
import java.text.NumberFormat;
import java.util.*;

/**
* See what parent annotation helps in treebank, based on support and
* KL divergence.
*
* @author Christopher Manning
* @version 2003/01/04
*/
public class ParentAnnotationStats implements TreeVisitor {

  private final TreebankLanguagePack tlp;

  private ParentAnnotationStats(TreebankLanguagePack tlp, boolean doTags) {
    this.tlp = tlp;
    this.doTags = doTags;
  }

  private final boolean doTags;

  private Map<String,ClassicCounter<List<String>>> nodeRules = Generics.newHashMap();
  private Map<List<String>,ClassicCounter<List<String>>> pRules = Generics.newHashMap();
  private Map<List<String>,ClassicCounter<List<String>>> gPRules = Generics.newHashMap();

  // corresponding ones for tags
  private Map<String,ClassicCounter<List<String>>> tagNodeRules = Generics.newHashMap();
  private Map<List<String>,ClassicCounter<List<String>>> tagPRules = Generics.newHashMap();
  private Map<List<String>,ClassicCounter<List<String>>> tagGPRules = Generics.newHashMap();

  /**
   * Minimum support * KL to be included in output and as feature
   */
  public static final double[] CUTOFFS = {100.0, 200.0, 500.0, 1000.0};

  /**
   * Minimum support of parent annotated node for grandparent to be
   * studied.  Just there to reduce runtime and printout size.
   */
  public static final double SUPPCUTOFF = 100.0;

  /**
   * Does whatever one needs to do to a particular parse tree
   */
  public void visitTree(Tree t) {
    processTreeHelper("TOP", "TOP", t);
  }

  public static List<String> kidLabels(Tree t) {
    Tree[] kids = t.children();
    List<String> l = new ArrayList<String>(kids.length);
    for (Tree kid : kids) {
      l.add(kid.label().value());
    }
    return l;
  }

  public void processTreeHelper(String gP, String p, Tree t) {
    if (!t.isLeaf() && (doTags || !t.isPreTerminal())) { // stop at words/tags
      Map<String,ClassicCounter<List<String>>> nr;
      Map<List<String>,ClassicCounter<List<String>>> pr;
      Map<List<String>,ClassicCounter<List<String>>> gpr;
      if (t.isPreTerminal()) {
        nr = tagNodeRules;
        pr = tagPRules;
        gpr = tagGPRules;
      } else {
        nr = nodeRules;
        pr = pRules;
        gpr = gPRules;
      }
      String n = t.label().value();
      if (tlp != null) {
        p = tlp.basicCategory(p);
        gP = tlp.basicCategory(gP);
      }
      List<String> kidn = kidLabels(t);
      ClassicCounter<List<String>> cntr = nr.get(n);
      if (cntr == null) {
        cntr = new ClassicCounter<List<String>>();
        nr.put(n, cntr);
      }
      cntr.incrementCount(kidn);
      List<String> pairStr = new ArrayList<String>(2);
      pairStr.add(n);
      pairStr.add(p);
      cntr = pr.get(pairStr);
      if (cntr == null) {
        cntr = new ClassicCounter<List<String>>();
        pr.put(pairStr, cntr);
      }
      cntr.incrementCount(kidn);
      List<String> tripleStr = new ArrayList<String>(3);
      tripleStr.add(n);
      tripleStr.add(p);
      tripleStr.add(gP);
      cntr = gpr.get(tripleStr);
      if (cntr == null) {
        cntr = new ClassicCounter<List<String>>();
        gpr.put(tripleStr, cntr);
      }
      cntr.incrementCount(kidn);
      Tree[] kids = t.children();
      for (Tree kid : kids) {
        processTreeHelper(p, n, kid);
      }
    }
  }


  public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(2);
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);

    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
      javaSB[i] = new StringBuffer("  private static String[] splitters" + (i + 1) + " = new String[] {");
    }

    ClassicCounter<List<String>> allScores = new ClassicCounter<List<String>>();
    // do value of parent
    for (String node : nodeRules.keySet()) {
      ArrayList<Pair<List<String>,Double>> answers = Generics.newArrayList();
      ClassicCounter<List<String>> cntr = nodeRules.get(node);
      double support = (cntr.totalCount());
      System.out.println("Node " + node + " support is " + support);
      for (Iterator<List<String>> it2 = pRules.keySet().iterator(); it2.hasNext();) {
        List<String> key = it2.next();
        if (key.get(0).equals(node)) {   // only do it if they match
          ClassicCounter<List<String>> cntr2 = pRules.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
          double score = kl * support2;
          answers.add(new Pair<List<String>,Double>(key, new Double(score)));
          allScores.setCount(key, score);
        }
      }
      System.out.println("----");
      System.out.println("Sorted descending support * KL");
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
        if (psd >= CUTOFFS[0]) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          for (int j = 0; j < CUTOFFS.length; j++) {
            if (psd >= CUTOFFS[j]) {
              javaSB[j].append("\"").append(nd).append("^");
              javaSB[j].append(par).append("\", ");
            }
          }
        }
      }
      System.out.println();
    }

    /*
          // do value of parent with info gain -- yet to finish this
          for (Iterator it = nodeRules.entrySet().iterator(); it.hasNext(); ) {
              Map.Entry pair = (Map.Entry) it.next();
              String node = (String) pair.getKey();
              Counter cntr = (Counter) pair.getValue();
              double support = (cntr.totalCount());
              System.out.println("Node " + node + " support is " + support);
              ArrayList dtrs = new ArrayList();
              for (Iterator it2 = pRules.entrySet().iterator(); it2.hasNext();) {
                  HashMap annotated = new HashMap();
                  Map.Entry pair2 = (Map.Entry) it2.next();
                  List node2 = (List) pair2.getKey();
                  Counter cntr2 = (Counter) pair2.getValue();
                  if (node2.get(0).equals(node)) {   // only do it if they match
                      annotated.put(node2, cntr2);
                  }
              }

              // upto

              List answers = new ArrayList();
              System.out.println("----");
              System.out.println("Sorted descending support * KL");
              Collections.sort(answers,
                               new Comparator() {
                                   public int compare(Object o1, Object o2) {
                                       Pair p1 = (Pair) o1;
                                       Pair p2 = (Pair) o2;
                                       Double p12 = (Double) p1.second();
                                       Double p22 = (Double) p2.second();
                                       return p22.compareTo(p12);
                                   }
                               });
              for (int i = 0, size = answers.size(); i < size; i++) {
                  Pair p = (Pair) answers.get(i);
                  double psd = ((Double) p.second()).doubleValue();
                  System.out.println(p.first() + ": " + nf.format(psd));
                  if (psd >= CUTOFFS[0]) {
                      List lst = (List) p.first();
                      String nd = (String) lst.get(0);
                      String par = (String) lst.get(1);
                      for (int j=0; j < CUTOFFS.length; j++) {
                          if (psd >= CUTOFFS[j]) {
                              javaSB[j].append("\"").append(nd).append("^");
                              javaSB[j].append(par).append("\", ");
                          }
                      }
                  }
              }
              System.out.println();
          }
    */

    // do value of grandparent
    for (List<String> node : pRules.keySet()) {
      ArrayList<Pair<List<String>, Double>> answers = Generics.newArrayList();
      ClassicCounter<List<String>> cntr = pRules.get(node);
      double support = (cntr.totalCount());
      if (support < SUPPCUTOFF) {
        continue;
      }
      System.out.println("Node " + node + " support is " + support);
      for (List<String> key : gPRules.keySet()) {
        if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {  // only do it if they match
          ClassicCounter<List<String>> cntr2 = gPRules.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
          double score = kl * support2;
          answers.add(Pair.makePair(key, new Double(score)));
          allScores.setCount(key,score);
        }
      }
      System.out.println("----");
      System.out.println("Sorted descending support * KL");
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
        if (psd >= CUTOFFS[0]) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          String gpar = (String) lst.get(2);
          for (int j = 0; j < CUTOFFS.length; j++) {
            if (psd >= CUTOFFS[j]) {
              javaSB[j].append("\"").append(nd).append("^");
              javaSB[j].append(par).append("~");
              javaSB[j].append(gpar).append("\", ");
            }
          }
        }
      }
      System.out.println();
    }
    System.out.println();

    System.out.println("All scores:");
    edu.stanford.nlp.util.PriorityQueue<List<String>> pq = Counters.toPriorityQueue(allScores);
    while (! pq.isEmpty()) {
      List<String> key = pq.getFirst();
      double score = pq.getPriority(key);
      pq.removeFirst();
      System.out.println(key + "\t" + score);
    }

    System.out.println("  // Automatically generated by ParentAnnotationStats -- preferably don't edit");
    for (int i = 0; i < CUTOFFS.length; i++) {
      int len = javaSB[i].length();
      javaSB[i].replace(len - 2, len, "};");
      System.out.println(javaSB[i]);
    }
    System.out.print("  public static HashSet splitters = new HashSet(Arrays.asList(");
    for (int i = CUTOFFS.length; i > 0; i--) {
      if (i == 1) {
        System.out.print("splitters1");
      } else {
        System.out.print("selectiveSplit" + i + " ? splitters" + i + " : (");
      }
    }
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
      System.out.print(")");
    }
    System.out.println(";");
  }


  private static void getSplitters(double cutOff, Map<String,ClassicCounter<List<String>>> nr,
                                   Map<List<String>,ClassicCounter<List<String>>> pr,
                                   Map<List<String>,ClassicCounter<List<String>>> gpr,
                                   Set<String> splitters) {

    // do value of parent
    for (String node : nr.keySet()) {
      List<Pair<List<String>,Double>> answers = new ArrayList<Pair<List<String>,Double>>();
      ClassicCounter<List<String>> cntr = nr.get(node);
      double support = (cntr.totalCount());
      for (List<String> key : pr.keySet()) {
        if (key.get(0).equals(node)) {   // only do it if they match
          ClassicCounter<List<String>> cntr2 = pr.get(key);
          double support2 = cntr2.totalCount();
          double kl = Counters.klDivergence(cntr2, cntr);
          answers.add(new Pair<List<String>, Double>(key, new Double(kl * support2)));
        }
      }
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair<List<String>,Double> p = answers.get(i);
        double psd = p.second().doubleValue();
        if (psd >= cutOff) {
          List<String> lst = p.first();
          String nd = lst.get(0);
          String par = lst.get(1);
          String name = nd + "^" + par;
          splitters.add(name);
        }
      }
    }

    /*
          // do value of parent with info gain -- yet to finish this
          for (Iterator it = nr.entrySet().iterator(); it.hasNext(); ) {
              Map.Entry pair = (Map.Entry) it.next();
              String node = (String) pair.getKey();
              Counter cntr = (Counter) pair.getValue();
              double support = (cntr.totalCount());
              ArrayList dtrs = new ArrayList();
              for (Iterator it2 = pr.entrySet().iterator(); it2.hasNext();) {
                  HashMap annotated = new HashMap();
                  Map.Entry pair2 = (Map.Entry) it2.next();
                  List node2 = (List) pair2.getKey();
                  Counter cntr2 = (Counter) pair2.getValue();
                  if (node2.get(0).equals(node)) {   // only do it if they match
                      annotated.put(node2, cntr2);
                  }
              }

              // upto

              List answers = new ArrayList();
              Collections.sort(answers,
                               new Comparator() {
                                   public int compare(Object o1, Object o2) {
                                       Pair p1 = (Pair) o1;
                                       Pair p2 = (Pair) o2;
                                       Double p12 = (Double) p1.second();
                                       Double p22 = (Double) p2.second();
                                       return p22.compareTo(p12);
                                   }
                               });
              for (int i = 0, size = answers.size(); i < size; i++) {
                  Pair p = (Pair) answers.get(i);
                  double psd = ((Double) p.second()).doubleValue();
                  if (psd >= cutOff) {
                      List lst = (List) p.first();
                      String nd = (String) lst.get(0);
                      String par = (String) lst.get(1);
                      String name = nd + "^" + par;
                      splitters.add(name);
                  }
              }
          }
    */

    // do value of grandparent
    for (List<String> node : pr.keySet()) {
      ArrayList<Pair<List<String>,Double>> answers = Generics.newArrayList();
      ClassicCounter<List<String>> cntr = pr.get(node);
      double support = (cntr.totalCount());
      if (support < SUPPCUTOFF) {
        continue;
      }
      for (List<String> key : gpr.keySet()) {
        if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {
          // only do it if they match
          ClassicCounter<List<String>> cntr2 = gpr.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          answers.add(new Pair<List<String>,Double>(key, new Double(kl * support2)));
        }
      }
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        if (psd >= cutOff) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          String gpar = (String) lst.get(2);
          String name = nd + "^" + par + "~" + gpar;
          splitters.add(name);
        }
      }
    }
  }


  /**
   * Calculate parent annotation statistics suitable for doing
   * selective parent splitting in the PCFGParser inside
   * FactoredParser.  <p>
   * Usage: java edu.stanford.nlp.parser.lexparser.ParentAnnotationStats
   * [-tags] treebankPath
   *
   * @param args One argument: path to the Treebank
   */
  public static void main(String[] args) {
    boolean doTags = false;
    if (args.length < 1) {
      System.out.println("Usage: java edu.stanford.nlp.parser.lexparser.ParentAnnotationStats [-tags] treebankPath");
    } else {
      int i = 0;
      boolean useCutOff = false;
      double cutOff = 0.0;
      while (args[i].startsWith("-")) {
        if (args[i].equals("-tags")) {
          doTags = true;
          i++;
        } else if (args[i].equals("-cutOff") && i + 1 < args.length) {
          useCutOff = true;
          cutOff = Double.parseDouble(args[i + 1]);
          i += 2;
        } else {
          System.err.println("Unknown option: " + args[i]);
          i++;
        }
      }

      Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()));
      treebank.loadPath(args[i]);

      if (useCutOff) {
        Set<String> splitters = getSplitCategories(treebank, doTags, 0, cutOff, cutOff, null);
        System.out.println(splitters);
      } else {
        ParentAnnotationStats pas = new ParentAnnotationStats(null, doTags);
        treebank.apply(pas);
        pas.printStats();
      }
    }
  }


  /**
   * Call this method to get a String array of categories to split on.
   * It calculates parent annotation statistics suitable for doing
   * selective parent splitting in the PCFGParser inside
   * FactoredParser.  <p>
   * If tlp is non-null tlp.basicCategory() will be called on parent and
   * grandparent nodes. <p>
   * This version just defaults some parameters.
   * <i>Implementation note:</i> This method is not designed for concurrent
   * invocation: it uses static state variables.
   */
  public static Set<String> getSplitCategories(Treebank t, double cutOff, TreebankLanguagePack tlp) {
    return getSplitCategories(t, true, 0, cutOff, cutOff, tlp);
  }


  /**
   * Call this method to get a String array of categories to split on.
   * It calculates parent annotation statistics suitable for doing
   * selective parent splitting in the PCFGParser inside
   * FactoredParser.  <p>
   * If tlp is non-null tlp.basicCategory() will be called on parent and
   * grandparent nodes. <p>
   * <i>Implementation note:</i> This method is not designed for concurrent
   * invocation: it uses static state variables.
   */
  public static Set<String> getSplitCategories(Treebank t, boolean doTags, int algorithm, double phrasalCutOff, double tagCutOff, TreebankLanguagePack tlp) {
    ParentAnnotationStats pas = new ParentAnnotationStats(tlp, doTags);
    t.apply(pas);
    Set<String> splitters = Generics.newHashSet();
    pas.getSplitters(phrasalCutOff, pas.nodeRules, pas.pRules, pas.gPRules, splitters);
    pas.getSplitters(tagCutOff, pas.tagNodeRules, pas.tagPRules, pas.tagGPRules, splitters);
    return splitters;
  }


  /**
   * This is hardwired to calculate the split categories from English
   * Penn Treebank sections 2-21 with a default cutoff of 300 (as used
   * in ACL03PCFG).  It was added to upgrading of code in cases where no
   * Treebank was available, and the pre-stored list was being used).
   */
  public static Set<String> getEnglishSplitCategories(String treebankRoot) {
    TreebankLangParserParams tlpParams = new EnglishTreebankParserParams();
    Treebank trees = tlpParams.memoryTreebank();
    trees.loadPath(treebankRoot, new NumberRangeFileFilter(200, 2199, true));
    return getSplitCategories(trees, 300.0, tlpParams.treebankLanguagePack());
  }

}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.ParentAnnotationStats

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.