Package edu.stanford.nlp.sempre.paraphrase

Source Code of edu.stanford.nlp.sempre.paraphrase.Aligner

package edu.stanford.nlp.sempre.paraphrase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import edu.stanford.nlp.sempre.BooleanValue;
import edu.stanford.nlp.sempre.FeatureVector;
import edu.stanford.nlp.sempre.LanguageInfo;
import edu.stanford.nlp.sempre.LanguageInfo.WordInfo;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.LanguageInfo.LanguageUtils;
import edu.stanford.nlp.sempre.fbalignment.utils.WnExpander;
import edu.stanford.nlp.sempre.paraphrase.paralex.PhraseTable;
import edu.stanford.nlp.sempre.paraphrase.rules.RuleApplication;
import edu.stanford.nlp.sempre.paraphrase.rules.RuleApplier;
import edu.stanford.nlp.sempre.paraphrase.rules.Rulebase;
import fig.basic.IntPair;
import fig.basic.LispTree;
import fig.basic.LogInfo;
import fig.basic.Option;

/**
* Exhaustive alignment between two paraphrases
*/
public class Aligner {

  public static class Options {
    @Option(gloss="Path to file with derivations") public String derivationPath="lib/derivation.txt";
    @Option(gloss="Maximum distortion") public double distortionParam=1;
    @Option(gloss="Whether to use syntax rules") public boolean useSyntax=false;
    @Option(gloss="Whether to use wordnet synsets") public boolean useWordnet=false;
    @Option public int verbose=0;
  }
  public static Options opts = new Options();
  private static Aligner aligner;

  private final PhraseTable phraseTable;
  private final Map<String,Set<String>> derivations;
  private Rulebase ruleBase;
  private WnExpander wnExpander;


  private Aligner() throws IOException {
    phraseTable = PhraseTable.getSingleton();
    derivations = ParaphraseUtils.loadDerivations(opts.derivationPath);
    wnExpander = new WnExpander();
    if(opts.useSyntax)
      ruleBase = new Rulebase();
  }

  public static Aligner getSingleton() throws IOException {
    if(aligner==null)
      aligner = new Aligner();
    return aligner;
  }

  public Alignment align(ParaphraseExample example, Params params) {
    example.ensureAnnotated();
    Alignment alignment = new Alignment(example);
    alignment.buildAlignment(example, params);
    example.setAlignment(alignment);
    return alignment;
  }

  public class Alignment {

    private List<AlignmentIntervalPair> substitutions = new ArrayList<Aligner.AlignmentIntervalPair>();
    private List<AlignmentInterval> deletions = new ArrayList<Aligner.AlignmentInterval>();
    FeatureVector featureVector = new FeatureVector();
    double score = Double.NaN;
    private final String source;
    private final String target;

    public Alignment(ParaphraseExample example) {
      source = example.source;
      target = example.target;
    }

    public void buildAlignment(ParaphraseExample example, Params params) {
      computeIdentityAlignments(example);
      computePhraseTableAlignments(example);
      computeSubstitutionsAlignment(example);
      computeDerivationsAlignment(example);
      //this needs to be done last
      markDeletions(example);
      if(opts.useSyntax)
        computeSyntacticAlignment(example);
      if(opts.verbose>=1) {
        printFeaturesAndWeights(params);
      }
      score = featureVector.dotProduct(params);
    }

    private void printFeaturesAndWeights(Params params) { 
      for(String key: featureVector.toMap().keySet()) {
        double value = featureVector.toMap().get(key);
        LogInfo.logs("Printing features: feautre=%s, value=%s, weight=%s,product=\t%s",key,value,
            params.getWeight(key),value * params.getWeight(key));
      }  
    }

    private void computeSyntacticAlignment(ParaphraseExample example) {
      for(RuleApplier rule: ruleBase.getRules()) {
        for(RuleApplication application: rule.apply(example.sourceInfo, example.targetInfo)) {
          List<WordInfo> wordInfos = new ArrayList<WordInfo>();
          for(int i = 0; i < application.consequent.numTokens(); ++i)
            wordInfos.add(application.consequent.getWordInfo(i));
          if(example.targetInfo.matchLemmas(wordInfos)) {
            featureVector.add("SyntAlign",rule.toString());
          }
        }
      }
    }

