Package edu.stanford.nlp.sempre.paraphrase.rules

Source Code of edu.stanford.nlp.sempre.paraphrase.rules.SyntacticRuleSet$Trie

package edu.stanford.nlp.sempre.paraphrase.rules;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import com.google.common.base.Joiner;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.sempre.FeatureVector;
import edu.stanford.nlp.sempre.LanguageInfo;
import edu.stanford.nlp.sempre.LanguageInfo.LanguageUtils;
import edu.stanford.nlp.sempre.LanguageInfo.WordInfo;
import edu.stanford.nlp.sempre.paraphrase.rules.LanguageExp.LangExpMatch;
import edu.stanford.nlp.sempre.paraphrase.rules.RuleApplication.ApplicationInfo;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;

/**
* Storing and matching efficiently a large number of syntactic rules
* We do not use the regexp mechanism since in the trie we want to use a map
* to find the next edges relevant and not go over all outgoing edges and match the regexp
* over them
* @author jonathanberant
*/

public class SyntacticRuleSet extends RuleApplier {

  public static class Options {
    @Option(gloss= "Path to syntactic rule set") public String rulesetPath="lib/paralex/syntactic-rules.retagged.sorted.txt";
    @Option(gloss= "Verbosity") public int verbose=0;
  }
  public static Options opts = new Options();

  Trie trie = new Trie();

  public SyntacticRuleSet() {
    LogInfo.begin_track("Loading syntactic rules");
    loadRuleset();
    LogInfo.end_track();
  }

  /**
   * loading the rules we filter
   * (a) things with punctuations
   * (b) things with derivations
   * (c) things where pos tags don't match?
   * (d) count threshold?
   */
  private void loadRuleset() {
    int count=0;
    for(String line: IOUtils.readLines(opts.rulesetPath)) {
      if(validRule(line)) {
        count++;
        SubstitutableSyntacticRule rule = parseRule(line);
        if(opts.verbose>=3)
          LogInfo.logs("loadRuleSet: uploaded rule=%s",rule);
        trie.add(parseRule(line));
      }
    }
    LogInfo.logs("Number of valid rules=%s",count);
  }

  private boolean validRule(String line) {
    if(line.contains("DER"))
      return false;
    if(line.contains("._"))
      return false;
    String[] tokens = line.split("\t");
    double count = Double.parseDouble(tokens[2]);
    if(count<20.0)
      return false;
    String[] lhsTokens = tokens[0].split("\\s+");
    String[] rhsTokens = tokens[1].split("\\s+");
    Set<Integer> lhsVars = new HashSet<Integer>();
    Set<Integer> rhsVars = new HashSet<Integer>();
    List<Integer> rhsVarList = new ArrayList<Integer>();
    for(String lhsToken: lhsTokens) {
      if(Character.isDigit(lhsToken.charAt(0)))
        return false;
      if(lhsToken.indexOf('_')!=-1)
        lhsVars.add(Integer.parseInt(lhsToken.split("_")[1]));
    }
    for(String rhsToken: rhsTokens) {
      if(Character.isDigit(rhsToken.charAt(0)))
        return false;
      if(rhsToken.indexOf('_')!=-1) {
        rhsVars.add(Integer.parseInt(rhsToken.split("_")[1]));
        rhsVarList.add(Integer.parseInt(rhsToken.split("_")[1]));
      }
    }
//    if(!SetUtils.isEqualSet(lhsVars, rhsVars))
//      return false;
    if(lhsVars.size()<2 || rhsVars.size()<2) //extreme rule
      return false;
    for(int i = 0; i < rhsVarList.size()-1;++i) {
      if(rhsVarList.get(i)>rhsVarList.get(i+1))
        return true;
    }
    return false;
    //    return true;
  }

