Package edu.stanford.nlp.ie.machinereading

Source Code of edu.stanford.nlp.ie.machinereading.BasicRelationFeatureFactory

package edu.stanford.nlp.ie.machinereading;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

import edu.stanford.nlp.ie.machinereading.structure.*;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations.GenderAnnotation;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations.TriggerAnnotation;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.TokensAnnotation;
import edu.stanford.nlp.process.Morphology;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.EnglishGrammaticalRelations;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.SemanticGraphFactory;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.StringUtils;

// XXX convert to BasicRelationFeatureFactory, make RelationFeatureFactory an interface


/**
@author Mason Smith
@author Mihai Surdeanu
*/
public class BasicRelationFeatureFactory extends RelationFeatureFactory implements Serializable {
  private static final long serialVersionUID = -7376668998622546620L;

  private static final Logger logger = Logger.getLogger(BasicRelationFeatureFactory.class.getName());


  protected static final List<String> dependencyFeatures = Collections.unmodifiableList(Arrays.asList(
          "dependency_path_lowlevel","dependency_path_length","dependency_path_length_binary",
          "verb_in_dependency_path","dependency_path","dependency_path_words","dependency_paths_to_verb",
          "dependency_path_stubs_to_verb",
          "dependency_path_POS_unigrams",
          "dependency_path_word_n_grams",
          "dependency_path_POS_n_grams",
          "dependency_path_edge_n_grams","dependency_path_edge_lowlevel_n_grams",
          "dependency_path_edge-node-edge-grams","dependency_path_edge-node-edge-grams_lowlevel",
          "dependency_path_node-edge-node-grams","dependency_path_node-edge-node-grams_lowlevel",
          "dependency_path_directed_bigrams",
          "dependency_path_edge_unigrams",
          "dependency_path_trigger"
  ));

  protected List<String> featureList;



  public BasicRelationFeatureFactory(String... featureList) {
    this.doNotLexicalizeFirstArg = false;
    this.dependencyType = DEPENDENCY_TYPE.COLLAPSED_CCPROCESSED;
    this.featureList = Collections.unmodifiableList(Arrays.asList(featureList));
  }

  static {
    logger.setLevel(Level.INFO);
  }


  public Datum<String,String> createDatum(RelationMention rel) {
    return createDatum(rel, (Logger) null);
  }

  public Datum<String,String> createDatum(RelationMention rel, Logger logger) {
    Counter<String> features = new ClassicCounter<String>();
    if (rel.getArgs().size() != 2) {
      return null;
    }

    addFeatures(features, rel, featureList, logger);

    String labelString = rel.getType();
    return new RVFDatum<String, String>(features, labelString);
  }

  @Override
  public Datum<String, String> createTestDatum(RelationMention rel, Logger logger) {
    return createDatum(rel, logger);
  }

  public Datum<String,String> createDatum(RelationMention rel, String positiveLabel) {
    Counter<String> features = new ClassicCounter<String>();
    if (rel.getArgs().size() != 2) {
      return null;
    }

    addFeatures(features, rel, featureList);

    String labelString = rel.getType();
    if(! labelString.equals(positiveLabel)) labelString = RelationMention.UNRELATED;
    return new RVFDatum<String, String>(features, labelString);
  }

  public boolean addFeatures(Counter<String> features, RelationMention rel, List<String> types) {
    return addFeatures(features, rel, types, null);
  }

