Package edu.stanford.nlp.trees

Source Code of edu.stanford.nlp.trees.fakeShortNameToGRel

package edu.stanford.nlp.trees;

import static edu.stanford.nlp.trees.GrammaticalRelation.DEPENDENT;

import java.io.LineNumberReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.trees.GrammaticalRelation.Language;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;

/**
* Scoring of typed dependencies
*
* @author danielcer
*
*/
public class DependencyScoring {
  public final static boolean VERBOSE = false;

  public final List<Set<TypedDependency>> goldDeps;
  public final List<Set<TypedDependency>> goldDepsUnlabeled;
  public final boolean ignorePunc;


  private static List<Set<TypedDependency>> toSets(Collection<TypedDependency> depCollection) {
    Set<TypedDependency> depSet = Generics.newHashSet();
    Set<TypedDependency> unlabeledDepSet = Generics.newHashSet();
    for (TypedDependency dep : depCollection) {
      unlabeledDepSet.add(new TypedDependencyStringEquality(null, dep.gov(), dep.dep()));
      depSet.add(new TypedDependencyStringEquality(dep.reln(), dep.gov(), dep.dep()));
    }

    List<Set<TypedDependency>> l = new ArrayList<Set<TypedDependency>>(2);
    l.add(depSet);
    l.add(unlabeledDepSet);
    return l;
  }

  public DependencyScoring(List<Collection<TypedDependency>> goldDeps, boolean ignorePunc) {
    this.goldDeps = new ArrayList<Set<TypedDependency>>(goldDeps.size());
    this.goldDepsUnlabeled = new ArrayList<Set<TypedDependency>>(goldDeps.size());
    this.ignorePunc = ignorePunc;

    for (Collection<TypedDependency> depCollection : goldDeps) {
      List<Set<TypedDependency>> sets = toSets(depCollection);
      this.goldDepsUnlabeled.add(sets.get(1));
      this.goldDeps.add(sets.get(0));
    }
    if (ignorePunc) {
      removeHeadsAssignedToPunc(this.goldDeps);
      removeHeadsAssignedToPunc(this.goldDepsUnlabeled);
    }
  }

  private static void removeHeadsAssignedToPunc(Set<TypedDependency> depSet) {
    List<TypedDependency> deps = new ArrayList<TypedDependency>(depSet);
    for (TypedDependency dep : deps) {
      if (langIndependentPuncCheck(dep.dep().word())) {
        if (VERBOSE) {
           System.err.printf("Dropping Punctuation Dependency: %s\n", dep);
        }
        depSet.remove(dep);
      }
    }
  }

  private static void removeHeadsAssignedToPunc(List<Set<TypedDependency>> depSets) {
    for (Set<TypedDependency> depSet : depSets) {
      removeHeadsAssignedToPunc(depSet);
    }
  }

  public static boolean langIndependentPuncCheck(String token) {
    boolean isNotWord = true;
    for (int offset = 0; offset < token.length(); ) {
       final int codepoint = token.codePointAt(offset);
       if (Character.isLetterOrDigit(codepoint)) {
         isNotWord = false;
       }
       offset += Character.charCount(codepoint);
    }
    return isNotWord;
  }

  public static DependencyScoring newInstanceStringEquality(List<Collection<TypedDependency>> goldDeps, boolean ignorePunc) {
    return new DependencyScoring(convertStringEquality(goldDeps), ignorePunc);
  }

  public DependencyScoring(String filename, boolean CoNLLX, boolean ignorePunc) throws IOException {
    this((CoNLLX ? readDepsCoNLLX(filename) : readDeps(filename)), ignorePunc);
  }

  public DependencyScoring(String filename) throws IOException {
    this(filename, false, false);
  }

  public static List<Collection<TypedDependency>> convertStringEquality(List<Collection<TypedDependency>> deps){
    List<Collection<TypedDependency>> convertedDeps = new ArrayList<Collection<TypedDependency>>();
    for(Collection<TypedDependency> depSet : deps){
      Collection<TypedDependency> converted = Generics.newHashSet();
      for(TypedDependency dep : depSet){
        converted.add(new TypedDependencyStringEquality(dep.reln(), dep.gov(), dep.dep()));
      }
      convertedDeps.add(converted);
    }
    return convertedDeps;
  }

  static private class TypedDependencyStringEquality extends TypedDependency {
    private static final long serialVersionUID = 1L;

    public TypedDependencyStringEquality(GrammaticalRelation reln, IndexedWord gov, IndexedWord dep)  {
       super(reln, gov, dep);
    }

