Package edu.stanford.nlp.sempre

Source Code of edu.stanford.nlp.sempre.BridgeFn

package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.FbFormulasInfo.BinaryFormulaInfo;
import edu.stanford.nlp.sempre.MergeFormula.Mode;
import fig.basic.LispTree;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.IOUtils;

import java.io.IOException;
import java.util.*;

/**
* Bridge between two derivations by type-raising one of them.
*
* @author jonathanberant
*/
public class BridgeFn extends SemanticFn {

  private static final Formula intFormula = Formulas.fromLispTree(LispTree.proto.parseFromString("(fb:type.object.type fb:type.int)"));
  private static final Formula floatFormula = Formulas.fromLispTree(LispTree.proto.parseFromString("(fb:type.object.type fb:type.float)"));

  public static class Options {
    @Option(gloss = "Verbose") public int verbose = 0;
    @Option(gloss = "Whether to allow entity bridging with no binary string match") public boolean looseEntBridge = false;
    @Option(gloss = "List of binary formulas to use during bridging") public String binariesFile = "";
  }

  public static Options opts = new Options();

  private FbFormulasInfo fbFormulaInfo = null;
  private String description;
  private boolean headFirst;
  private TextToTextMatcher textToTextMatcher;
  private static HashSet<Formula> binariesToUse;

  public void init(LispTree tree) {
    super.init(tree);
    if (tree.children.size() != 3)
      throw new RuntimeException("Number of children is: " + tree.children.size());
    if (!tree.child(2).value.equals("headFirst") && !tree.child(2).value.equals("headLast"))
      throw new RuntimeException("Bad argument for head position: " + tree.child(2).value);
    if (!tree.child(1).value.equals("unary") && !tree.child(1).value.equals("inject") && !tree.child(1).value.equals("entity"))
      throw new RuntimeException("Bad description: " + tree.child(1).value);

    this.description = tree.child(1).value;
    headFirst = tree.child(2).value.equals("headFirst");
  }

  public BridgeFn() {
    fbFormulaInfo = FbFormulasInfo.getSingleton();
    textToTextMatcher = TextToTextMatcher.getSingleton();
    binariesToUse = new HashSet<Formula>();
    if (!opts.binariesFile.equals("")) {
      readBinaries();
    }
  }