  /**
   * Creates all features for the datum corresponding to this relation mention
   * Note: this assumes binary relations where both arguments are EntityMention
   * @param features Stores all features
   * @param rel The mention
   * @param types Comma separated list of feature classes to use
   */
  public boolean addFeatures(Counter<String> features, RelationMention rel, List<String> types, Logger logger) {
    // sanity checks: must have two arguments, and each must be an entity mention
    if(rel.getArgs().size() != 2) return false;
    if(! (rel.getArg(0) instanceof EntityMention)) return false;
    if(! (rel.getArg(1) instanceof EntityMention)) return false;

    EntityMention arg0 = (EntityMention) rel.getArg(0);
    EntityMention arg1 = (EntityMention) rel.getArg(1);

    Tree tree = rel.getSentence().get(TreeAnnotation.class);
    if(tree == null){
      throw new RuntimeException("ERROR: Relation extraction requires full syntactic analysis!");
    }
    List<Tree> leaves = tree.getLeaves();
    List<CoreLabel> tokens = rel.getSentence().get(TokensAnnotation.class);

    // this assumes that both args are in the same sentence as the relation object
    // let's check for this to be safe
    CoreMap relSentence = rel.getSentence();
    CoreMap arg0Sentence = arg0.getSentence();
    CoreMap arg1Sentence = arg1.getSentence();
    if(arg0Sentence != relSentence){
      System.err.println("WARNING: Found relation with arg0 in a different sentence: " + rel);
      System.err.println("Relation sentence: " + relSentence.get(TextAnnotation.class));
      System.err.println("Arg0 sentence: " + arg0Sentence.get(TextAnnotation.class));
      return false;
    }
    if(arg1Sentence != relSentence){
      System.err.println("WARNING: Found relation with arg1 in a different sentence: " + rel);
      System.err.println("Relation sentence: " + relSentence.get(TextAnnotation.class));
      System.err.println("Arg1 sentence: " + arg1Sentence.get(TextAnnotation.class));
      return false;
    }

    // Checklist keeps track of which features have been handled by an if clause
    // Should be empty after all the clauses have been gone through.
    List<String> checklist = new ArrayList<String>(types);

    // arg_type: concatenation of the entity types of the args, e.g.
    // "arg1type=Loc_and_arg2type=Org"
    // arg_subtype: similar, for entity subtypes
    if (usingFeature(types, checklist, "arg_type")) {
      features.setCount("arg1type=" + arg0.getType() + "_and_arg2type=" + arg1.getType(), 1.0);
    }
    if (usingFeature(types,checklist,"arg_subtype")) {
      features.setCount("arg1subtype="+arg0.getSubType()+"_and_arg2subtype="+arg1.getSubType(),1.0);
    }

    // arg_order: which arg comes first in the sentence
    if (usingFeature(types, checklist, "arg_order")) {
      if (arg0.getSyntacticHeadTokenPosition() < arg1.getSyntacticHeadTokenPosition())
        features.setCount("arg1BeforeArg2", 1.0);
    }
    // same_head: whether the two args share the same syntactic head token
    if (usingFeature(types, checklist, "same_head")) {
      if (arg0.getSyntacticHeadTokenPosition() == arg1.getSyntacticHeadTokenPosition())
        features.setCount("arguments_have_same_head",1.0);
    }

    // full_tree_path: Path from one arg to the other in the phrase structure tree,
    // e.g., NNP -> PP -> NN <- NNP
    if (usingFeature(types, checklist, "full_tree_path")) {
      //System.err.println("ARG0: " + arg0);
      //System.err.println("ARG0 HEAD: " + arg0.getSyntacticHeadTokenPosition());
      //System.err.println("TREE: " + tree);
      //System.err.println("SENTENCE: " + sentToString(arg0.getSentence()));
      if(arg0.getSyntacticHeadTokenPosition() < leaves.size() && arg1.getSyntacticHeadTokenPosition() < leaves.size()){
        Tree arg0preterm = leaves.get(arg0.getSyntacticHeadTokenPosition()).parent(tree);
        Tree arg1preterm = leaves.get(arg1.getSyntacticHeadTokenPosition()).parent(tree);
        Tree join = tree.joinNode(arg0preterm, arg1preterm);
        StringBuilder pathStringBuilder = new StringBuilder();
        List<Tree> pathUp = join.dominationPath(arg0preterm);
        Collections.reverse(pathUp);
        for (Tree node : pathUp) {
          if (node != join) {
            pathStringBuilder.append(node.label().value() + " <- ");
          }
        }

        for (Tree node : join.dominationPath(arg1preterm)) {
          pathStringBuilder.append(((node == join) ? "" : " -> ") + node.label().value());
        }
        String pathString = pathStringBuilder.toString();
        if(logger != null && ! rel.getType().equals(RelationMention.UNRELATED)) logger.info("full_tree_path: " + pathString);
        features.setCount("treepath:"+pathString, 1.0);
      } else {
        System.err.println("WARNING: found weird argument offsets. Most likely because arguments appear in different sentences than the relation:");
        System.err.println("ARG0: " + arg0);
        System.err.println("ARG0 HEAD: " + arg0.getSyntacticHeadTokenPosition());
        System.err.println("ARG0 SENTENCE: " + sentToString(arg0.getSentence()));
        System.err.println("ARG1: " + arg1);
        System.err.println("ARG1 HEAD: " + arg1.getSyntacticHeadTokenPosition());
        System.err.println("ARG1 SENTENCE: " + sentToString(arg1.getSentence()));
        System.err.println("RELATION TREE: " + tree);
      }
    }

    int pathLength = tree.pathNodeToNode(tree.getLeaves().get(arg0.getSyntacticHeadTokenPosition()),
            tree.getLeaves().get(arg1.getSyntacticHeadTokenPosition())).size();
    // path_length: Length of the path in the phrase structure parse tree, integer-valued feature
    if (usingFeature(types, checklist, "path_length")) {
      features.setCount("path_length", pathLength);
    }
    // path_length_binary: Length of the path in the phrase structure parse tree, binary features
    if (usingFeature(types, checklist, "path_length_binary")) {
      features.setCount("path_length_" + pathLength, 1.0);
    }

    /* entity_order
           * This tells you for each of the two args
           * whether there are other entities before or after that arg.
           * In particular, it can tell whether an arg is the first entity of its type in the sentence
           * (which can be useful for example for telling the gameWinner and gameLoser in NFL).
           * TODO: restrict this feature so that it only looks for
           * entities of the same type?
           * */
    if (usingFeature(types, checklist, "entity_order")) {
      for (int i = 0; i < rel.getArgs().size(); i++) {
        // We already checked the class of the args at the beginning of the method
        EntityMention arg = (EntityMention) rel.getArgs().get(i);
        if(rel.getSentence().get(MachineReadingAnnotations.EntityMentionsAnnotation.class) != null) { // may be null due to annotation error
          for (EntityMention otherArg : rel.getSentence().get(MachineReadingAnnotations.EntityMentionsAnnotation.class)) {
            String feature;
            if (otherArg.getSyntacticHeadTokenPosition() > arg.getSyntacticHeadTokenPosition()) {
              feature = "arg" + i + "_before_" + otherArg.getType();
              features.setCount(feature, 1.0);
            }
            if (otherArg.getSyntacticHeadTokenPosition() < arg.getSyntacticHeadTokenPosition()) {
              feature = "arg" + i + "_after_" + otherArg.getType();
              features.setCount(feature, 1.0);
            }
          }
        }
      }
    }

    // surface_distance: Number of tokens in the sentence between the two words, integer-valued feature
    int surfaceDistance = Math.abs(arg0.getSyntacticHeadTokenPosition() - arg1.getSyntacticHeadTokenPosition());
    if (usingFeature(types, checklist, "surface_distance")) {
      features.setCount("surface_distance", surfaceDistance);
    }
    // surface_distance_binary: Number of tokens in the sentence between the two words, binary features
    if (usingFeature(types, checklist, "surface_distance_binary")) {
      features.setCount("surface_distance_" + surfaceDistance, 1.0);
    }
    // surface_distance_bins: number of tokens between the two args, binned to several intervals
    if(usingFeature(types, checklist, "surface_distance_bins")) {
      if(surfaceDistance < 4){
        features.setCount("surface_distance_bin" + surfaceDistance, 1.0);
      } else if(surfaceDistance < 6){
        features.setCount("surface_distance_bin_lt6", 1.0);
      } else if(surfaceDistance < 10) {
        features.setCount("surface_distance_bin_lt10", 1.0);
      } else {
        features.setCount("surface_distance_bin_ge10", 1.0);
      }
    }

    // separate_surface_windows: windows of 1,2,3 tokens before and after args, for each arg separately
    // Separate features are generated for windows to the left and to the right of the args.
    // Features are concatenations of words in the window (or NULL for sentence boundary).
    //
    // conjunction_surface_windows: concatenation of the windows of the two args
    //
    // separate_surface_windows_POS: windows of POS tags of size 1,2,3 for each arg
    //
    // conjunction_surface_windows_POS: concatenation of windows of the args

    List<EntityMention> args = new ArrayList<EntityMention>();
    args.add(arg0); args.add(arg1);
    for (int windowSize = 1; windowSize <= 3; windowSize++) {

      String[] leftWindow, rightWindow, leftWindowPOS, rightWindowPOS;
      leftWindow = new String[2];
      rightWindow = new String[2];
      leftWindowPOS = new String[2];
      rightWindowPOS = new String[2];

      for (int argn = 0; argn <= 1; argn++) {
        int ind = args.get(argn).getSyntacticHeadTokenPosition();
        for (int winnum = 1; winnum <= windowSize; winnum++) {
          int windex = ind - winnum;
          if (windex > 0) {
            leftWindow[argn] = leaves.get(windex).label().value() + "_" + leftWindow[argn];
            leftWindowPOS[argn] = leaves.get(windex).parent(tree).label().value() + "_" + leftWindowPOS[argn];
          } else {
            leftWindow[argn] = "NULL_" + leftWindow[argn];
            leftWindowPOS[argn] = "NULL_" + leftWindowPOS[argn];
          }
          windex = ind + winnum;
          if (windex < leaves.size()) {
            rightWindow[argn] = rightWindow[argn] + "_" + leaves.get(windex).label().value();
            rightWindowPOS[argn] = rightWindowPOS[argn] + "_" + leaves.get(windex).parent(tree).label().value();
          } else {
            rightWindow[argn] = rightWindow[argn] + "_NULL";
            rightWindowPOS[argn] = rightWindowPOS[argn] + "_NULL";
          }
        }
        if (usingFeature(types, checklist, "separate_surface_windows")) {
          features.setCount("left_window_"+windowSize+"_arg_" + argn + ": " + leftWindow[argn], 1.0);
          features.setCount("left_window_"+windowSize+"_POS_arg_" + argn + ": " + leftWindowPOS[argn], 1.0);
        }
        if (usingFeature(types, checklist, "separate_surface_windows_POS")) {
          features.setCount("right_window_"+windowSize+"_arg_" + argn + ": " + rightWindow[argn], 1.0);
          features.setCount("right_window_"+windowSize+"_POS_arg_" + argn + ": " + rightWindowPOS[argn], 1.0);
        }

      }
      if (usingFeature(types, checklist, "conjunction_surface_windows")) {
        features.setCount("left_windows_"+windowSize+": " + leftWindow[0] + "__" + leftWindow[1], 1.0);
        features.setCount("right_windows_"+windowSize+": " + rightWindow[0] + "__" + rightWindow[1], 1.0);
      }
      if (usingFeature(types, checklist, "conjunction_surface_windows_POS")) {
        features.setCount("left_windows_"+windowSize+"_POS: " + leftWindowPOS[0] + "__" + leftWindowPOS[1], 1.0);
        features.setCount("right_windows_"+windowSize+"_POS: " + rightWindowPOS[0] + "__" + rightWindowPOS[1], 1.0);
      }
    }

    // arg_words:  The actual arg tokens as separate features, and concatenated
    String word0 = leaves.get(arg0.getSyntacticHeadTokenPosition()).label().value();
    String word1 = leaves.get(arg1.getSyntacticHeadTokenPosition()).label().value();
    if (usingFeature(types, checklist, "arg_words")) {
      if(doNotLexicalizeFirstArg == false)
        features.setCount("word_arg0: " + word0, 1.0);
      features.setCount("word_arg1: " + word1, 1.0);
      if(doNotLexicalizeFirstArg == false)
        features.setCount("words: " + word0 + "__" + word1, 1.0);
    }

    // arg_POS:  POS tags of the args, as separate features and concatenated
    String pos0 = leaves.get(arg0.getSyntacticHeadTokenPosition()).parent(tree).label().value();
    String pos1 = leaves.get(arg1.getSyntacticHeadTokenPosition()).parent(tree).label().value();
    if (usingFeature(types, checklist, "arg_POS")) {
      features.setCount("POS_arg0: " + pos0, 1.0);
      features.setCount("POS_arg1: " + pos1, 1.0);
      features.setCount("POSs: " + pos0 + "__" + pos1, 1.0);
    }

    // adjacent_words: words immediately to the left and right of the args
    if(usingFeature(types, checklist, "adjacent_words")){
      for(int i = 0; i < rel.getArgs().size(); i ++){
        Span s = ((EntityMention) rel.getArg(i)).getHead();
        if(s.start() > 0){
          String v = tokens.get(s.start() - 1).word();
          features.setCount("leftarg" + i + "-" + v, 1.0);
        }
        if(s.end() < tokens.size()){
          String v = tokens.get(s.end()).word();
          features.setCount("rightarg" + i + "-" + v, 1.0);
        }
      }
    }

    // entities_between_args:  binary feature for each type specifying whether there is an entity of that type in the sentence
    // between the two args.
    // e.g. "entity_between_args: Loc" means there is at least one entity of type Loc between the two args
    if (usingFeature(types, checklist, "entities_between_args")) {
      CoreMap sent = rel.getSentence();
      if(sent == null) throw new RuntimeException("NULL sentence for relation " + rel);
      List<EntityMention> relArgs = sent.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
      if(relArgs != null) { // may be null due to annotation errors!
        for (EntityMention arg : relArgs) {
          if ((arg.getSyntacticHeadTokenPosition() > arg0.getSyntacticHeadTokenPosition() && arg.getSyntacticHeadTokenPosition() < arg1.getSyntacticHeadTokenPosition())
                  || (arg.getSyntacticHeadTokenPosition() > arg1.getSyntacticHeadTokenPosition() && arg.getSyntacticHeadTokenPosition() < arg0.getSyntacticHeadTokenPosition())) {
            features.setCount("entity_between_args: " + arg.getType(), 1.0);
          }
        }
      }
    }

    // entity_counts: For each type, the total number of entities of that type in the sentence (integer-valued feature)
    // entity_counts_binary: Counts of entity types as binary features.
    Counter<String> typeCounts = new ClassicCounter<String>();
    if(rel.getSentence().get(MachineReadingAnnotations.EntityMentionsAnnotation.class) != null){ // may be null due to annotation errors!
      for (EntityMention arg : rel.getSentence().get(MachineReadingAnnotations.EntityMentionsAnnotation.class))
        typeCounts.incrementCount(arg.getType());
      for (String type : typeCounts.keySet()) {
        if (usingFeature(types,checklist,"entity_counts"))
          features.setCount("entity_counts_"+type,typeCounts.getCount(type));
        if (usingFeature(types,checklist,"entity_counts_binary"))
          features.setCount("entity_counts_"+type+": "+typeCounts.getCount(type),1.0);
      }
    }

    // surface_path: concatenation of tokens between the two args
    // surface_path_POS: concatenation of POS tags between the args
    // surface_path_selective: concatenation of tokens between the args which are nouns or verbs
    StringBuilder sb = new StringBuilder();
    StringBuilder sbPOS = new StringBuilder();
    StringBuilder sbSelective = new StringBuilder();
    for (int i = Math.min(arg0.getSyntacticHeadTokenPosition(), arg1.getSyntacticHeadTokenPosition()) + 1; i < Math.max(arg0.getSyntacticHeadTokenPosition(), arg1.getSyntacticHeadTokenPosition()); i++) {
      String word = leaves.get(i).label().value();
      sb.append(word + "_");
      String pos = leaves.get(i).parent(tree).label().value();
      sbPOS.append(pos + "_");
      if (pos.equals("NN") || pos.equals("NNS") || pos.equals("NNP") || pos.equals("NNPS") || pos.equals("VB")
              || pos.equals("VBN") || pos.equals("VBD") || pos.equals("VBG") || pos.equals("VBP") || pos.equals("VBZ")) {
        sbSelective.append(word + "_");
      }
    }
    if (usingFeature(types, checklist, "surface_path")) {
      features.setCount("surface_path: " + sb, 1.0);
    }
    if (usingFeature(types, checklist, "surface_path_POS")) {
      features.setCount("surface_path_POS: " + sbPOS, 1.0);
    }
    if (usingFeature(types, checklist, "surface_path_selective")) {
      features.setCount("surface_path_selective: " + sbSelective, 1.0);
    }

    int swStart, swEnd; // must be initialized below
    if (arg0.getSyntacticHeadTokenPosition() < arg1.getSyntacticHeadTokenPosition()){
      swStart = arg0.getExtentTokenEnd();
      swEnd = arg1.getExtentTokenStart();
    } else {
      swStart = arg1.getExtentTokenEnd();
      swEnd = arg0.getExtentTokenStart();
    }

    // span_words_unigrams: words that appear in between the two arguments
    if (usingFeature(types, checklist, "span_words_unigrams")) {
      for(int i = swStart; i < swEnd; i ++){
        features.setCount("span_word:" + tokens.get(i).word(), 1.0);
      }
    }

    // span_words_bigrams: bigrams of words that appear in between the two arguments
    if (usingFeature(types, checklist, "span_words_bigrams")) {
      for(int i = swStart; i < swEnd - 1; i ++){
        features.setCount("span_bigram:" + tokens.get(i).word() + "-" + tokens.get(i + 1).word(), 1.0);
      }
    }

    if (usingFeature(types, checklist, "span_words_trigger")) {
      for (int i = swStart; i < swEnd; i++) {
        String trigger = tokens.get(i).get(TriggerAnnotation.class);
        if (trigger != null && trigger.startsWith("B-"))
          features.incrementCount("span_words_trigger=" + trigger.substring(2));
      }
    }

    if (usingFeature(types, checklist, "arg2_number")) {
      if (arg1.getType().equals("NUMBER")){
        try {
          int value = Integer.parseInt(arg1.getValue());

          if (2 <= value && value <= 100)
            features.setCount("arg2_number", 1.0);
          if (2 <= value && value <= 19)
            features.setCount("arg2_number_2", 1.0);
          if (20 <= value && value <= 59)
            features.setCount("arg2_number_20", 1.0);
          if (60 <= value && value <= 100)
            features.setCount("arg2_number_60", 1.0);
          if (value >= 100)
            features.setCount("arg2_number_100", 1.0);
        } catch (NumberFormatException e) {}
      }
    }

    if (usingFeature(types, checklist, "arg2_date")) {
      if (arg1.getType().equals("DATE")){
        try {
          int value = Integer.parseInt(arg1.getValue());

          if (0 <= value && value <= 2010)
            features.setCount("arg2_date", 1.0);
          if (0 <= value && value <= 999)
            features.setCount("arg2_date_0", 1.0);
          if (1000 <= value && value <= 1599)
            features.setCount("arg2_date_1000", 1.0);
          if (1600 <= value && value <= 1799)
            features.setCount("arg2_date_1600", 1.0);
          if (1800 <= value && value <= 1899)
            features.setCount("arg2_date_1800", 1.0);
          if (1900 <= value && value <= 1999)
            features.setCount("arg2_date_1900", 1.0);
          if (value >= 2000)
            features.setCount("arg2_date_2000", 1.0);
        } catch (NumberFormatException e) {}
      }
    }

    if (usingFeature(types, checklist, "arg_gender")) {
      boolean arg0Male = false, arg0Female = false;
      boolean arg1Male = false, arg1Female = false;
      System.out.println("Adding gender annotations!");

      int index = arg0.getExtentTokenStart();
      String gender = tokens.get(index).get(GenderAnnotation.class);
      System.out.println(tokens.get(index).word() + " -- " + gender);
      if (gender.equals("MALE"))
        arg0Male = true;
      else if (gender.equals("FEMALE"))
        arg0Female = true;

      index = arg1.getExtentTokenStart();
      gender = tokens.get(index).get(GenderAnnotation.class);
      if (gender.equals("MALE"))
        arg1Male = true;
      else if (gender.equals("FEMALE"))
        arg1Female = true;

      if (arg0Male) features.setCount("arg1_male", 1.0);
      if (arg0Female) features.setCount("arg1_female", 1.0);
      if (arg1Male) features.setCount("arg2_male", 1.0);
      if (arg1Female) features.setCount("arg2_female", 1.0);

      if ((arg0Male && arg1Male) || (arg0Female && arg1Female))
        features.setCount("arg_same_gender", 1.0);
      if ((arg0Male && arg1Female) || (arg0Female && arg1Male))
        features.setCount("arg_different_gender", 1.0);
    }

    List<String> tempDepFeatures = new ArrayList<String>(dependencyFeatures);
    if (tempDepFeatures.removeAll(types) || types.contains("all")) { // dependencyFeatures contains at least one of the features listed in types
      addDependencyPathFeatures(features, rel, arg0, arg1, types, checklist, logger);
    }

    if (!checklist.isEmpty() && !checklist.contains("all"))
      throw new AssertionError("RelationFeatureFactory: features not handled: "+checklist);


    List<String> featureList = new ArrayList<String>(features.keySet());
    Collections.sort(featureList);

//    for (String feature : featureList) {
//      logger.info(feature+"\n"+"count="+features.getCount(feature));
//    }

    return true;

  }