  //TODO very hacky
  private SubstitutableSyntacticRule parseRule(String line) {

    String[] tokens = line.split("\t");
    double count = Double.parseDouble(tokens[2]);
    List<LanguageExpToken> lhs = new ArrayList<LanguageExpToken>();

    String[] lhsTokens = tokens[0].split("\\s+");
    int[] varNameToPositionMap =new int[lhsTokens.length];
    String[] rhsTokens = tokens[1].split("\\s+");
    int[] mapping = new int[rhsTokens.length];
    Arrays.fill(mapping, -1);
    Arrays.fill(varNameToPositionMap, -1);
    WordInfo[] rhs = new WordInfo[rhsTokens.length];
    //generate lhs
    for(int i = 0; i < lhsTokens.length; ++i) {
      String lhsToken = lhsTokens[i];
      if(lhsToken.indexOf('_')!=-1) {
        String[] posAndVarName = lhsToken.split("_");
        lhs.add(new LanguageExpToken("pos", "["+posAndVarName[0]+"]"));
        varNameToPositionMap[Integer.parseInt(posAndVarName[1])]=i;
      }
      else {
        String[] wordInfoParts = lhsToken.split("\\|\\|");
        lhs.add(new LanguageExpToken("lemma","["+wordInfoParts[0]+"]"));
      }
    }
    //generate mapping and rhs
    for(int i = 0; i < rhsTokens.length; ++i) {
      String rhsToken = rhsTokens[i];
      if(rhsToken.indexOf('_')!=-1) {
        String[] posAndVarName = rhsToken.split("_");
        mapping[i]=varNameToPositionMap[Integer.parseInt(posAndVarName[1])];
      }
      else {
        String[] wordInfoParts = rhsToken.split("\\|\\|");
        rhs[i]=new WordInfo(wordInfoParts[0], wordInfoParts[0], wordInfoParts[1], wordInfoParts[2], "O");
      }
    }
    return new SubstitutableSyntacticRule(lhs, rhs, mapping, count);
  }

  @Override
  public List<RuleApplication> apply(LanguageInfo antecedent,
      LanguageInfo target) {

    List<RuleApplication> res = new ArrayList<RuleApplication>();
    //go over all spans
    for(int i = 0; i < antecedent.numTokens(); ++i) {
      List<Trie> currentNodes = new ArrayList<Trie>();
      currentNodes.add(trie);
      for(int span=1; span<=5 && i+span <= antecedent.numTokens(); ++span) {
        if(i==1 && span==1)
          System.out.print("");
        List<Trie> nextNodes = new ArrayList<Trie>();
        Pair<String,String> lemmaPair = new Pair<String, String>("lemma",antecedent.lemmaTokens.get(i+span-1));
        Pair<String,String> posPair = new Pair<String, String>("pos",LanguageUtils.getCanonicalPos(antecedent.posTags.get(i+span-1)));
        Pair<String,String> nerPair = new Pair<String, String>("ner",antecedent.nerTags.get(i+span-1));
        //add to next nodes all the tries we can reach with the current word
        addReachableTries(currentNodes, nextNodes, lemmaPair, posPair, nerPair);
        //now we can apply all of the rules
        generateMatchingRhsApplications(antecedent, target, res, i, span, nextNodes);
        //we set current nodes to next nodes for the next round
        currentNodes = nextNodes;
      }
    }
    return res;
  }

  private void generateMatchingRhsApplications(LanguageInfo antecedent, LanguageInfo target,
      List<RuleApplication> res, int i, int span, List<Trie> nextNodes) {

    Set<String> generatedRhsMatches = new HashSet<String>();
    for(Trie nextNode: nextNodes) {
      for(SubstitutableSyntacticRule rule: nextNode.rules) {
        List<WordInfo> rhsMatch = rule.generateRhsLemmas(antecedent, i);
        if(target.matchLemmas(rhsMatch)) {
          String rhsMatchPhrase = LanguageUtils.getLemmaPhrase(rhsMatch);
          if(!generatedRhsMatches.contains(rhsMatchPhrase)) {
            res.add(generateApplications(antecedent, i, i+span, rhsMatch, rule));
          }
        }
      }
    }
  }

  private void addReachableTries(List<Trie> currentNodes, List<Trie> nextNodes,
      Pair<String, String> lemmaPair, Pair<String, String> posPair,
      Pair<String, String> nerPair) {
    for(Trie currentNode: currentNodes) {
      addNextTrie(nextNodes, lemmaPair, currentNode);
      addNextTrie(nextNodes, posPair, currentNode);
      addNextTrie(nextNodes, nerPair, currentNode);
    }
  }

  private void addNextTrie(List<Trie> nextNodes, Pair<String, String> lemmaPair,
      Trie currentNode) {
    Trie nextTrie = currentNode.next(lemmaPair);
    if(nextTrie!=null)
      nextNodes.add(nextTrie);
  }