    public boolean equals(Object o) {
       // some parsers, like Relex, screw up the casing
       return o.toString().toLowerCase().equals(this.toString().toLowerCase());
    }

    @Override
    public int hashCode() {
       return toString().toLowerCase().hashCode();
    }
  }

  /** Normalize all number tokens to &lt;num&gt; in order to allow
   *  for proper scoring of MSTParser productions.
   */
  static protected String normalizeNumbers(String token) {
    String norm = token.replaceFirst("^([0-9]+)-([0-9]+)$", "<num>-$2");

    if (!norm.equals(token)) {
      System.err.printf("Normalized numbers in token: %s => %s\n", token, norm);
    }
    return token;
  }


  /**
   * Read in typed dependencies in CoNLLX format.
   *
   * @param filename
   * @throws IOException
   */
  static protected List<Collection<TypedDependency>> readDepsCoNLLX(String filename) throws IOException {
    List<GrammaticalStructure> gss = GrammaticalStructure.readCoNLLXGrammaticalStructureCollection(filename,
                  new fakeShortNameToGRel(), new GraphLessGrammaticalStructureFactory());
    List<Collection<TypedDependency>> readDeps = new ArrayList<Collection<TypedDependency>>(gss.size());
    for (GrammaticalStructure gs : gss) {
      Collection<TypedDependency> deps = gs.typedDependencies();
      readDeps.add(deps);
    }
    return readDeps;
  }

  /**
   * Read in typed dependencies. Warning created typed dependencies are not
   * backed by any sort of a tree structure.
   *
   * @param filename
   * @throws IOException
   */
  static protected List<Collection<TypedDependency>> readDeps(String filename) throws IOException {
    LineNumberReader breader = new LineNumberReader(new FileReader(filename));
    List<Collection<TypedDependency>> readDeps = new ArrayList<Collection<TypedDependency>>();
    Collection<TypedDependency> deps = new ArrayList<TypedDependency>();
    for (String line = breader.readLine(); line != null; line = breader.readLine()) {
      if (line.equals("null(-0,-0)") || line.equals("null(-1,-1)")) {
         readDeps.add(deps);
         deps = new ArrayList<TypedDependency>();
         continue; // relex parse error
      }
      try {
      if (line.equals("")) {
         if (deps.size() != 0) {
          //System.out.println(deps);
          readDeps.add(deps);
          deps = new ArrayList<TypedDependency>();
        }
        continue;
      }
      int firstParen = line.indexOf("(");
      int commaSpace = line.indexOf(", ");
      String depName = line.substring(0, firstParen);
      String govName = line.substring(firstParen + 1, commaSpace);
      String childName = line.substring(commaSpace+2, line.length() - 1);
      GrammaticalRelation grel = GrammaticalRelation.valueOf(depName);
      if (depName.startsWith("prep_")) {
        String prep = depName.substring(5);
        grel = EnglishGrammaticalRelations.getPrep(prep);
      }
      if (depName.startsWith("prepc_")) {
        String prepc = depName.substring(6);
        grel = EnglishGrammaticalRelations.getPrepC(prepc);
      }
      if (depName.startsWith("conj_")) {
        String conj = depName.substring(5);
        grel = EnglishGrammaticalRelations.getConj(conj);
      }
      if (grel == null) {
        throw new RuntimeException("Unknown grammatical relation '" + depName+"'");
      }

      //Word govWord = new Word(govName.substring(0, govDash));
      IndexedWord govWord = new IndexedWord();
      govWord.setValue(normalizeNumbers(govName));
      govWord.setWord(govWord.value());
      //Word childWord = new Word(childName.substring(0, childDash));
      IndexedWord childWord = new IndexedWord();
      childWord.setValue(normalizeNumbers(childName));
      childWord.setWord(childWord.value());
      TypedDependency dep = new TypedDependencyStringEquality(grel, govWord, childWord);
      deps.add(dep);
      } catch (Exception e) {
        breader.close();
        throw new RuntimeException("Error on line "+breader.getLineNumber()+":\n\n"+e);
      }
    }
    if (deps.size() != 0) {
       readDeps.add(deps);
    }

    //System.err.println("last: "+readDeps.get(readDeps.size()-1));
    breader.close();
    return readDeps;
  }

