package edu.stanford.nlp.sentiment;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.util.Generics;
/**
* This tool is of very limited scope: it converts a model built with
* the Matlab version of the code to the Java version of the code. It
* is useful to save this tool in case the format of the Java model
* changes, in which case this will let us easily recreate it.
* <br>
* Another set of matrices is in <br>
* /u/nlp/data/sentiment/binary/model_binary_best_asTextFiles/
*
* @author John Bauer
*/
public class ConvertMatlabModel {
private ConvertMatlabModel() {} // static class
/** Will not overwrite an existing word vector if it is already there */
public static void copyWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
if (wordVectors.containsKey(target) || !wordVectors.containsKey(source)) {
return;
}
System.err.println("Using wordVector " + source + " for " + target);
wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
}
/** <br>Will</br> overwrite an existing word vector */
public static void replaceWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
if (!wordVectors.containsKey(source)) {
return;
}
wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
}
public static SimpleMatrix loadMatrix(String binaryName, String textName) throws IOException {
File matrixFile = new File(binaryName);
if (matrixFile.exists()) {
return SimpleMatrix.loadBinary(matrixFile.getPath());
}
matrixFile = new File(textName);
if (matrixFile.exists()) {
return NeuralUtils.loadTextMatrix(matrixFile);
}
throw new RuntimeException("Could not find either " + binaryName + " or " + textName);
}
public static void main(String[] args) throws IOException {
String basePath = "/user/socherr/scr/projects/semComp/RNTN/src/params/";
int numSlices = 25;
boolean useEscapedParens = false;
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-slices")) {
numSlices = Integer.parseInt(args[argIndex + 1]);
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-path")) {
basePath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-useEscapedParens")) {
useEscapedParens = true;
argIndex += 1;
} else {
System.err.println("Unknown argument " + args[argIndex]);
System.exit(2);
}
}
SimpleMatrix[] slices = new SimpleMatrix[numSlices];
for (int i = 0; i < numSlices; ++i) {
slices[i] = loadMatrix(basePath + "bin/Wt_" + (i + 1) + ".bin", basePath + "Wt_" + (i + 1) + ".txt");
}
SimpleTensor tensor = new SimpleTensor(slices);
System.err.println("W tensor size: " + tensor.numRows() + "x" + tensor.numCols() + "x" + tensor.numSlices());
SimpleMatrix W = loadMatrix(basePath + "bin/W.bin", basePath + "W.txt");
System.err.println("W matrix size: " + W.numRows() + "x" + W.numCols());
SimpleMatrix Wcat = loadMatrix(basePath + "bin/Wcat.bin", basePath + "Wcat.txt");
System.err.println("W cat size: " + Wcat.numRows() + "x" + Wcat.numCols());
SimpleMatrix combinedWV = loadMatrix(basePath + "bin/Wv.bin", basePath + "Wv.txt");
System.err.println("Word matrix size: " + combinedWV.numRows() + "x" + combinedWV.numCols());
File vocabFile = new File(basePath + "vocab_1.txt");
if (!vocabFile.exists()) {
vocabFile = new File(basePath + "words.txt");
}
List<String> lines = Generics.newArrayList();
for (String line : IOUtils.readLines(vocabFile)) {
lines.add(line.trim());
}
System.err.println("Lines in vocab file: " + lines.size());
Map<String, SimpleMatrix> wordVectors = Generics.newTreeMap();
for (int i = 0; i < lines.size() && i < combinedWV.numCols(); ++i) {
String[] pieces = lines.get(i).split(" +");
if (pieces.length == 0 || pieces.length > 1) {
continue;
}
wordVectors.put(pieces[0], combinedWV.extractMatrix(0, numSlices, i, i+1));
if (pieces[0].equals("UNK")) {
wordVectors.put(SentimentModel.UNKNOWN_WORD, wordVectors.get("UNK"));
}
}
// If there is no ",", we first try to look for an HTML escaping,
// then fall back to "." as better than just a random word vector.
// Same for "``" and ";"
copyWordVector(wordVectors, ",", ",");
copyWordVector(wordVectors, ".", ",");
copyWordVector(wordVectors, ";", ";");
copyWordVector(wordVectors, ".", ";");
copyWordVector(wordVectors, "``", "``");
copyWordVector(wordVectors, "''", "``");
if (useEscapedParens) {
replaceWordVector(wordVectors, "(", "-LRB-");
replaceWordVector(wordVectors, ")", "-RRB-");
}
RNNOptions op = new RNNOptions();
op.numHid = numSlices;
op.lowercaseWordVectors = false;
if (Wcat.numRows() == 2) {
op.classNames = new String[] { "Negative", "Positive" };
op.equivalenceClasses = new int[][] { { 0 }, { 1 } }; // TODO: set to null once old models are updated
op.numClasses = 2;
}
if (!wordVectors.containsKey(SentimentModel.UNKNOWN_WORD)) {
wordVectors.put(SentimentModel.UNKNOWN_WORD, SimpleMatrix.random(numSlices, 1, -0.00001, 0.00001, new Random()));
}
SentimentModel model = SentimentModel.modelFromMatrices(W, Wcat, tensor, wordVectors, op);
model.saveSerialized("matlab.ser.gz");
}
}