  String sentToString(CoreMap sentence) {
    StringBuffer os = new StringBuffer();
    List<CoreLabel> tokens = sentence.get(TokensAnnotation.class);
    if(tokens != null){
      boolean first = true;
      for(CoreLabel token: tokens) {
        if(! first) os.append(" ");
        os.append(token.word());
        first = false;
      }
    }

    return os.toString();
  }

  protected void addDependencyPathFeatures(
          Counter<String> features,
          RelationMention rel,
          EntityMention arg0,
          EntityMention arg1,
          List<String> types,
          List<String> checklist,
          Logger logger) {
    SemanticGraph graph = null;
    if(dependencyType == null) dependencyType = DEPENDENCY_TYPE.COLLAPSED_CCPROCESSED; // needed for backwards compatibility. old serialized models don't have it
    if(dependencyType == DEPENDENCY_TYPE.COLLAPSED_CCPROCESSED)
      graph = rel.getSentence().get(SemanticGraphCoreAnnotations.CollapsedCCProcessedDependenciesAnnotation.class);
    else if(dependencyType == DEPENDENCY_TYPE.COLLAPSED)
      graph = rel.getSentence().get(SemanticGraphCoreAnnotations.CollapsedDependenciesAnnotation.class);
    else if(dependencyType == DEPENDENCY_TYPE.BASIC)
      graph = rel.getSentence().get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
    else
      throw new RuntimeException("ERROR: unknown dependency type: " + dependencyType);

    if (graph == null) {
      Tree tree = rel.getSentence().get(TreeAnnotation.class);
      if(tree == null){
        System.err.println("WARNING: found sentence without TreeAnnotation. Skipped dependency-path features.");
        return;
      }
      try {
        graph = SemanticGraphFactory.generateCollapsedDependencies(tree);
      } catch(Exception e){
        System.err.println("WARNING: failed to generate dependencies from tree " + tree.toString());
        e.printStackTrace();
        System.err.println("Skipped dependency-path features.");
        return;
      }
    }

    IndexedWord node0 = graph.getNodeByIndexSafe(arg0.getSyntacticHeadTokenPosition() + 1);
    IndexedWord node1 = graph.getNodeByIndexSafe(arg1.getSyntacticHeadTokenPosition() + 1);
    if (node0 == null) {
      checklist.removeAll(dependencyFeatures);
      return;
    }
    if (node1 == null) {
      checklist.removeAll(dependencyFeatures);
      return;
    }

    List<SemanticGraphEdge> edgePath = graph.getShortestUndirectedPathEdges(node0, node1);
    List<IndexedWord> pathNodes = graph.getShortestUndirectedPathNodes(node0, node1);

    if (edgePath == null) {
      checklist.removeAll(dependencyFeatures);
      return;
    }

    if (pathNodes == null || pathNodes.size() <= 1) { // arguments have the same head.
      checklist.removeAll(dependencyFeatures);
      return;
    }

    // dependency_path: Concatenation of relations in the path between the args in the dependency graph, including directions
    // e.g. "subj->  <-prep_in  <-mod"
    // dependency_path_lowlevel: Same but with finer-grained syntactic relations
    // e.g. "nsubj->  <-prep_in  <-nn"
    if (usingFeature(types, checklist, "dependency_path")) {
      features.setCount("dependency_path:"+generalizedDependencyPath(edgePath, node0), 1.0);
    }
    if (usingFeature(types, checklist, "dependency_path_lowlevel")) {
      String depLowLevel = dependencyPath(edgePath, node0);
      if(logger != null && ! rel.getType().equals(RelationMention.UNRELATED)) logger.info("dependency_path_lowlevel: " + depLowLevel);
      features.setCount("dependency_path_lowlevel:" + depLowLevel, 1.0);
    }

    List<String> pathLemmas = new ArrayList<String>();
    List<String> noArgPathLemmas = new ArrayList<String>();
    // do not add to pathLemmas words that belong to one of the two args
    Set<Integer> indecesToSkip = new HashSet<Integer>();
    for(int i = arg0.getExtentTokenStart(); i < arg0.getExtentTokenEnd(); i ++) indecesToSkip.add(i + 1);
    for(int i = arg1.getExtentTokenStart(); i < arg1.getExtentTokenEnd(); i ++) indecesToSkip.add(i + 1);
    for (IndexedWord node : pathNodes){
      pathLemmas.add(Morphology.lemmaStatic(node.value(), node.tag(), true));
      if(! indecesToSkip.contains(node.index()))
        noArgPathLemmas.add(Morphology.lemmaStatic(node.value(), node.tag(), true));
    }


    // Verb-based features
    // These features were designed on the assumption that verbs are often trigger words
    // (specifically with the "Kill" relation from Roth CONLL04 in mind)
    // but they didn't end up boosting performance on Roth CONLL04, so they may not be necessary.
    //
    // dependency_paths_to_verb: for each verb in the dependency path,
    // the path to the left of the (lemmatized) verb, to the right, and both, e.g.
    // "subj-> be"
    // "be  <-prep_in  <-mod"
    // "subj->  be  <-prep_in  <-mod"
    // (Higher level relations used as opposed to "lowlevel" finer grained relations)
    if (usingFeature(types, checklist, "dependency_paths_to_verb")) {
      for (IndexedWord node : pathNodes) {
        if (node.tag().contains("VB")) {
          if (node.equals(node0) || node.equals(node1)) {
            continue;
          }
          String lemma = Morphology.lemmaStatic(node.value(), node.tag(), true);
          String node1Path = generalizedDependencyPath(graph.getShortestUndirectedPathEdges(node, node1), node);
          String node0Path = generalizedDependencyPath(graph.getShortestUndirectedPathEdges(node0, node), node0);
          features.setCount("dependency_paths_to_verb:" + node0Path + " " + lemma, 1.0);
          features.setCount("dependency_paths_to_verb:" + lemma + " " + node1Path, 1.0);
          features.setCount("dependency_paths_to_verb:" + node0Path + " " + lemma + " " + node1Path, 1.0);
        }
      }
    }
    // dependency_path_stubs_to_verb:
    // For each verb in the dependency path,
    // the verb concatenated with the first (high-level) relation in the path from arg0;
    // the verb concatenated with the first relation in the path from arg1,
    // and the verb concatenated with both relations.  E.g. (same arguments and sentence as example above)
    // "stub: subj->  be"
    // "stub: be  <-mod"
    // "stub: subj->  be  <-mod"
    if (usingFeature(types, checklist, "dependency_path_stubs_to_verb")) {
      for (IndexedWord node : pathNodes) {
        SemanticGraphEdge edge0 = edgePath.get(0);
        SemanticGraphEdge edge1 = edgePath.get(edgePath.size() - 1);
        if (node.tag().contains("VB")) {
          if (node.equals(node0) || node.equals(node1)) {
            continue;
          }
          String lemma = Morphology.lemmaStatic(node.value(), node.tag(), true);
          String edge0str, edge1str;
          if (node0.equals(edge0.getGovernor())) {
            edge0str = "<-" + generalizeRelation(edge0.getRelation());
          } else {
            edge0str = generalizeRelation(edge0.getRelation()) + "->";
          }
          if (node1.equals(edge1.getGovernor())) {
            edge1str = generalizeRelation(edge1.getRelation()) + "->";
          } else {
            edge1str = "<-" + generalizeRelation(edge1.getRelation());
          }
          features.setCount("stub: " + edge0str + " " + lemma, 1.0);
          features.setCount("stub: " + lemma + edge1str, 1.0);
          features.setCount("stub: " + edge0str + " " + lemma + " " + edge1str, 1.0);
        }
      }
    }

    if (usingFeature(types, checklist, "verb_in_dependency_path")) {
      for (IndexedWord node : pathNodes) {
        if (node.tag().contains("VB")) {
          if (node.equals(node0) || node.equals(node1)) {
            continue;
          }
          SemanticGraphEdge rightEdge = graph.getShortestUndirectedPathEdges(node, node1).get(0);
          SemanticGraphEdge leftEdge = graph.getShortestUndirectedPathEdges(node, node0).get(0);
          String rightRelation, leftRelation;
          boolean governsLeft = false, governsRight = false;
          if (node.equals(rightEdge.getGovernor())) {
            rightRelation = " <-" + generalizeRelation(rightEdge.getRelation());
            governsRight = true;
          } else {
            rightRelation = generalizeRelation(rightEdge.getRelation()) + "-> ";
          }
          if (node.equals(leftEdge.getGovernor())) {
            leftRelation = generalizeRelation(leftEdge.getRelation()) + "-> ";
            governsLeft = true;
          } else {
            leftRelation = " <-" + generalizeRelation(leftEdge.getRelation());
          }
          String lemma = Morphology.lemmaStatic(node.value(), node.tag(), true);

          if (governsLeft || governsRight) {
          }
          if (governsLeft) {
            features.setCount("verb: " + leftRelation + lemma, 1.0);
          }
          if (governsRight) {
            features.setCount("verb: " + lemma + rightRelation, 1.0);
          }
          if (governsLeft && governsRight) {
            features.setCount("verb: " + leftRelation + lemma + rightRelation, 1.0);
          }
        }
      }
    }


    // FEATURES FROM BJORNE ET AL., BIONLP'09
    // dependency_path_words: generates a feature for each word in the dependency path (lemmatized)
    // dependency_path_POS_unigrams: generates a feature for the POS tag of each word in the dependency path
    if (usingFeature(types, checklist, "dependency_path_words")) {
      for (String lemma : noArgPathLemmas)
        features.setCount("word_in_dependency_path:" + lemma, 1.0);
    }
    if (usingFeature(types, checklist, "dependency_path_POS_unigrams")) {
      for (IndexedWord node : pathNodes)
        if (!node.equals(node0) && !node.equals(node1))
          features.setCount("POS_in_dependency_path: "+node.tag(),1.0);
    }

    // dependency_path_word_n_grams: n-grams of words (lemmatized) in the dependency path, n=2,3,4
    // dependency_path_POS_n_grams: n-grams of POS tags of words in the dependency path, n=2,3,4
    for (int node = 0; node < pathNodes.size(); node++) {
      for (int n = 2; n <= 4; n++) {
        if (node+n > pathNodes.size())
          break;
        StringBuilder sb = new StringBuilder();
        StringBuilder sbPOS = new StringBuilder();

        for (int elt = node; elt < node+n; elt++) {
          sb.append(pathLemmas.get(elt));
          sb.append("_");
          sbPOS.append(pathNodes.get(elt).tag());
          sbPOS.append("_");
        }
        if (usingFeature(types, checklist, "dependency_path_word_n_grams"))
          features.setCount("dependency_path_"+n+"-gram: "+sb,1.0);
        if (usingFeature(types,checklist, "dependency_path_POS_n_grams"))
          features.setCount("dependency_path_POS_"+n+"-gram: "+sbPOS,1.0);
      }
    }
    // dependency_path_edge_n_grams: n_grams of relations (high-level) in the dependency path, undirected, n=2,3,4
    // e.g. "subj -- prep_in -- mod"
    // dependency_path_edge_lowlevel_n_grams: similar, for fine-grained relations
    //
    // dependency_path_node-edge-node-grams: trigrams consisting of adjacent words (lemmatized) in the dependency path
    // and the relation between them (undirected)
    // dependency_path_node-edge-node-grams_lowlevel: same, using fine-grained relations
    //
    // dependency_path_edge-node-edge-grams: trigrams consisting of words (lemmatized) in the dependency path
    // and the incoming and outgoing relations (undirected)
    // e.g. "subj -- television -- mod"
    // dependency_path_edge-node-edge-grams_lowlevel: same, using fine-grained relations
    //
    // dependency_path_directed_bigrams: consecutive words in the dependency path (lemmatized) and the direction
    // of the dependency between them
    // e.g. "Theatre -> exhibit"
    //
    // dependency_path_edge_unigrams: feature for each (fine-grained) relation in the dependency path,
    // with its direction in the path and whether it's at the left end, right end, or interior of the path.
    // e.g. "prep_at ->  - leftmost"
    for (int edge = 0; edge < edgePath.size(); edge++) {
      if (usingFeature(types, checklist, "dependency_path_edge_n_grams") ||
              usingFeature(types, checklist, "dependency_path_edge_lowlevel_n_grams")) {
        for (int n = 2; n <= 4; n++) {
          if (edge+n > edgePath.size())
            break;
          StringBuilder sbRelsHi = new StringBuilder();
          StringBuilder sbRelsLo = new StringBuilder();
          for (int elt = edge; elt < edge+n; elt++) {
            GrammaticalRelation gr = edgePath.get(elt).getRelation();
            sbRelsHi.append(generalizeRelation(gr));
            sbRelsHi.append("_");
            sbRelsLo.append(gr);
            sbRelsLo.append("_");
          }
          if (usingFeature(types, checklist, "dependency_path_edge_n_grams"))
            features.setCount("dependency_path_edge_"+n+"-gram: "+sbRelsHi,1.0);
          if (usingFeature(types, checklist, "dependency_path_edge_lowlevel_n_grams"))
            features.setCount("dependency_path_edge_lowlevel_"+n+"-gram: "+sbRelsLo,1.0);
        }
      }
      if (usingFeature(types, checklist, "dependency_path_node-edge-node-grams"))
        features.setCount(
                "dependency_path_node-edge-node-gram: "+
                        pathLemmas.get(edge)+" -- "+
                        generalizeRelation(edgePath.get(edge).getRelation())+" -- "+
                        pathLemmas.get(edge+1),
                1.0);
      if (usingFeature(types, checklist, "dependency_path_node-edge-node-grams_lowlevel"))
        features.setCount(
                "dependency_path_node-edge-node-gram_lowlevel: "+
                        pathLemmas.get(edge)+" -- "+
                        edgePath.get(edge).getRelation()+" -- "+
                        pathLemmas.get(edge+1),
                1.0);
      if (usingFeature(types,checklist, "dependency_path_edge-node-edge-grams") && edge > 0)
        features.setCount(
                "dependency_path_edge-node-edge-gram: "+
                        generalizeRelation(edgePath.get(edge-1).getRelation())+" -- "+
                        pathLemmas.get(edge)+" -- "+
                        generalizeRelation(edgePath.get(edge).getRelation()),
                1.0);
      if (usingFeature(types,checklist,"dependency_path_edge-node-edge-grams_lowlevel") && edge > 0)
        features.setCount(
                "dependency_path_edge-node-edge-gram_lowlevel: "+
                        edgePath.get(edge-1).getRelation()+" -- "+
                        pathLemmas.get(edge)+" -- "+
                        edgePath.get(edge).getRelation(),
                1.0);
      String dir = pathNodes.get(edge).equals(edgePath.get(edge).getDependent()) ? " -> " : " <- ";
      if (usingFeature(types, checklist, "dependency_path_directed_bigrams"))
        features.setCount(
                "dependency_path_directed_bigram: "+
                        pathLemmas.get(edge)+
                        dir+
                        pathLemmas.get(edge+1),
                1.0);
      if (usingFeature(types, checklist, "dependency_path_edge_unigrams"))
        features.setCount(
                "dependency_path_edge_unigram: "+
                        edgePath.get(edge).getRelation() +
                        dir+
                        (edge==0 ? " - leftmost" : edge==edgePath.size()-1 ? " - rightmost" : " - interior"),1.0);
    }

    // dependency_path_length: number of edges in the path between args in the dependency graph, integer-valued
    // dependency_path_length_binary: same, as binary features
    if (usingFeature(types, checklist, "dependency_path_length")) {
      features.setCount("dependency_path_length", edgePath.size());
    }
    if (usingFeature(types, checklist, "dependency_path_length_binary")) {
      features.setCount("dependency_path_length_" + new DecimalFormat("00").format(edgePath.size()), 1.0);
    }

    if (usingFeature(types, checklist, "dependency_path_trigger")) {
      List<CoreLabel> tokens = rel.getSentence().get(TokensAnnotation.class);

      for (IndexedWord node : pathNodes) {
        int index = node.index();
        if (indecesToSkip.contains(index)) continue;

        String trigger = tokens.get(index - 1).get(TriggerAnnotation.class);
        if (trigger != null && trigger.startsWith("B-"))
          features.incrementCount("dependency_path_trigger=" + trigger.substring(2));
      }
    }
  }