  /**
   * Score system typed dependencies
   *
   * @param system
   * @return a triple consisting of (labeled attachment, unlabeled attachment,
   *         label accuracy)
   */
  public Score score(List<Collection<TypedDependency>> system) {
    int parserCnt = 0;
    int goldCnt = 0;
    int parserUnlabeledCnt = 0;
    int goldUnlabeledCnt = 0;
    int correctAttachment = 0;
    int correctUnlabeledAttachment = 0;
    int labelCnt = 0;
    int labelCorrect = 0;

    ClassicCounter<String> unlabeledErrorCounts = new ClassicCounter<String>();
    ClassicCounter<String> labeledErrorCounts = new ClassicCounter<String>();
    //System.out.println("Gold size: "+ goldDeps.size() + " System size: "+system.size());
    for (int i = 0; i < system.size(); i++) {
      List<Set<TypedDependency>> l = toSets(system.get(i));
      if (ignorePunc) {
        removeHeadsAssignedToPunc(l.get(0));
        removeHeadsAssignedToPunc(l.get(1));
      }

      parserCnt += l.get(0).size();
      goldCnt += goldDeps.get(i).size();

      parserUnlabeledCnt += l.get(1).size();
      goldUnlabeledCnt += goldDepsUnlabeled.get(i).size();

      l.get(0).retainAll(goldDeps.get(i));
      l.get(1).retainAll(goldDepsUnlabeled.get(i));
      correctAttachment += l.get(0).size();
      correctUnlabeledAttachment += l.get(1).size();
      labelCnt += l.get(1).size();
      labelCorrect += l.get(0).size();
      //System.out.println(""+i+" Acc: "+(l.get(0).size())/(double)localCnt+" "+l.get(0).size()+"/"+localCnt);


      // identify errors
      List<Set<TypedDependency>> errl = toSets(system.get(i));
      errl.get(0).removeAll(goldDeps.get(i));
      errl.get(1).removeAll(goldDepsUnlabeled.get(i));
      Map<String,String> childCorrectWithLabel = Generics.newHashMap();
      Map<String,String> childCorrectWithOutLabel = Generics.newHashMap();

      for (TypedDependency goldDep: goldDeps.get(i)) {
          //System.out.print(goldDep);
          String sChild = goldDep.dep().toString().replaceFirst("-[^-]*$", "");
          String prefixLabeled = "";
          String prefixUnlabeled = "";
          if (childCorrectWithLabel.containsKey(sChild)) {
            prefixLabeled = childCorrectWithLabel.get(sChild)+", ";
            prefixUnlabeled = childCorrectWithOutLabel.get(sChild)+", ";
          }
          childCorrectWithLabel.put(sChild, prefixLabeled + goldDep.reln()+"("+goldDep.gov().toString().replaceFirst("-[^-]*$", "")+", "+sChild+")");
          childCorrectWithOutLabel.put(sChild, prefixUnlabeled + "dep("+goldDep.gov().toString().replaceFirst("-[^-]*$", "")+", "+sChild+")");
      }

      for (TypedDependency labeledError: errl.get(0)) {
          String sChild = labeledError.dep().toString().replaceFirst("-[^-]*$", "");
          String sGov   = labeledError.gov().toString().replaceFirst("-[^-]*$", "");
          labeledErrorCounts.incrementCount(labeledError.reln().toString()+"("+sGov+", "+sChild+") <= "+childCorrectWithLabel.get(sChild));
      }
      for (TypedDependency unlabeledError: errl.get(1)) {
          String sChild = unlabeledError.dep().toString().replaceFirst("-[^-]*$", "");
          String sGov   = unlabeledError.gov().toString().replaceFirst("-[^-]*$", "");
          unlabeledErrorCounts.incrementCount("dep("+sGov+", "+sChild+") <= "+childCorrectWithOutLabel.get(sChild));
      }
    }
    return new Score(parserCnt, goldCnt, parserUnlabeledCnt, goldUnlabeledCnt, correctAttachment, correctUnlabeledAttachment, labelCnt, labelCorrect, labeledErrorCounts, unlabeledErrorCounts);
  }

  public static class Score {
    final int parserCnt;
    final int goldCnt;
    final int parserUnlabeledCnt;
    final int goldUnlabeledCnt;
    final int correctAttachment;
    final int correctUnlabeledAttachment;
    final int labelCnt;
    final int labelCorrect;
    final ClassicCounter<String> unlabeledErrorCounts;
    final ClassicCounter<String> labeledErrorCounts;

