package edu.stanford.nlp.sentiment;
import java.io.Serializable;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
public class SentimentModel implements Serializable {
/**
* Nx2N+1, where N is the size of the word vectors
*/
public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform;
/**
* 2Nx2NxN, where N is the size of the word vectors
*/
public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensors;
/**
* CxN+1, where N = size of word vectors, C is the number of classes
*/
public final TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification;
/**
* CxN+1, where N = size of word vectors, C is the number of classes
*/
public final Map<String, SimpleMatrix> unaryClassification;
public Map<String, SimpleMatrix> wordVectors;
/**
* How many classes the RNN is built to test against
*/
public final int numClasses;
/**
* Dimension of hidden layers, size of word vectors, etc
*/
public final int numHid;
/**
* Cached here for easy calculation of the model size;
* TwoDimensionalMap does not return that in O(1) time
*/
public final int numBinaryMatrices;
/** How many elements a transformation matrix has */
public final int binaryTransformSize;
/** How many elements the binary transformation tensors have */
public final int binaryTensorSize;
/** How many elements a classification matrix has */
public final int binaryClassificationSize;
/**
* Cached here for easy calculation of the model size;
* TwoDimensionalMap does not return that in O(1) time
*/
public final int numUnaryMatrices;
/** How many elements a classification matrix has */
public final int unaryClassificationSize;
/**
* we just keep this here for convenience
*/
transient SimpleMatrix identity;
/**
* A random number generator - keeping it here lets us reproduce results
*/
final Random rand;
static final String UNKNOWN_WORD = "*UNK*";
/**
* Will store various options specific to this model
*/
final RNNOptions op;
/*
// An example of how you could read in old models with readObject to fix the serialization
// You would first read in the old model, then reserialize it
private void readObject(ObjectInputStream in)
throws IOException, ClassNotFoundException
{
ObjectInputStream.GetField fields = in.readFields();
binaryTransform = ErasureUtils.uncheckedCast(fields.get("binaryTransform", null));
// transform binaryTensors
binaryTensors = TwoDimensionalMap.treeMap();
TwoDimensionalMap<String, String, edu.stanford.nlp.rnn.SimpleTensor> oldTensors = ErasureUtils.uncheckedCast(fields.get("binaryTensors", null));
for (String first : oldTensors.firstKeySet()) {
for (String second : oldTensors.get(first).keySet()) {
binaryTensors.put(first, second, new SimpleTensor(oldTensors.get(first, second).slices));
}
}
binaryClassification = ErasureUtils.uncheckedCast(fields.get("binaryClassification", null));
unaryClassification = ErasureUtils.uncheckedCast(fields.get("unaryClassification", null));
wordVectors = ErasureUtils.uncheckedCast(fields.get("wordVectors", null));
if (fields.defaulted("numClasses")) {
throw new RuntimeException();
}
numClasses = fields.get("numClasses", 0);
if (fields.defaulted("numHid")) {
throw new RuntimeException();
}
numHid = fields.get("numHid", 0);
if (fields.defaulted("numBinaryMatrices")) {
throw new RuntimeException();
}
numBinaryMatrices = fields.get("numBinaryMatrices", 0);
if (fields.defaulted("binaryTransformSize")) {
throw new RuntimeException();
}
binaryTransformSize = fields.get("binaryTransformSize", 0);
if (fields.defaulted("binaryTensorSize")) {
throw new RuntimeException();
}
binaryTensorSize = fields.get("binaryTensorSize", 0);
if (fields.defaulted("binaryClassificationSize")) {
throw new RuntimeException();
}
binaryClassificationSize = fields.get("binaryClassificationSize", 0);
if (fields.defaulted("numUnaryMatrices")) {
throw new RuntimeException();
}
numUnaryMatrices = fields.get("numUnaryMatrices", 0);
if (fields.defaulted("unaryClassificationSize")) {
throw new RuntimeException();
}
unaryClassificationSize = fields.get("unaryClassificationSize", 0);
rand = ErasureUtils.uncheckedCast(fields.get("rand", null));
op = ErasureUtils.uncheckedCast(fields.get("op", null));
op.classNames = op.DEFAULT_CLASS_NAMES;
op.equivalenceClasses = op.APPROXIMATE_EQUIVALENCE_CLASSES;
op.equivalenceClassNames = op.DEFAULT_EQUIVALENCE_CLASS_NAMES;
}
*/
/**
* Given single matrices and sets of options, create the
* corresponding SentimentModel. Useful for creating a Java version
* of a model trained in some other manner, such as using the
* original Matlab code.
*/
static SentimentModel modelFromMatrices(SimpleMatrix W, SimpleMatrix Wcat, SimpleTensor Wt, Map<String, SimpleMatrix> wordVectors, RNNOptions op) {
if (!op.combineClassification || !op.simplifiedModel) {
throw new IllegalArgumentException("Can only create a model using this method if combineClassification and simplifiedModel are turned on");
}
TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform = TwoDimensionalMap.treeMap();
binaryTransform.put("", "", W);
TwoDimensionalMap<String, String, SimpleTensor> binaryTensors = TwoDimensionalMap.treeMap();
binaryTensors.put("", "", Wt);
TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification = TwoDimensionalMap.treeMap();
Map<String, SimpleMatrix> unaryClassification = Generics.newTreeMap();
unaryClassification.put("", Wcat);
return new SentimentModel(binaryTransform, binaryTensors, binaryClassification, unaryClassification, wordVectors, op);
}
private SentimentModel(TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform,
TwoDimensionalMap<String, String, SimpleTensor> binaryTensors,
TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification,
Map<String, SimpleMatrix> unaryClassification,
Map<String, SimpleMatrix> wordVectors,
RNNOptions op) {
this.op = op;
this.binaryTransform = binaryTransform;
this.binaryTensors = binaryTensors;
this.binaryClassification = binaryClassification;
this.unaryClassification = unaryClassification;
this.wordVectors = wordVectors;
this.numClasses = op.numClasses;
if (op.numHid <= 0) {
int nh = 0;
for (SimpleMatrix wv : wordVectors.values()) {
nh = wv.getNumElements();
}
this.numHid = nh;
} else {
this.numHid = op.numHid;
}
this.numBinaryMatrices = binaryTransform.size();
binaryTransformSize = numHid * (2 * numHid + 1);
if (op.useTensors) {
binaryTensorSize = numHid * numHid * numHid * 4;
} else {
binaryTensorSize = 0;
}
binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1);
numUnaryMatrices = unaryClassification.size();
unaryClassificationSize = numClasses * (numHid + 1);
rand = new Random(op.randomSeed);
identity = SimpleMatrix.identity(numHid);
}
/**
* The traditional way of initializing an empty model suitable for training.
*/
public SentimentModel(RNNOptions op, List<Tree> trainingTrees) {
this.op = op;
rand = new Random(op.randomSeed);
if (op.randomWordVectors) {
initRandomWordVectors(trainingTrees);
} else {
readWordVectors();
}
if (op.numHid > 0) {
this.numHid = op.numHid;
} else {
int size = 0;
for (SimpleMatrix vector : wordVectors.values()) {
size = vector.getNumElements();
break;
}
this.numHid = size;
}
TwoDimensionalSet<String, String> binaryProductions = TwoDimensionalSet.hashSet();
if (op.simplifiedModel) {
binaryProductions.add("", "");
} else {
// TODO
// figure out what binary productions we have in these trees
// Note: the current sentiment training data does not actually
// have any constituent labels
throw new UnsupportedOperationException("Not yet implemented");
}
Set<String> unaryProductions = Generics.newHashSet();
if (op.simplifiedModel) {
unaryProductions.add("");
} else {
// TODO
// figure out what unary productions we have in these trees (preterminals only, after the collapsing)
throw new UnsupportedOperationException("Not yet implemented");
}
this.numClasses = op.numClasses;
identity = SimpleMatrix.identity(numHid);
binaryTransform = TwoDimensionalMap.treeMap();
binaryTensors = TwoDimensionalMap.treeMap();
binaryClassification = TwoDimensionalMap.treeMap();
// When making a flat model (no symantic untying) the
// basicCategory function will return the same basic category for
// all labels, so all entries will map to the same matrix
for (Pair<String, String> binary : binaryProductions) {
String left = basicCategory(binary.first);
String right = basicCategory(binary.second);
if (binaryTransform.contains(left, right)) {
continue;
}
binaryTransform.put(left, right, randomTransformMatrix());
if (op.useTensors) {
binaryTensors.put(left, right, randomBinaryTensor());
}
if (!op.combineClassification) {
binaryClassification.put(left, right, randomClassificationMatrix());
}
}
numBinaryMatrices = binaryTransform.size();
binaryTransformSize = numHid * (2 * numHid + 1);
if (op.useTensors) {
binaryTensorSize = numHid * numHid * numHid * 4;
} else {
binaryTensorSize = 0;
}
binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1);
unaryClassification = Generics.newTreeMap();
// When making a flat model (no symantic untying) the
// basicCategory function will return the same basic category for
// all labels, so all entries will map to the same matrix
for (String unary : unaryProductions) {
unary = basicCategory(unary);
if (unaryClassification.containsKey(unary)) {
continue;
}
unaryClassification.put(unary, randomClassificationMatrix());
}
numUnaryMatrices = unaryClassification.size();
unaryClassificationSize = numClasses * (numHid + 1);
// System.err.println("Binary transform matrices:");
// System.err.println(binaryTransform);
// System.err.println("Binary classification matrices:");
// System.err.println(binaryClassification);
// System.err.println("Unary classification matrices:");
// System.err.println(unaryClassification);
}
SimpleTensor randomBinaryTensor() {
double range = 1.0 / (4.0 * numHid);
SimpleTensor tensor = SimpleTensor.random(numHid * 2, numHid * 2, numHid, -range, range, rand);
return tensor.scale(op.trainOptions.scalingForInit);
}
SimpleMatrix randomTransformMatrix() {
SimpleMatrix binary = new SimpleMatrix(numHid, numHid * 2 + 1);
// bias column values are initialized zero
binary.insertIntoThis(0, 0, randomTransformBlock());
binary.insertIntoThis(0, numHid, randomTransformBlock());
return binary.scale(op.trainOptions.scalingForInit);
}
SimpleMatrix randomTransformBlock() {
double range = 1.0 / (Math.sqrt((double) numHid) * 2.0);
return SimpleMatrix.random(numHid,numHid,-range,range,rand).plus(identity);
}
/**
* Returns matrices of the right size for either binary or unary (terminal) classification
*/
SimpleMatrix randomClassificationMatrix() {
SimpleMatrix score = new SimpleMatrix(numClasses, numHid + 1);
// Leave the bias column with 0 values
double range = 1.0 / (Math.sqrt((double) numHid));
score.insertIntoThis(0, 0, SimpleMatrix.random(numClasses, numHid, -range, range, rand));
return score.scale(op.trainOptions.scalingForInit);
}
SimpleMatrix randomWordVector() {
return randomWordVector(op.numHid, rand);
}
static SimpleMatrix randomWordVector(int size, Random rand) {
return NeuralUtils.randomGaussian(size, 1, rand);
}
void initRandomWordVectors(List<Tree> trainingTrees) {
if (op.numHid == 0) {
throw new RuntimeException("Cannot create random word vectors for an unknown numHid");
}
Set<String> words = Generics.newHashSet();
words.add(UNKNOWN_WORD);
for (Tree tree : trainingTrees) {
List<Tree> leaves = tree.getLeaves();
for (Tree leaf : leaves) {
String word = leaf.label().value();
if (op.lowercaseWordVectors) {
word = word.toLowerCase();
}
words.add(word);
}
}
this.wordVectors = Generics.newTreeMap();
for (String word : words) {
SimpleMatrix vector = randomWordVector();
wordVectors.put(word, vector);
}
}
void readWordVectors() {
Embedding embedding = new Embedding(op.wordVectors, op.numHid);
this.wordVectors = Generics.newTreeMap();
// Map<String, SimpleMatrix> rawWordVectors = NeuralUtils.readRawWordVectors(op.wordVectors, op.numHid);
// for (String word : rawWordVectors.keySet()) {
for (String word : embedding.keySet()) {
// TODO: factor out unknown word vector code from DVParser
wordVectors.put(word, embedding.get(word));
}
String unkWord = op.unkWord;
SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
wordVectors.put(UNKNOWN_WORD, unknownWordVector);
if (unknownWordVector == null) {
throw new RuntimeException("Unknown word vector not specified in the word vector file");
}
}
public int totalParamSize() {
int totalSize = 0;
// binaryTensorSize was set to 0 if useTensors=false
totalSize = numBinaryMatrices * (binaryTransformSize + binaryClassificationSize + binaryTensorSize);
totalSize += numUnaryMatrices * unaryClassificationSize;
totalSize += wordVectors.size() * numHid;
return totalSize;
}
public double[] paramsToVector() {
int totalSize = totalParamSize();
return NeuralUtils.paramsToVector(totalSize, binaryTransform.valueIterator(), binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensors.valueIterator()), unaryClassification.values().iterator(), wordVectors.values().iterator());
}
public void vectorToParams(double[] theta) {
NeuralUtils.vectorToParams(theta, binaryTransform.valueIterator(), binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensors.valueIterator()), unaryClassification.values().iterator(), wordVectors.values().iterator());
}
// TODO: combine this and getClassWForNode?
public SimpleMatrix getWForNode(Tree node) {
if (node.children().length == 2) {
String leftLabel = node.children()[0].value();
String leftBasic = basicCategory(leftLabel);
String rightLabel = node.children()[1].value();
String rightBasic = basicCategory(rightLabel);
return binaryTransform.get(leftBasic, rightBasic);
} else if (node.children().length == 1) {
throw new AssertionError("No unary transform matrices, only unary classification");
} else {
throw new AssertionError("Unexpected tree children size of " + node.children().length);
}
}
public SimpleTensor getTensorForNode(Tree node) {
if (!op.useTensors) {
throw new AssertionError("Not using tensors");
}
if (node.children().length == 2) {
String leftLabel = node.children()[0].value();
String leftBasic = basicCategory(leftLabel);
String rightLabel = node.children()[1].value();
String rightBasic = basicCategory(rightLabel);
return binaryTensors.get(leftBasic, rightBasic);
} else if (node.children().length == 1) {
throw new AssertionError("No unary transform matrices, only unary classification");
} else {
throw new AssertionError("Unexpected tree children size of " + node.children().length);
}
}
public SimpleMatrix getClassWForNode(Tree node) {
if (op.combineClassification) {
return unaryClassification.get("");
} else if (node.children().length == 2) {
String leftLabel = node.children()[0].value();
String leftBasic = basicCategory(leftLabel);
String rightLabel = node.children()[1].value();
String rightBasic = basicCategory(rightLabel);
return binaryClassification.get(leftBasic, rightBasic);
} else if (node.children().length == 1) {
String unaryLabel = node.children()[0].value();
String unaryBasic = basicCategory(unaryLabel);
return unaryClassification.get(unaryBasic);
} else {
throw new AssertionError("Unexpected tree children size of " + node.children().length);
}
}
public SimpleMatrix getWordVector(String word) {
return wordVectors.get(getVocabWord(word));
}
public String getVocabWord(String word) {
if (op.lowercaseWordVectors) {
word = word.toLowerCase();
}
if (wordVectors.containsKey(word)) {
return word;
}
// TODO: go through unknown words here
return UNKNOWN_WORD;
}
public String basicCategory(String category) {
if (op.simplifiedModel) {
return "";
}
String basic = op.langpack.basicCategory(category);
if (basic.length() > 0 && basic.charAt(0) == '@') {
basic = basic.substring(1);
}
return basic;
}
public SimpleMatrix getUnaryClassification(String category) {
category = basicCategory(category);
return unaryClassification.get(category);
}
public SimpleMatrix getBinaryClassification(String left, String right) {
if (op.combineClassification) {
return unaryClassification.get("");
} else {
left = basicCategory(left);
right = basicCategory(right);
return binaryClassification.get(left, right);
}
}
public SimpleMatrix getBinaryTransform(String left, String right) {
left = basicCategory(left);
right = basicCategory(right);
return binaryTransform.get(left, right);
}
public SimpleTensor getBinaryTensor(String left, String right) {
left = basicCategory(left);
right = basicCategory(right);
return binaryTensors.get(left, right);
}
public void saveSerialized(String path) {
try {
IOUtils.writeObjectToFile(this, path);
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
public static SentimentModel loadSerialized(String path) {
try {
return IOUtils.readObjectFromURLOrClasspathOrFileSystem(path);
} catch (IOException e) {
throw new RuntimeIOException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeIOException(e);
}
}
public void printParamInformation(int index) {
int curIndex = 0;
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryTransform) {
if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTransform \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
return;
} else {
curIndex += entry.getValue().getNumElements();
}
}
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryClassification) {
if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryClassification \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
return;
} else {
curIndex += entry.getValue().getNumElements();
}
}
for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : binaryTensors) {
if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTensor \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
return;
} else {
curIndex += entry.getValue().getNumElements();
}
}
for (Map.Entry<String, SimpleMatrix> entry : unaryClassification.entrySet()) {
if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
System.err.println("Index " + index + " is element " + (index - curIndex) + " of unaryClassification \"" + entry.getKey() + "\"");
return;
} else {
curIndex += entry.getValue().getNumElements();
}
}
for (Map.Entry<String, SimpleMatrix> entry : wordVectors.entrySet()) {
if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
System.err.println("Index " + index + " is element " + (index - curIndex) + " of wordVector \"" + entry.getKey() + "\"");
return;
} else {
curIndex += entry.getValue().getNumElements();
}
}
System.err.println("Index " + index + " is beyond the length of the parameters; total parameter space was " + totalParamSize());
}
private static final long serialVersionUID = 1;
}