package edu.stanford.nlp.parser.dvparser;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.Reranker;
import edu.stanford.nlp.parser.lexparser.RerankerQuery;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.Generics;
public class CombinedDVModelReranker implements Reranker {
private final Options op;
private final List<DVModel> models;
public CombinedDVModelReranker(Options op, List<DVModel> models) {
this.op = op;
this.models = models;
}
public Query process(List<? extends HasWord> sentence) {
return new Query();
}
public List<Eval> getEvals() {
return Collections.emptyList();
}
public class Query implements RerankerQuery {
private final TreeTransformer transformer;
private final List<DVParserCostAndGradient> scorers;
public Query() {
this.transformer = LexicalizedParser.buildTrainTransformer(op);
this.scorers = Generics.newArrayList();
for (DVModel model : models) {
this.scorers.add(new DVParserCostAndGradient(null, null, model, op));
}
}
public double score(Tree tree) {
double totalScore = 0.0;
for (DVParserCostAndGradient scorer : scorers) {
IdentityHashMap<Tree, SimpleMatrix> nodeVectors = Generics.newIdentityHashMap();
Tree transformedTree = transformer.transformTree(tree);
if (op.trainOptions.useContextWords) {
Trees.convertToCoreLabels(transformedTree);
transformedTree.setSpans();
}
double score = scorer.score(transformedTree, nodeVectors);
totalScore += score;
//totalScore = Math.max(totalScore, score);
}
return totalScore;
}
}
}