    public Score(int parserCnt, int goldCnt, int parserUnlabeledCnt, int goldUnlabeledCnt, int correctAttachment, int correctUnlabeledAttachment, int labelCnt, int labelCorrect, ClassicCounter<String> labeledErrorCounts, ClassicCounter<String> unlabeledErrorCounts) {
      this.parserCnt = parserCnt;
      this.goldCnt = goldCnt;
      this.parserUnlabeledCnt = parserUnlabeledCnt;
      this.goldUnlabeledCnt = goldUnlabeledCnt;
      this.correctAttachment = correctAttachment;
      this.correctUnlabeledAttachment = correctUnlabeledAttachment;
      this.labelCnt = labelCnt;
      this.labelCorrect = labelCorrect;
      this.unlabeledErrorCounts = new ClassicCounter<String>(unlabeledErrorCounts);
      this.labeledErrorCounts = new ClassicCounter<String>(labeledErrorCounts);
    }

    public String toString() {
      return toStringFScore(false, false);
    }

    public String toStringAttachmentScore(boolean json) {
       if (parserCnt != goldCnt) {
         throw new RuntimeException(
           String.format("AttachmentScore cannot be used when count(gold deps:%d) != count(system deps:%d)", parserCnt, goldCnt));
       }
       double las = correctAttachment/(double)goldCnt;
       double uas = correctUnlabeledAttachment/(double)goldCnt;
       StringBuilder sbuild = new StringBuilder();

       if (json) {
         sbuild.append("{");
         sbuild.append(String.format("'LAS' : %.3f, ", las));
         sbuild.append(String.format("'UAS' : %.3f, ", uas));
         sbuild.append("}");
       } else {
         sbuild.append(String.format("|| Labeled Attachment Score   ||"));
         sbuild.append(String.format(" %.3f (%d/%d) ||\n", las, correctAttachment, goldCnt));

         sbuild.append(String.format("|| Unlabeled Attachment Score ||"));
         sbuild.append(String.format(" %.3f (%d/%d) ||\n", uas, correctUnlabeledAttachment, goldCnt));
       }

       return sbuild.toString();
    }

    public String toStringFScore(boolean verbose, boolean json) {

      double lp = correctAttachment/(double)parserCnt;
      double lr = correctAttachment/(double)goldCnt;
      double lf = 2.0*(lp*lr)/(lp+lr);
      /*sbuild.append(String.format("Labeled Attachment P: %.3f (%d/%d)\n", correctAttachment/(double)parserCnt, correctAttachment, parserCnt));
      sbuild.append(String.format("Labeled Attachment R: %.3f (%d/%d)\n", correctAttachment/(double)goldCnt, correctAttachment, goldCnt));
      */
      double ulp = correctUnlabeledAttachment/(double)parserUnlabeledCnt;
      double ulr = correctUnlabeledAttachment/(double)goldUnlabeledCnt;
      double ulf = 2.0*(ulp*ulr)/(ulp+ulr);
      /*
      sbuild.append(String.format("Unlabeled Attachment P: %.3f (%d/%d)\n", correctUnlabeledAttachment/(double)parserCnt, correctUnlabeledAttachment, parserCnt));
      sbuild.append(String.format("Unlabeled Attachment R: %.3f (%d/%d)\n", correctUnlabeledAttachment/(double)goldCnt, correctUnlabeledAttachment, goldCnt));
      sbuild.append(String.format("LabelAccuracy: %.3f (%d/%d)\n", labelCorrect/(double)labelCnt, labelCorrect, labelCnt));
      */
      StringBuilder sbuild = new StringBuilder();

      if (json) {
         sbuild.append("{");
         sbuild.append(String.format("'LF1' : %.3f, ", lf));
         sbuild.append(String.format("'LP' : %.3f, ", lp));
         sbuild.append(String.format("'LR' : %.3f, ", lr));

         sbuild.append(String.format("'UF1' : %.3f, ", ulf));
         sbuild.append(String.format("'UP' : %.3f, ", ulp));
         sbuild.append(String.format("'UR' : %.3f, ", ulr));

         sbuild.append("}");
      } else {
         sbuild.append(String.format("|| Labeled Attachment   || F ||  P ||  R ||\n"));
         sbuild.append(String.format("||                      || %.3f || %.3f (%d/%d) || %.3f (%d/%d)||\n",
             lf, lp, correctAttachment, parserCnt, lr, correctAttachment, goldCnt));
         sbuild.append(String.format("|| Unlabeled Attachment || F ||  P ||  R ||\n"));
         sbuild.append(String.format("||                     || %.3f || %.3f (%d/%d) || %.3f (%d/%d)||\n",
             ulf, ulp, correctUnlabeledAttachment, parserCnt, ulr, correctUnlabeledAttachment, goldCnt));

         if (verbose) {
           sbuild.append("\nLabeled Attachment Error Counts\n");
           sbuild.append(Counters.toSortedString(labeledErrorCounts, Integer.MAX_VALUE, "\t%2$f\t%1$s", "\n"));
           sbuild.append("\n");
           sbuild.append("\nUnlabeled Attachment Error Counts\n");
           sbuild.append(Counters.toSortedString(unlabeledErrorCounts, Integer.MAX_VALUE, "\t%2$f\t%1$s", "\n"));
         }
      }
      return sbuild.toString();
    }

  } // end static class Score


