package edu.stanford.nlp.parser.dvparser;
import java.io.BufferedWriter;
import java.io.FileFilter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.RerankingParserQuery;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
/**
* A tool which takes all of the n-grams of a certain length and looks
* for other n-grams which are close using distance between word vectors.
* Useful for coming up with interesting analysis of how the word vectors
* help the parsing task.
*
* @author John Bauer
*/
public class FindNearestNeighbors {
// TODO: parameter?
static final int numNeighbors = 5;
static final int maxLength = 8;
public static class ParseRecord {
final List<Word> sentence;
final Tree goldTree;
final Tree parse;
final SimpleMatrix rootVector;
final IdentityHashMap<Tree, SimpleMatrix> nodeVectors;
public ParseRecord(List<Word> sentence, Tree goldTree, Tree parse, SimpleMatrix rootVector, IdentityHashMap<Tree, SimpleMatrix> nodeVectors) {
this.sentence = sentence;
this.goldTree = goldTree;
this.parse = parse;
this.rootVector = rootVector;
this.nodeVectors = nodeVectors;
}
}
public static void main(String[] args) throws Exception {
String modelPath = null;
String outputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<String>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (modelPath == null) {
throw new IllegalArgumentException("Need to specify -model");
}
if (testTreebankPath == null) {
throw new IllegalArgumentException("Need to specify -testTreebank");
}
if (outputPath == null) {
throw new IllegalArgumentException("Need to specify -output");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
Treebank testTreebank = null;
if (testTreebankPath != null) {
System.err.println("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
System.err.println("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
System.err.println("Read in " + testTreebank.size() + " trees for testing");
}
FileWriter out = new FileWriter(outputPath);
BufferedWriter bout = new BufferedWriter(out);
System.err.println("Parsing " + testTreebank.size() + " trees");
int count = 0;
List<ParseRecord> records = Generics.newArrayList();
for (Tree goldTree : testTreebank) {
List<Word> tokens = goldTree.yieldWords();
ParserQuery parserQuery = lexparser.parserQuery();
if (!parserQuery.parse(tokens)) {
throw new AssertionError("Could not parse: " + tokens);
}
if (!(parserQuery instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
}
RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
}
DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
SimpleMatrix rootVector = null;
for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
if (entry.getKey().label().value().equals("ROOT")) {
rootVector = entry.getValue();
break;
}
}
if (rootVector == null) {
throw new AssertionError("Could not find root nodevector");
}
out.write(tokens + "\n");
out.write(tree.getTree() + "\n");
for (int i = 0; i < rootVector.getNumElements(); ++i) {
out.write(" " + rootVector.get(i));
}
out.write("\n\n\n");
count++;
if (count % 10 == 0) {
System.err.print(" " + count);
}
records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
}
System.err.println(" done parsing");
List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
for (int i = 0; i < records.size(); ++i) {
ParseRecord record = records.get(i);
for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
if (entry.getKey().getLeaves().size() <= maxLength) {
subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
}
}
}
System.err.println("There are " + subtrees.size() + " subtrees in the set of trees");
PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<ScoredObject<Pair<Tree, Tree>>>(101, ScoredComparator.DESCENDING_COMPARATOR);
for (int i = 0; i < subtrees.size(); ++i) {
System.err.println(subtrees.get(i).first().yieldWords());
System.err.println(subtrees.get(i).first());
for (int j = 0; j < subtrees.size(); ++j) {
if (i == j) {
continue;
}
// TODO: look at basic category?
double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
bestmatches.add(new ScoredObject<Pair<Tree, Tree>>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
if (bestmatches.size() > 100) {
bestmatches.poll();
}
}
List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
while (bestmatches.size() > 0) {
ordered.add(bestmatches.poll());
}
Collections.reverse(ordered);
for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
System.err.println(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
}
System.err.println();
System.err.println();
bestmatches.clear();
}
/*
for (int i = 0; i < records.size(); ++i) {
if (i % 10 == 0) {
System.err.print(" " + i);
}
List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
for (int j = 0; j < records.size(); ++j) {
if (i == j) continue;
double score = 0.0;
int matches = 0;
for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
if (firstBasic.equals(secondBasic)) {
++matches;
double normF = first.getValue().minus(second.getValue()).normF();
score += normF * normF;
}
}
}
if (matches == 0) {
score = Double.POSITIVE_INFINITY;
} else {
score = score / matches;
}
//double score = records.get(i).vector.minus(records.get(j).vector).normF();
scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
}
Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
out.write(records.get(i).sentence.toString() + "\n");
for (int j = 0; j < numNeighbors; ++j) {
out.write(" " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
}
out.write("\n\n");
}
System.err.println();
*/
bout.flush();
out.flush();
out.close();
}
}