    private void computeDerivationsAlignment(ParaphraseExample example) {
      for(int i = 0 ; i < example.sourceInfo.numTokens(); ++i) {
        for(int j = 0; j < example.targetInfo.numTokens(); ++j) {
          String sourceLemma = example.sourceInfo.lemmaTokens.get(i);
          String targetLemma = example.targetInfo.lemmaTokens.get(j);
          if(!sourceLemma.equals(targetLemma)) {
            if(derivations.containsKey(sourceLemma) &&
                derivations.get(sourceLemma).contains(targetLemma)) {
              AlignmentInterval sourceInterval =
                  new AlignmentInterval(sourceLemma, new Interval(i, i+1));
              AlignmentInterval targetInterval =
                  new AlignmentInterval(targetLemma, new Interval(j, j+1));
              substitutions.add(new AlignmentIntervalPair(sourceInterval, targetInterval));
              featureVector.add("Subst", "Deriv");
              featureVector.add("Subst", "l="+sourceLemma+",r="+targetLemma);
            }
          }
        }
      }
    }

    private void markDeletions(ParaphraseExample example) {
      markDeletion(example.sourceInfo,true);
      markDeletion(example.targetInfo,false);
    }

    private void markDeletion(LanguageInfo lInfo, boolean isSource) {

      for(int i = 0; i < lInfo.numTokens(); ++i) {
        boolean aligned=false;
        for(AlignmentIntervalPair alignmentIntervalPair: substitutions) {
          AlignmentInterval alignmentInterval = isSource ? alignmentIntervalPair.sourceInterval
              : alignmentIntervalPair.targetInterval;
          if(alignmentInterval.interval.contains(i)) {
            aligned=true;
            break;
          }
        }
        if(!aligned) {
          AlignmentInterval deletedToken = new AlignmentInterval(lInfo.lemmaTokens.get(i), new Interval(i, i+1));
          deletions.add(deletedToken);
          if(ParaphraseFeatureMatcher.containsDomain("Del")) {
            featureVector.add("Del", "lemma="+deletedToken.phrase);
            featureVector.add("Del", "pos="+LanguageUtils.getCanonicalPos(lInfo.posTags.get(i)));
          }
        }
      } 
    }

    private void computeIdentityAlignments(ParaphraseExample example) {
      for(int i = 0 ; i < example.sourceInfo.numTokens(); ++i) {
        for(int j = 0; j < example.targetInfo.numTokens(); ++j) {
          if(example.sourceInfo.lemmaTokens.get(i).equals(example.targetInfo.lemmaTokens.get(j))) {

            AlignmentInterval sourceInterval =
                new AlignmentInterval(example.sourceInfo.lemmaTokens.get(i), new Interval(i, i+1));
            AlignmentInterval targetInterval =
                new AlignmentInterval(example.targetInfo.lemmaTokens.get(j), new Interval(j, j+1));
            substitutions.add(new AlignmentIntervalPair(sourceInterval, targetInterval));
            featureVector.add("Subst", "Identity", 1);          
          }
        }
      }
    }