  public static void main(String[] args) throws IOException {
    Properties props = StringUtils.argsToProperties(args);
    boolean verbose = Boolean.parseBoolean(props.getProperty("v", "False"));
    boolean conllx = Boolean.parseBoolean(props.getProperty("conllx", "False"));
    boolean jsonOutput = Boolean.parseBoolean(props.getProperty("jsonOutput", "False"));
    boolean ignorePunc = Boolean.parseBoolean(props.getProperty("nopunc", "False"));

    String goldFilename = props.getProperty("g");
    String systemFilename = props.getProperty("s");
    if (goldFilename == null || systemFilename == null) {
      System.err.println("Usage:\n\tjava ...DependencyScoring [-v True/False] [-conllx True/False] [-jsonOutput True/False] [-ignorePunc True/False] -g goldFile -s systemFile\n");
      System.err.println("\nOptions:\n\t-v verbose output");
      System.exit(-1);
    }

    DependencyScoring goldScorer = new DependencyScoring(goldFilename, conllx, ignorePunc);

    List<Collection<TypedDependency>> systemDeps;

    if (conllx) {
      systemDeps = DependencyScoring.readDepsCoNLLX(systemFilename);
    } else {
      systemDeps = DependencyScoring.readDeps(systemFilename);
    }

    Score score = goldScorer.score(systemDeps);
    if (conllx) {
      System.out.println(score.toStringAttachmentScore(jsonOutput));
    } else {
      System.out.println(score.toStringFScore(verbose,jsonOutput));
    }
  }
}

class GraphLessGrammaticalStructureFactory implements GrammaticalStructureFromDependenciesFactory {

  public GrammaticalStructure build(
      List<TypedDependency> projectiveDependencies, TreeGraphNode root) {
    return new GraphLessGrammaticalStructure(projectiveDependencies, root);
  }
}

class GraphLessGrammaticalStructure extends GrammaticalStructure {
  private static final long serialVersionUID = 1L;

  public GraphLessGrammaticalStructure(
      List<TypedDependency> projectiveDependencies, TreeGraphNode root) {
    super(projectiveDependencies, root);
   }
}

class fakeShortNameToGRel implements Map<String, GrammaticalRelation>{

  public void clear() {
    throw new UnsupportedOperationException();
  }

  public boolean containsKey(Object o) {
    // since we generate grammatical relations dynamically, this "map" technically contains any String key
    if (o instanceof String) {
      return true;
    } else {
      return false;
    }
  }

  public boolean containsValue(Object o) {
    throw new UnsupportedOperationException();
  }

  public Set<java.util.Map.Entry<String, GrammaticalRelation>> entrySet() {
    throw new UnsupportedOperationException();
  }

  public GrammaticalRelation get(Object key) {
    if (!(key instanceof String)) {
      throw new UnsupportedOperationException();
    }
    String strkey = (String)key;
    return new GrammaticalRelation(Language.Any, strkey, null, DEPENDENT) {
      private static final long serialVersionUID = 1L;

      @Override
      public boolean equals(Object o) {
        if (o instanceof GrammaticalRelation) {
           return this.getShortName().equals(((GrammaticalRelation)o).getShortName());
        }
        return false;
      }

      @Override
      public int hashCode() {
        return this.getShortName().hashCode();
      }
    };
  }

  public boolean isEmpty() {
    return false;
  }

  public Set<String> keySet() {
    throw new UnsupportedOperationException();
  }

  public GrammaticalRelation put(String key, GrammaticalRelation value) {
    throw new UnsupportedOperationException();
  }

  public void putAll(Map<? extends String, ? extends GrammaticalRelation> m) {
    throw new UnsupportedOperationException();
  }

  public GrammaticalRelation remove(Object key) {
    throw new UnsupportedOperationException();
  }

  public int size() {
    throw new UnsupportedOperationException();
  }

  public Collection<GrammaticalRelation> values() {
    throw new UnsupportedOperationException();
  }
}
TOP

Related Classes of edu.stanford.nlp.trees.fakeShortNameToGRel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.