  public static void readBinaries() {
    for (String line : IOUtils.readLinesHard(opts.binariesFile)) {
      if (line.startsWith("#")) continue;
      if (line.equals("")) continue;
      binariesToUse.add(Formula.fromString(line));
    }
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) return true;
    if (o == null || getClass() != o.getClass()) return false;
    BridgeFn bridgeFn = (BridgeFn) o;
    if (headFirst != bridgeFn.headFirst) return false;
    if (!description.equals(bridgeFn.description)) return false;
    return true;
  }

  @Override
  public List<Derivation> call(Example ex, Callable c) {
    try {
      if (description.equals("unary")) {
        return bridgeUnary(ex, c);
      } else if (description.equals("inject")) {
        return injectIntoCvt(ex, c);
      } else if (description.equals("entity")) {
        return bridgeEntity(ex, c);
      } else {
        throw new RuntimeException("Invalid (expected unary, inject, or entity): " + description);
      }
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }

  private boolean isCvt(Derivation headDeriv) {
    if (!(headDeriv.formula instanceof JoinFormula))
      return false;
    JoinFormula join = (JoinFormula) headDeriv.formula;
    if (join.relation instanceof LambdaFormula)
      return true;
    if (join.child instanceof JoinFormula || join.child instanceof MergeFormula)
      return true;
    return false;
  }

  // Return all the entity supertypes of |type|.
  // TODO: make this more efficient.
  private Set<String> getSupertypes(SemType type, Set<String> supertypes) {
    if (type instanceof EntitySemType)
      supertypes.addAll(fbFormulaInfo.getIncludedTypesInclusive(((EntitySemType) type).name));
    else if (type instanceof UnionSemType)
      for (SemType baseType : ((UnionSemType) type).baseTypes)
        getSupertypes(baseType, supertypes);
    else {
      // FIXME HACK for when passing binary into lambda formula and
      // getSuperTypes doesn't work
      getSupertypes(SemType.fromString("topic"), supertypes);
      //throw new RuntimeException("Unexpected type (must be unary): " + type);
    }
    return supertypes;
  }

  private List<Derivation> bridgeUnary(Example ex, Callable c) throws IOException {

    List<Derivation> res = new ArrayList<Derivation>();

    // Example: modifier[Hanks] head[movies]
    Derivation headDeriv = headFirst ? c.child(0) : c.child(1);
    Derivation modifierDeriv = !headFirst ? c.child(0) : c.child(1);

    Set<String> headTypes = getSupertypes(headDeriv.type, new HashSet<String>());
    Set<String> modifierTypes = getSupertypes(modifierDeriv.type, new HashSet<String>());

    HashSet<String> usedBinaries = new HashSet<String>();

    for (String modifierType : modifierTypes) { // For each head type...
      Set<Formula> binaries;
      if (binariesToUse.size() == 0)
        binaries = fbFormulaInfo.getBinariesForType2(modifierType);
      else
        binaries = binariesToUse;

      for (Formula binary : binaries) {  // For each possible binary...

        if (usedBinaries.contains(binary.toString()))
          continue;

        BinaryFormulaInfo binaryInfo = fbFormulaInfo.getBinaryInfo(binary);

        if (opts.verbose >= 3)
          LogInfo.logs("%s => %s", modifierType, binary);

        if (headTypes.contains(binaryInfo.expectedType1)) {

          Formula bridgedFormula = buildBridge(headDeriv.formula, modifierDeriv.formula, binary);
          FbFormulasInfo.touchBinaryFormula(binary);
          usedBinaries.add(binary.toString());

          Derivation newDeriv = new Derivation.Builder()
              .withCallable(c)
              .formula(bridgedFormula)
              .type(headDeriv.type)
              .createDerivation();

          if (SemanticFn.opts.trackLocalChoices) {
            newDeriv.localChoices.add(
                String.format(
                    "BridgeFn: %s %s --> %s %s --> %s %s",
                    headDeriv.startEndString(ex.getTokens()), headDeriv.formula,
                    ex.getTokens().subList(c.child(0).end, c.child(1).start),
                    binary,
                    modifierDeriv.startEndString(ex.getTokens()), modifierDeriv.formula));
          }

          // Add features
          if (ex != null) {
            newDeriv.addFeature("BridgeFn", "unary");

            // Popularity of the binary
            newDeriv.addFeatureWithBias("BridgeFn", "popularity", Math.log(binaryInfo.popularity + 1));

            // head modifier POS tags
            String headModifierOrder = headFirst ? "head-modifier" : "modifier-head";
            newDeriv.addFeature("BridgeFn",
                "order=" + headModifierOrder + "," +
                "pos=" +
                ex.languageInfo.getCanonicalPos(headDeriv.start) + "-" +
                ex.languageInfo.getCanonicalPos(modifierDeriv.start));

            addBinaryMatchFeatures(ex, modifierDeriv, binary, newDeriv); //HACKY
            List<List<String>> exampleInfo = generateExampleInfo(ex, c); //this is not done in text to text matcher so done here

            FeatureVector vector = textToTextMatcher.extractFeatures(
                exampleInfo.get(0), exampleInfo.get(1), exampleInfo.get(2),
                new HashSet<String>(binaryInfo.descriptions));
            newDeriv.addFeatures(vector);
          }
          res.add(newDeriv);
        }
      }
    }
    return res;
  }

  //see whether the binary bridge can also be retrieved with the alignment lexicon
  private void addBinaryMatchFeatures(Example ex, Derivation modifierDeriv, Formula binary, Derivation newDeriv) throws IOException {
    if (!FeatureExtractor.containsDomain("bridgeBinaryMatch")) return;

    for (int i = 0; i < ex.languageInfo.lemmaTokens.size(); ++i) {
      if (i >= modifierDeriv.start && i < modifierDeriv.end)
        continue;
      String pos = ex.languageInfo.posTags.get(i);
      String lemma = ex.languageInfo.lemmaTokens.get(i);
      if (pos.startsWith("NN") || (pos.startsWith("VB") && !pos.equals("VBD-AUX")) || pos.equals("JJ") || pos.equals("IN")) {

        LexiconFn fn = new LexiconFn();
        fn.init(LispTree.proto.parseFromString("(LexiconFn binary)"));

        Derivation child = new Derivation.Builder()
            .cat("$CompositeRel").start(i).end(i + 1)
            .children(new ArrayList<Derivation>())
            .withStringFormulaFrom(lemma)
            .createDerivation();
        CallInfo c = new CallInfo("$Binary", i, i + 1, null, Collections.singletonList(child));
        List<Derivation> derivations = fn.call(ex, c);
        for (Derivation deriv : derivations) {
          if (deriv.formula.equals(binary) || deriv.formula.equals(FbFormulasInfo.reverseFormula(binary))) {
            if (opts.verbose >= 1) {
              LogInfo.logs("BridgeFn: lemma %s matched bridged binary %s", lemma, binary);
            }
            newDeriv.addFeatures(deriv);
            return;
          }
        }
      }
    }
  }

  //bridge without a unary - simply by looking at binaries leading to the entity and string matching binary description to example tokens/lemmas/stems
  private List<Derivation> bridgeEntity(Example ex, Callable c) throws IOException {

    List<Derivation> res = new ArrayList<Derivation>();
    Derivation modifierDeriv = c.child(0);

    Set<String> modifierTypes = getSupertypes(modifierDeriv.type, new HashSet<String>());

    for (String modifierType : modifierTypes) { // For each head type...
      Set<Formula> binaries = fbFormulaInfo.getBinariesForType2(modifierType);
      for (Formula binary : binaries) {  // For each possible binary...
        BinaryFormulaInfo binaryInfo = fbFormulaInfo.getBinaryInfo(binary);

        if (opts.verbose >= 3)
          LogInfo.logs("%s => %s", modifierType, binary);

        List<List<String>> exampleInfo = generateExampleInfo(ex, c); //this is not done in text to text matcher so done here
        if (textToTextMatcher.existsTokenMatch(exampleInfo.get(0), exampleInfo.get(2), new HashSet<String>(binaryInfo.descriptions))
            || opts.looseEntBridge) {
          Formula join = new JoinFormula(binary, modifierDeriv.formula);
          FbFormulasInfo.touchBinaryFormula(binary);         

          Derivation newDeriv = new Derivation.Builder()
              .withCallable(c)
              .formula(join)
              .type(new EntitySemType(binaryInfo.expectedType1))
              .createDerivation();
          if (opts.verbose >= 2)
            LogInfo.logs("BridgeStringFn: %s", join);

          //features
          if (ex != null) {
            newDeriv.addFeature("BridgeFn", "entity");
            newDeriv.addFeatureWithBias("BridgeFn", "popularity", Math.log(binaryInfo.popularity + 1));

            addBinaryMatchFeatures(ex, modifierDeriv, binary, newDeriv); //HACKY
            FeatureVector textMatchFeatures = textToTextMatcher.extractFeatures(
                exampleInfo.get(0), exampleInfo.get(1), exampleInfo.get(2),
                new HashSet<String>(binaryInfo.descriptions));
            newDeriv.addFeatures(textMatchFeatures);

            //addTokenMatchFeatures(tokenStemFeatures.first(), newDeriv, "binary_token");
            //addTokenMatchFeatures(tokenStemFeatures.second(), newDeriv, "binary_stem");          
            //addWordSimilarityFeatures(ex, newDeriv, binaryInfo); // (1) edit distance (2) word similarity
          }
          //newDeriv.localFeatureVector.add("bridge_lex_"+binaryInfo.expectedType1+"-->"+binary); //causes overfitting with 300 training examples
          res.add(newDeriv);
        }
      }
    }
    return res;
  }

  //generate from example array of content word tokens/lemmas/stems that are not dominated by child derivations
  private List<List<String>> generateExampleInfo(Example ex, Callable c) {

    List<String> tokens = new ArrayList<String>();
    List<String> posTags = new ArrayList<String>();
    List<String> lemmas = new ArrayList<String>();
    List<List<String>> res = new ArrayList<List<String>>();
    res.add(tokens);
    res.add(posTags);
    res.add(lemmas);

    Derivation modifierDeriv = headFirst ? c.child(1) : c.child(0);

    for (int i = 0; i < ex.languageInfo.tokens.size(); ++i) {
      if (i >= modifierDeriv.start && i < modifierDeriv.end) { //do not consider the modifier words {
        continue;
      }
      tokens.add(ex.languageInfo.tokens.get(i));
      posTags.add(ex.languageInfo.posTags.get(i));
      lemmas.add(ex.languageInfo.lemmaTokens.get(i));
    }
    return res;
  }

  private List<Derivation> injectIntoCvt(Example ex, Callable c) {
    List<Derivation> res = new ArrayList<Derivation>();

    // Example: modifier[Braveheart] head[Mel Gibson plays in]
    Derivation headDeriv = headFirst ? c.child(0) : c.child(1);
    if (!isCvt(headDeriv)) //only works on cvts
      return res;

    Derivation modifierDeriv = !headFirst ? c.child(0) : c.child(1);
    JoinFormula headFormula = (JoinFormula)Formulas.betaReduction(headDeriv.formula);
    //find the type of the cvt node
    Set<String> headTypes = Collections.singleton(fbFormulaInfo.getBinaryInfo(headFormula.relation).expectedType2);
    Set<String> modifierTypes = getSupertypes(modifierDeriv.type, new HashSet<String>());

    for (String modifierType : modifierTypes) {
      Set<Formula> binaries = fbFormulaInfo.getAtomicBinariesForType2(modifierType); //here we use atomic binaries since we inject into a CVT
      for (Formula binary : binaries) {  // For each possible binary...
        BinaryFormulaInfo info = fbFormulaInfo.getBinaryInfo(binary);

        if (headTypes.contains(info.expectedType1)) {
          Formula bridgedFormula = buildBridgeFromCvt(headFormula, modifierDeriv.formula, binary);
          FbFormulasInfo.touchBinaryFormula(binary);
          Derivation newDeriv = new Derivation.Builder()
              .withCallable(c)
              .formula(bridgedFormula)
              .type(headDeriv.type)
              .createDerivation();
          if (opts.verbose >= 3)
            LogInfo.logs("BridgeFn: injecting %s to %s --> %s ", modifierDeriv.formula, headFormula, bridgedFormula);

          if (ex != null) {
            String headModifierOrder = headFirst ? "head-modifier" : "modifier-head";
            newDeriv.addFeature("BridgeFn",
                "inject_order=" + headModifierOrder + "," + "pos=" +
                ex.languageInfo.getCanonicalPos(headDeriv.start) + "-" +
                ex.languageInfo.getCanonicalPos(modifierDeriv.start));
          }
          newDeriv.addFeature("BridgeFn", "binary=" + binary);

          res.add(newDeriv);
        }
      }
    }
    return res;
  }

  // Checks whether "var" is used as a binary in "formula"
  private boolean varIsBinary(Formula formula, String var) {
    boolean isBinary = false;
    LispTree tree = formula.toLispTree();
    VariableFormula vf = new VariableFormula(var);
    for (LispTree child : tree.children) {
      if (child.isLeaf())
        continue;
      if (child.children.size() == 2 && vf.equals(Formulas.fromLispTree(child.child(0)))) {
        isBinary = true;
        break;
      }
      if (varIsBinary(Formulas.fromLispTree(child), var)) {
        isBinary = true;
        break;
      }
    }
    return isBinary;
  }

  private Formula buildBridge(Formula headFormula, Formula modifierFormula, Formula binary) {
    // Handle cases like "what state has the most cities" where "has the" is mapped
    // to "contains" predicate via bridging but "most" triggers a nested lambda w/
    // argmax on a count relation
    // (Corresponds to $MetaMetaOperator in grammar)
    if (modifierFormula instanceof LambdaFormula) {
      LambdaFormula lf = (LambdaFormula) modifierFormula;
      if (varIsBinary(lf, lf.var)) {
        Formula newBinary = Formulas.lambdaApply(lf, binary);
        if (newBinary instanceof LambdaFormula) {
          Formula result = Formulas.lambdaApply((LambdaFormula) newBinary, headFormula);
          return result;
        }
      }
    }

    Formula join = new JoinFormula(binary, modifierFormula);
    Formula merge = new MergeFormula(Mode.and, headFormula, join);
    //Don't merge on ints and floats
    return (headFormula.equals(intFormula) || headFormula.equals(floatFormula)) ? join : merge;
  }

  private Formula buildBridgeFromCvt(JoinFormula headFormula, Formula modifierFormula, Formula binary) {
    Formula join = new JoinFormula(binary, modifierFormula);
    Formula merge = new MergeFormula(Mode.and, headFormula.child, join);
    return new JoinFormula(headFormula.relation, merge);
  }
}
TOP

Related Classes of edu.stanford.nlp.sempre.BridgeFn

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.