Package edu.stanford.nlp.sempre

Source Code of edu.stanford.nlp.sempre.Formulas

package edu.stanford.nlp.sempre;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import fig.basic.LispTree;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Utilities for working with Formulas.
*
* @author Percy Liang
*/
public abstract class Formulas {
  public static Formula fromLispTree(LispTree tree) {
    // Primitive predicate
    if (tree.isLeaf()) return new ValueFormula<NameValue>(new NameValue(tree.value, null));

    String func = tree.child(0).value;

    if (func != null) {
      if (func.equals("boolean") || func.equals("number") || func.equals("string") || func.equals("date"))
        return new ValueFormula<Value>(Values.fromLispTree(tree));
      if (func.equals("var"))
        return new VariableFormula(tree.child(1).value);
      if (func.equals("lambda"))
        return new LambdaFormula(tree.child(1).value, fromLispTree(tree.child(2)));
      if (func.equals("mark"))
        return new MarkFormula(tree.child(1).value, fromLispTree(tree.child(2)));
      if (func.equals("not"))
        return new NotFormula(fromLispTree(tree.child(1)));
      if (func.equals("reverse"))
        return new ReverseFormula(fromLispTree(tree.child(1)));
      if (func.equals("call")) {
        Formula callFunc = fromLispTree(tree.child(1));
        List<Formula> args = Lists.newArrayList();
        for (int i = 2; i < tree.children.size(); i++)
          args.add(fromLispTree(tree.child(i)));
        return new CallFormula(callFunc, args);
      }
    }

    { // Merge: (intersect (fb:type.object.type fb:people.person) (fb:people.person.children fb:en.barack_obama))
      MergeFormula.Mode mode = MergeFormula.parseMode(func);
      if (mode != null)
        return new MergeFormula(mode, fromLispTree(tree.child(1)), fromLispTree(tree.child(2)));
    }

    { // Aggregate: (count (fb:type.object.type fb:people.person))
      AggregateFormula.Mode mode = AggregateFormula.parseMode(func);
      if (mode != null)
        return new AggregateFormula(mode, fromLispTree(tree.child(1)));
    }

    { // Superlative: (argmax 1 1 (fb:type.object.type fb:people.person) (lambda x (!fb:people.person.height_meters (var x))))
      SuperlativeFormula.Mode mode = SuperlativeFormula.parseMode(func);
      if (mode != null) {
        return new SuperlativeFormula(
            mode,
            Integer.parseInt(tree.child(1).value),
            Integer.parseInt(tree.child(2).value),
            fromLispTree(tree.child(3)),
            fromLispTree(tree.child(4)));
      }
    }

    // Default is join: (fb:type.object.type fb:people.person)
    if (tree.children.size() != 2)
      throw new RuntimeException("Invalid number of arguments for join (want 2): " + tree);
    return new JoinFormula(fromLispTree(tree.child(0)), fromLispTree(tree.child(1)));
  }

  // Replace occurrences of the variable reference |var| with |formula|.
  public static Formula substituteVar(Formula formula, final String var, final Formula replaceFormula) {
    return formula.map(
        new Function<Formula, Formula>() {
          public Formula apply(Formula formula) {
            if (formula instanceof VariableFormula) {  // Replace variable
              String name = ((VariableFormula) formula).name;
              return var.equals(name) ? replaceFormula : formula;
            } else if (formula instanceof LambdaFormula) {
              if (((LambdaFormula) formula).var.equals(var)) // |var| is bound, so don't substitute inside
                return formula;
            }
            return null;
          }
        });
  }

  // Beta-reduction.
  public static Formula lambdaApply(LambdaFormula func, Formula arg) {
    return substituteVar(func.body, func.var, arg);
  }

  // Apply all the nested LambdaFormula's.
  public static Formula betaReduction(Formula formula) {
    return formula.map(
        new Function<Formula, Formula>() {
          public Formula apply(Formula formula) {
            if (formula instanceof JoinFormula) {
              Formula relation = betaReduction(((JoinFormula) formula).relation);
              Formula child = ((JoinFormula) formula).child;
              if (relation instanceof LambdaFormula)
                return betaReduction(lambdaApply((LambdaFormula) relation, child));
            }
            return null;
          }
        });
  }

  // Return whether |formula| contains a free instance of |var|.
  public static boolean containsFreeVar(Formula formula, VariableFormula var) {
    if (formula instanceof PrimitiveFormula)
      return formula.equals(var);
    if (formula instanceof MergeFormula) {
      MergeFormula merge = (MergeFormula)formula;
      return containsFreeVar(merge.child1, var) || containsFreeVar(merge.child2, var);
    }
    if (formula instanceof JoinFormula) {
      JoinFormula join = (JoinFormula)formula;
      return containsFreeVar(join.relation, var) || containsFreeVar(join.child, var);
    }
    if (formula instanceof LambdaFormula) {
      LambdaFormula lambda = (LambdaFormula)formula;
      if (lambda.var.equals(var.name)) return false// Blocked by bound variable
      return containsFreeVar(lambda.body, var);
    }
    if (formula instanceof MarkFormula) {
      MarkFormula mark = (MarkFormula)formula;
      // Note: marks are transparent, unlike lambdas
      return containsFreeVar(mark.body, var);
    }
    if (formula instanceof ReverseFormula) {
      return containsFreeVar(((ReverseFormula)formula).child, var);
    }
    if (formula instanceof AggregateFormula) {
      return containsFreeVar(((AggregateFormula)formula).child, var);
    }
    if (formula instanceof SuperlativeFormula) {
      SuperlativeFormula superlative = (SuperlativeFormula)formula;
      return containsFreeVar(superlative.head, var) || containsFreeVar(superlative.relation, var);
    }
    if (formula instanceof NotFormula) {
      NotFormula notForm = (NotFormula) formula;
      return containsFreeVar(notForm.child, var);
    }
    throw new RuntimeException("Unhandled: " + formula);
  }
   
