package edu.stanford.nlp.parser.dvparser;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileFilter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.io.FileSystem;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.RerankerQuery;
import edu.stanford.nlp.parser.lexparser.RerankingParserQuery;
import edu.stanford.nlp.process.DocumentPreprocessor;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
public class ParseAndPrintMatrices {
public static void outputMatrix(BufferedWriter bout, SimpleMatrix matrix) throws IOException {
for (int i = 0; i < matrix.getNumElements(); ++i) {
bout.write(" " + matrix.get(i));
}
bout.newLine();
}
public static void outputTreeMatrices(BufferedWriter bout, Tree tree, IdentityHashMap<Tree, SimpleMatrix> vectors) throws IOException {
if (tree.isPreTerminal() || tree.isLeaf()) {
return;
}
for (int i = tree.children().length - 1; i >= 0; i--) {
outputTreeMatrices(bout, tree.children()[i], vectors);
}
outputMatrix(bout, vectors.get(tree));
}
public static Tree findRootTree(IdentityHashMap<Tree, SimpleMatrix> vectors) {
for (Tree tree : vectors.keySet()) {
if (tree.label().value().equals("ROOT")) {
return tree;
}
}
throw new RuntimeException("Could not find root");
}
public static void main(String[] args) throws IOException {
String modelPath = null;
String outputPath = null;
String inputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = Generics.newArrayList();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-input")) {
inputPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else {
unusedArgs.add(args[argIndex++]);
}
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
DVModel model = DVParser.getModelFromLexicalizedParser(parser);
File outputFile = new File(outputPath);
FileSystem.checkNotExistsOrFail(outputFile);
FileSystem.mkdirOrFail(outputFile);
int count = 0;
if (inputPath != null) {
Reader input = new BufferedReader(new FileReader(inputPath));
DocumentPreprocessor processor = new DocumentPreprocessor(input);
for (List<HasWord> sentence : processor) {
count++; // index from 1
ParserQuery pq = parser.parserQuery();
if (!(pq instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a RerankingParserQuery");
}
RerankingParserQuery rpq = (RerankingParserQuery) pq;
if (!rpq.parse(sentence)) {
throw new RuntimeException("Unparsable sentence: " + sentence);
}
RerankerQuery reranker = rpq.rerankerQuery();
if (!(reranker instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a DVModelReranker");
}
DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
System.err.println(entry.getKey() + " " + entry.getValue());
}
FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
BufferedWriter bout = new BufferedWriter(fout);
bout.write(Sentence.listToString(sentence));
bout.newLine();
bout.write(deepTree.getTree().toString());
bout.newLine();
for (HasWord word : sentence) {
outputMatrix(bout, model.getWordVector(word.word()));
}
Tree rootTree = findRootTree(vectors);
outputTreeMatrices(bout, rootTree, vectors);
bout.flush();
fout.close();
}
}
}
}