    /**
     * We consider all tokens and phrases with two nouns and allow substitutions if canonical pos matches.
     * @param example
     */
    private void computeSubstitutionsAlignment(ParaphraseExample example) {

      List<Interval> sourceIntervals = getLanguageInfoCandidates(example.sourceInfo);
      List<Interval> targetIntervals = getLanguageInfoCandidates(example.targetInfo);
      for(Interval sourceInterval: sourceIntervals) {
        for(Interval targetInterval: targetIntervals) {
          String sourcePhrase = example.sourceInfo.lemmaPhrase(sourceInterval.start, sourceInterval.end);
          String targetPhrase = example.targetInfo.lemmaPhrase(targetInterval.start, targetInterval.end);
          if(!sourcePhrase.equals(targetPhrase) &&
              ParaphraseUtils.posCompatible(example.sourceInfo.posTags.get(sourceInterval.start),
                      example.targetInfo.posTags.get(targetInterval.start))) {

            AlignmentInterval sourceAlignmentInterval = new AlignmentInterval(sourcePhrase, sourceInterval);
            AlignmentInterval targetAlignmentInterval = new AlignmentInterval(targetPhrase, targetInterval);
            if(validDistortion(sourceAlignmentInterval, example.sourceInfo,
                targetAlignmentInterval,example.targetInfo)) {
              substitutions.add(new AlignmentIntervalPair(sourceAlignmentInterval, targetAlignmentInterval));
              //features
              String lPos = example.sourceInfo.canonicalPosSeq(sourceInterval.start, sourceInterval.end);
              String rPos = example.targetInfo.posSeq(targetInterval.start, targetInterval.end);
              if(!lPos.equals(rPos))
                featureVector.add("Subst", "l_pos="+lPos+",r_pos="+rPos);
              else
                featureVector.add("Subst", "pos_identity");
              featureVector.add("Subst", "l="+sourcePhrase+",r="+targetPhrase);
              if(opts.useWordnet) {
                if(wnExpander.getSynonyms(sourcePhrase).contains(targetPhrase))
                  featureVector.add("Subst", "synonym");
              }
            }
          }
        }
      }
    }

    private boolean validDistortion(AlignmentInterval sourceAlignmentInterval,
        LanguageInfo sourceInfo, AlignmentInterval targetAlignmentInterval, LanguageInfo targetInfo) {

      double sourcePosition = sourceAlignmentInterval.interval.middle() / sourceInfo.numTokens();
      double targetPosition = targetAlignmentInterval.interval.middle() / targetInfo.numTokens();
      int min = Math.min(sourceInfo.numTokens(), targetInfo.numTokens());
      return (Math.abs(sourcePosition-targetPosition)<Math.max(opts.distortionParam,
          (double) 2*(1/min)));

    }

    private List<Interval> getLanguageInfoCandidates(LanguageInfo lInfo) {
      List<Interval> res = new ArrayList<Interval>();
      for(int i = 0; i < lInfo.numTokens(); ++i) {
        res.add(new Interval(i, i+1));
        if(i < lInfo.numTokens()-1 &&
            LanguageUtils.isNN(lInfo.posTags.get(i)) &&
            LanguageUtils.isNN(lInfo.posTags.get(i+1))) {
          res.add(new Interval(i,i+2));
        }
      }
      return res;
    }

    private  void computePhraseTableAlignments(ParaphraseExample example) {
      //go over all lhs spans
      for(int i = 0; i < example.sourceInfo.numTokens(); ++i) {
        for(int j = i+1; j <= example.sourceInfo.numTokens(); ++j) {
          String lhsPhrase = example.sourceInfo.lemmaPhrase(i, j);
          if(phraseTable.containsKey(lhsPhrase)) {
            Map<String,AlignmentStats> rhsCandidates = phraseTable.get(lhsPhrase);
            //get rhs
            Map<String,IntPair> rhsLemmaSpans = example.targetInfo.getLemmaSpans();
            for(String rhsLemmaSpan: rhsLemmaSpans.keySet()) {
              if(rhsCandidates.containsKey(rhsLemmaSpan)) {
                int targetStart = rhsLemmaSpans.get(rhsLemmaSpan).first;
                int targetEnd = rhsLemmaSpans.get(rhsLemmaSpan).second;
                AlignmentStats aStats = rhsCandidates.get(rhsLemmaSpan);

                //now we have the source phrase the target phrase and the intervals
                AlignmentInterval sourceInterval = new AlignmentInterval(lhsPhrase, new Interval(i, j));
                AlignmentInterval targetInterval = new AlignmentInterval(rhsLemmaSpan, new Interval(targetStart, targetEnd));
                if(validDistortion(sourceInterval, example.sourceInfo,
                    targetInterval,example.targetInfo)) {
                  substitutions.add(new AlignmentIntervalPair(sourceInterval, targetInterval));
                  //features
                  featureVector.add("Subst", "l="+lhsPhrase+",r="+rhsLemmaSpan);
                  if(ParaphraseFeatureMatcher.containsDomain("Pt")) {
                    featureVector.add("Pt","cooccurCount="+binCount(aStats.cooccurrenceCount));
                    featureVector.add("Pt","phraseCount="+binCount(aStats.phrase1Count));
                    featureVector.add("Pt","phraseCount="+binCount(aStats.phrase2Count));
                  }
                  String lPos = example.sourceInfo.canonicalPosSeq(i, j);
                  String rPos = example.targetInfo.canonicalPosSeq(targetStart, targetEnd);
                  if(lPos.equals(rPos))
                    featureVector.add("Subst","pos_identity");
                  else
                    featureVector.add("Subst", "l_pos="+lPos+",r_pos="+rPos);
                }
              }
            }
          }
        }
      }
    }

