Package edu.stanford.nlp.parser.dvparser

Source Code of edu.stanford.nlp.parser.dvparser.DVModelReranker$Query

package edu.stanford.nlp.parser.dvparser;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.Reranker;
import edu.stanford.nlp.parser.lexparser.RerankerQuery;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.Generics;

public class DVModelReranker implements Reranker {
  private final Options op;
  private final DVModel model;
 
  public DVModelReranker(DVModel model) {
    this.op = model.op;
    this.model = model;
  }

  DVModel getModel() {
    return model;
  }

  public Query process(List<? extends HasWord> sentence) {
    return new Query();
  }

  public List<Eval> getEvals() {
    Eval eval = new UnknownWordPrinter(model);
    return Collections.singletonList(eval);
  }

  public class Query implements RerankerQuery {
    private final TreeTransformer transformer;
    private final DVParserCostAndGradient scorer;

    private List<DeepTree> deepTrees;

    public Query() {
      this.transformer = LexicalizedParser.buildTrainTransformer(op);
      this.scorer = new DVParserCostAndGradient(null, null, model, op);
      this.deepTrees = Generics.newArrayList();
    }
   
    public double score(Tree tree) {
      IdentityHashMap<Tree, SimpleMatrix> nodeVectors = Generics.newIdentityHashMap();
      Tree transformedTree = transformer.transformTree(tree);
      if (op.trainOptions.useContextWords) {
        Trees.convertToCoreLabels(transformedTree);
        transformedTree.setSpans();
      }
      double score = scorer.score(transformedTree, nodeVectors);
      deepTrees.add(new DeepTree(tree, nodeVectors, score));
      return score;
    }

    public List<DeepTree> getDeepTrees() {
      return deepTrees;
    }
  } 

  private static final long serialVersionUID = 7897546308624261207L;
}
TOP

Related Classes of edu.stanford.nlp.parser.dvparser.DVModelReranker$Query

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.