  private RuleApplication generateApplications(LanguageInfo antecedent, int i, int j, List<WordInfo> rhsMatch, SubstitutableSyntacticRule rule) {
    LanguageInfo consequent = new LanguageInfo();
    consequent.addSpan(antecedent, 0, i);
    consequent.addWordInfos(rhsMatch);
    consequent.addSpan(antecedent, j, antecedent.numTokens());
    RuleApplication application = new RuleApplication(antecedent, consequent, new ApplicationInfo(SYNT_SUBST, rule.toString()));
    FeatureVector fv = new FeatureVector();
    fv.add(SYNT_SUBST, rule.toString());
    fv.add(SYNT_SUBST, "score" ,rule.count);
    application.addFeatures(fv);
    if(opts.verbose>0)
      LogInfo.logs("antecedent=%s, consequent=%s, rule=%s",antecedent.tokens,consequent.tokens,rule);
    return application;
  }

  class SubstitutableSyntacticRule {

    final List<LanguageExpToken> lhs;
    final WordInfo[] rhs;
    final int[] mapping;
    final double count;

    public SubstitutableSyntacticRule(List<LanguageExpToken> lhs, WordInfo[] rhs,
        int[] mapping, double count) {
      this.lhs=lhs;
      this.rhs=rhs;
      this.mapping=mapping;
      this.count = count;
    }

    public List<WordInfo> generateRhsLemmas(LanguageInfo antecedent, int start) {
      List<WordInfo> res = new ArrayList<WordInfo>();
      for(int i = 0; i < rhs.length; ++i) {
        if(rhs[i]!=null)
          res.add(rhs[i]);
        else
          res.add(antecedent.getWordInfo(start+mapping[i]));
      }
      return res;
    }

    public String toString() {
      StringBuilder sb = new StringBuilder();
      for(int i = 0; i < rhs.length; ++i) {
        if(rhs[i]==null)
          sb.append(mapping[i]+" ");
        else
          sb.append(rhs[i].toString()+" ");  
      }
      return Joiner.on(' ').join(lhs)+"-->"+sb.toString()+"("+count+")";
    }
  }


  class Trie {
    ArrayList<SubstitutableSyntacticRule> rules = new ArrayList<SubstitutableSyntacticRule>();
    HashMap<Pair<String,String>, Trie> children = new HashMap<Pair<String,String>, Trie>();

    Trie next(Pair<String,String> pair) {
      return children.get(pair);
    }

    void add(SubstitutableSyntacticRule rule) { add(rule, 0); }
    private void add(SubstitutableSyntacticRule rule, int i) {

      if(i == rule.lhs.size()) {
        rules.add(rule);
        return;
      }

      LanguageExpToken langToken = rule.lhs.get(i);
      Trie child = children.get(convertLangItemToPair(langToken));
      if (child == null) {
        children.put(convertLangItemToPair(langToken),
            child = new Trie());
      }
      child.add(rule, i + 1);
    }

    private Pair<String,String> convertLangItemToPair(LanguageExpToken langToken) {
      String value = langToken.value.substring(1,langToken.value.lastIndexOf(']'));
      return Pair.newPair(langToken.type.toString(),value);
    }
  }

  public static void main(String[] args) {
    opts.verbose=3;
    SyntacticRuleSet.opts.rulesetPath="/Users/jonathanberant/Research/temp/syntactic-rules.retagged.sorted.txt";
    SyntacticRuleSet srt = new SyntacticRuleSet();
    LanguageInfo antecedent = new LanguageInfo();
    antecedent.addWordInfo(new WordInfo("where", "where", "WDT", "O", "O"));
    antecedent.addWordInfo(new WordInfo("was", "be", "VBD", "O", "O"));
    antecedent.addWordInfo(new WordInfo("obama", "obama", "NNP", "PERSON", "O"));
    antecedent.addWordInfo(new WordInfo("birth", "birth", "NN", "O", "O"));
    antecedent.addWordInfo(new WordInfo("place", "place", "NN", "O", "O"));

    LanguageInfo target = new LanguageInfo();
    target.addWordInfo(new WordInfo("where", "where", "WDT", "O", "O"));
    target.addWordInfo(new WordInfo("was", "be", "VBD", "O", "O"));
    target.addWordInfo(new WordInfo("obama", "obama", "NNP", "PERSON", "O"));
    target.addWordInfo(new WordInfo("'s", "'s", "POS", "O", "O"));
    target.addWordInfo(new WordInfo("place", "place", "NN", "O", "O"));
    target.addWordInfo(new WordInfo("of", "of", "IN", "O", "O"));
    target.addWordInfo(new WordInfo("birth", "birth", "NN", "O", "O"));
    srt.apply(antecedent, target);
  }

  @Override
  public List<LangExpMatch> match(LanguageInfo lInfo) {
    throw new RuntimeException("Unsupoorted method");
  }
}
TOP

Related Classes of edu.stanford.nlp.sempre.paraphrase.rules.SyntacticRuleSet$Trie

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.