    private String binCount(double count) {
      if(count<=50)
        return "<=50";
      if(count<=100)
        return "<=100";
      if(count<=500)
        return "<=500";
      if(count<=1000)
        return "<=1000";
      if(count<=5000)
        return "<=5000";
      if(count<=10000)
        return "<=10000";
      if(count<=50000)
        return "<=50000";
      if(count<=100000)
        return "<=100000";
      if(count<=1000000)
        return "<=1000000";
      return ">1000000";
    }

    public void clear() {
      featureVector.clear();
      substitutions.clear();
      deletions.clear();
    }

    public LispTree toLispTree() {
      LispTree tree = LispTree.proto.newList();
      tree.addChild("alignment");
      tree.addChild(LispTree.proto.newList("source", ""+source));
      tree.addChild(LispTree.proto.newList("target", ""+target));
      tree.addChild(LispTree.proto.newList("alignment_score", ""+score));
      return tree;
    }

    public String toString() {
      StringBuilder sb = new StringBuilder();
      for(String key: featureVector.toMap().keySet()) {
        sb.append(key+"\t"+featureVector.toMap().get(key)+"\n");
      }
      return sb.toString();
    }
  }

  public class AlignmentInterval {
    public final String phrase;
    public final Interval interval;
    public AlignmentInterval(String phrase, Interval interval) {
      super();
      this.phrase = phrase;
      this.interval = interval;
    }
    public String toString() {
      return phrase+","+interval;
    }
  }
  public class AlignmentIntervalPair {
    public final AlignmentInterval sourceInterval;
    public final AlignmentInterval targetInterval;
    public AlignmentIntervalPair(AlignmentInterval sourceInterval,
        AlignmentInterval targetInterval) {
      super();
      this.sourceInterval = sourceInterval;
      this.targetInterval = targetInterval;
    }
    public String toString() {
      return sourceInterval+"\t"+targetInterval;
    }
  }

  public static class AlignmentStats {
    public final double cooccurrenceCount;
    public final double phrase1Count;
    public final double phrase2Count;

    public AlignmentStats(double cooccurrence, double phrase1Freq, double phrase2Freq) {
      this.cooccurrenceCount=cooccurrence;
      this.phrase1Count=phrase1Freq;
      this.phrase2Count=phrase2Freq;
    }

  }

  public static void main(String[] args) throws IOException {

    ParaphraseExample paraExample =new ParaphraseExample("what type of music did richard wagner play ?",
        "what is the musical genres of richard wagner ?",new BooleanValue(true));
    Aligner aligner = new Aligner();
    Params params = new Params();
    params.read("/Users/jonathanberant/Research/temp/params");
    Alignment alignment = aligner.align(paraExample, params);
    alignment.printFeaturesAndWeights(params)
  }
}
TOP

Related Classes of edu.stanford.nlp.sempre.paraphrase.Aligner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.