  /**
   * Helper method that checks if a feature type "type" is present in the list of features "types"
   * and removes it from "checklist"
   * @param types
   * @param checklist
   * @param type
   * @return true if types contains type
   */
  protected static boolean usingFeature(final List<String> types, List<String> checklist, String type) {
    checklist.remove(type);
    return types.contains(type) || types.contains("all");
  }

  protected static GrammaticalRelation generalizeRelation(GrammaticalRelation gr) {
    final GrammaticalRelation[] GENERAL_RELATIONS = { EnglishGrammaticalRelations.SUBJECT,
            EnglishGrammaticalRelations.COMPLEMENT, EnglishGrammaticalRelations.CONJUNCT,
            EnglishGrammaticalRelations.MODIFIER, };
    for (GrammaticalRelation generalGR : GENERAL_RELATIONS) {
      if (generalGR.isAncestor(gr)) {
        return generalGR;
      }
    }
    return gr;
  }

  /*
   * Under construction
   */

  public static List<String> dependencyPathAsList(List<SemanticGraphEdge> edgePath, IndexedWord node, boolean generalize) {
    if(edgePath == null) return null;
    List<String> path = new ArrayList<String>();
    for (SemanticGraphEdge edge : edgePath) {
      IndexedWord nextNode;
      GrammaticalRelation relation;
      if (generalize) {
        relation = generalizeRelation(edge.getRelation());
      } else {
        relation = edge.getRelation();
      }

      if (node.equals(edge.getDependent())) {
        String v = (relation + "->").intern();
        path.add(v);
        nextNode = edge.getGovernor();
      } else {
        String v = ("<-" + relation).intern();
        path.add(v);
        nextNode = edge.getDependent();
      }
      node = nextNode;
    }

    return path;
  }

  public static String dependencyPath(List<SemanticGraphEdge> edgePath, IndexedWord node) {
    // the extra spaces are to maintain compatibility with existing relation extraction models
    return " " + StringUtils.join(dependencyPathAsList(edgePath, node, false), "  ") + " ";
  }

  public static String generalizedDependencyPath(List<SemanticGraphEdge> edgePath, IndexedWord node) {
    // the extra spaces are to maintain compatibility with existing relation extraction models
    return " " + StringUtils.join(dependencyPathAsList(edgePath, node, true), "  ") + " ";
  }

  public Set<String> getFeatures(RelationMention rel, String featureType) {
    Counter<String> features = new ClassicCounter<String>();
    List<String> singleton = new ArrayList<String>();
    singleton.add(featureType);
    addFeatures(features, rel, singleton);
    return features.keySet();
  }

  public String getFeature(RelationMention rel, String featureType) {
    Set<String> features = getFeatures(rel, featureType);
    if (features.size() == 0) {
      return "";
    } else {
      return features.iterator().next();
    }
  }


}
TOP

Related Classes of edu.stanford.nlp.ie.machinereading.BasicRelationFeatureFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.