  // TODO: use Formula.map
  public static Set<String> extractAtomicFreebaseElements(Formula formula) {
    Set<String> res = new HashSet<String>();
    LispTree formulaTree = formula.toLispTree();
    extractAtomicFreebaseElements(formulaTree, res);
    return res;
  }
  private static void extractAtomicFreebaseElements(LispTree formulaTree,
                                                    Set<String> res) {
    //base
    if (formulaTree.isLeaf()) {
      if (formulaTree.value.startsWith("fb:"))
        res.add(formulaTree.value);
      else if (formulaTree.value.startsWith("!fb:"))
        res.add(formulaTree.value.substring(1));
    }
    //recursion
    else {
      for (LispTree child : formulaTree.children) {
        extractAtomicFreebaseElements(child, res);
      }
    }
  }

  // TODO: remove
  public static boolean isCountFormula(Formula formula) {
    if (formula instanceof AggregateFormula)
      return ((AggregateFormula) formula).mode == AggregateFormula.Mode.count;
    if (formula instanceof JoinFormula) {
      Formula relation = ((JoinFormula) formula).relation;
      if (relation instanceof LambdaFormula) {
        Formula l = ((LambdaFormula) relation).body;
        if (l instanceof AggregateFormula)
          return ((AggregateFormula) l).mode == AggregateFormula.Mode.count;
      }
    }
    return false;
  }

  public static String getString(Formula formula) {
    if (formula instanceof ValueFormula) {
      Value value = ((ValueFormula)formula).value;
      if (value instanceof StringValue)
        return ((StringValue)value).value;
    }
    return null;
  }
  public static String getNameId(Formula formula) {
    if (formula instanceof ValueFormula) {
      Value value = ((ValueFormula)formula).value;
      if (value instanceof NameValue)
        return ((NameValue)value).id;
    }
    return null;
  }

  public static ValueFormula<NameValue> newNameFormula(String id) {
    return new ValueFormula<NameValue>(new NameValue(id));
  }
 
  /*
   * Extract all subformulas in a string format (to also have primitive values)
   */
  public static Set<String> extractSubparts(Formula f) {
    Set<String> res = new HashSet<String>();
    extractSubpartsRecursive(f,res);
    return res;
  }

  private static void extractSubpartsRecursive(Formula f, Set<String> res) {
    //base
    res.add(f.toString());
    //recurse
    if(f instanceof AggregateFormula) {
      AggregateFormula aggFormula = (AggregateFormula) f;
      extractSubpartsRecursive(aggFormula, res);
    }
    else if(f instanceof CallFormula) {
      CallFormula callFormula = (CallFormula) f;
      extractSubpartsRecursive(callFormula.func,res);
      for(Formula argFormula: callFormula.args)
        extractSubpartsRecursive(argFormula, res);    
    }
    else if(f instanceof JoinFormula) {
      JoinFormula joinFormula = (JoinFormula) f;
      extractSubpartsRecursive(joinFormula.relation,res);
      extractSubpartsRecursive(joinFormula.child,res);
    }
    else if(f instanceof LambdaFormula) {
      LambdaFormula lambdaFormula = (LambdaFormula) f;
      extractSubpartsRecursive(lambdaFormula.body, res);
    }
    else if(f instanceof MarkFormula) {
      MarkFormula markFormula = (MarkFormula) f;
      extractSubpartsRecursive(markFormula.body, res);
    }
    else if(f instanceof MergeFormula) {
      MergeFormula mergeFormula = (MergeFormula) f;
      extractSubpartsRecursive(mergeFormula.child1, res);
      extractSubpartsRecursive(mergeFormula.child2, res);
    }
    else if(f instanceof NotFormula) {
      NotFormula notFormula = (NotFormula) f;
      extractSubpartsRecursive(notFormula.child, res);
    }
    else if(f instanceof ReverseFormula) {
      ReverseFormula revFormula = (ReverseFormula) f;
      extractSubpartsRecursive(revFormula.child,res);
    }
    else if(f instanceof SuperlativeFormula) {
      SuperlativeFormula superlativeFormula = (SuperlativeFormula) f;
      extractSubpartsRecursive(superlativeFormula.head, res);
      extractSubpartsRecursive(superlativeFormula.relation, res);
    }
  }
}
TOP

Related Classes of edu.stanford.nlp.sempre.Formulas

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.