Package edu.stanford.nlp.sentiment

Source Code of edu.stanford.nlp.sentiment.SentimentModel

package edu.stanford.nlp.sentiment;

import java.io.Serializable;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;

public class SentimentModel implements Serializable {
  /**
   * Nx2N+1, where N is the size of the word vectors
   */
  public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform;

  /**
   * 2Nx2NxN, where N is the size of the word vectors
   */
  public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensors;

  /**
   * CxN+1, where N = size of word vectors, C is the number of classes
   */
  public final TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification;

  /**
   * CxN+1, where N = size of word vectors, C is the number of classes
   */
  public final Map<String, SimpleMatrix> unaryClassification;

  public Map<String, SimpleMatrix> wordVectors;

  /**
   * How many classes the RNN is built to test against
   */
  public final int numClasses;

  /**
   * Dimension of hidden layers, size of word vectors, etc
   */
  public final int numHid;

  /**
   * Cached here for easy calculation of the model size;
   * TwoDimensionalMap does not return that in O(1) time
   */
  public final int numBinaryMatrices;

  /** How many elements a transformation matrix has */
  public final int binaryTransformSize;
  /** How many elements the binary transformation tensors have */
  public final int binaryTensorSize;
  /** How many elements a classification matrix has */
  public final int binaryClassificationSize;

  /**
   * Cached here for easy calculation of the model size;
   * TwoDimensionalMap does not return that in O(1) time
   */
  public final int numUnaryMatrices;

  /** How many elements a classification matrix has */
  public final int unaryClassificationSize;

  /**
   * we just keep this here for convenience
   */
  transient SimpleMatrix identity;

  /**
   * A random number generator - keeping it here lets us reproduce results
   */
  final Random rand;

  static final String UNKNOWN_WORD = "*UNK*";

  /**
   * Will store various options specific to this model
   */
  final RNNOptions op;

  /*
  // An example of how you could read in old models with readObject to fix the serialization
  // You would first read in the old model, then reserialize it
  private void readObject(ObjectInputStream in)
    throws IOException, ClassNotFoundException
  {
    ObjectInputStream.GetField fields = in.readFields();
    binaryTransform = ErasureUtils.uncheckedCast(fields.get("binaryTransform", null));

    // transform binaryTensors
    binaryTensors = TwoDimensionalMap.treeMap();
    TwoDimensionalMap<String, String, edu.stanford.nlp.rnn.SimpleTensor> oldTensors = ErasureUtils.uncheckedCast(fields.get("binaryTensors", null));
    for (String first : oldTensors.firstKeySet()) {
      for (String second : oldTensors.get(first).keySet()) {
        binaryTensors.put(first, second, new SimpleTensor(oldTensors.get(first, second).slices));
      }
    }

    binaryClassification = ErasureUtils.uncheckedCast(fields.get("binaryClassification", null));
    unaryClassification = ErasureUtils.uncheckedCast(fields.get("unaryClassification", null));
    wordVectors = ErasureUtils.uncheckedCast(fields.get("wordVectors", null));

    if (fields.defaulted("numClasses")) {
      throw new RuntimeException();
    }
    numClasses = fields.get("numClasses", 0);

    if (fields.defaulted("numHid")) {
      throw new RuntimeException();
    }
    numHid = fields.get("numHid", 0);

    if (fields.defaulted("numBinaryMatrices")) {
      throw new RuntimeException();
    }
    numBinaryMatrices = fields.get("numBinaryMatrices", 0);

    if (fields.defaulted("binaryTransformSize")) {
      throw new RuntimeException();
    }
    binaryTransformSize = fields.get("binaryTransformSize", 0);

    if (fields.defaulted("binaryTensorSize")) {
      throw new RuntimeException();
    }
    binaryTensorSize = fields.get("binaryTensorSize", 0);

    if (fields.defaulted("binaryClassificationSize")) {
      throw new RuntimeException();
    }
    binaryClassificationSize = fields.get("binaryClassificationSize", 0);

    if (fields.defaulted("numUnaryMatrices")) {
      throw new RuntimeException();
    }
    numUnaryMatrices = fields.get("numUnaryMatrices", 0);

    if (fields.defaulted("unaryClassificationSize")) {
      throw new RuntimeException();
    }
    unaryClassificationSize = fields.get("unaryClassificationSize", 0);

    rand = ErasureUtils.uncheckedCast(fields.get("rand", null));
    op = ErasureUtils.uncheckedCast(fields.get("op", null));
    op.classNames = op.DEFAULT_CLASS_NAMES;
    op.equivalenceClasses = op.APPROXIMATE_EQUIVALENCE_CLASSES;
    op.equivalenceClassNames = op.DEFAULT_EQUIVALENCE_CLASS_NAMES;
  }
  */

  /**
   * Given single matrices and sets of options, create the
   * corresponding SentimentModel.  Useful for creating a Java version
   * of a model trained in some other manner, such as using the
   * original Matlab code.
   */
  static SentimentModel modelFromMatrices(SimpleMatrix W, SimpleMatrix Wcat, SimpleTensor Wt, Map<String, SimpleMatrix> wordVectors, RNNOptions op) {
    if (!op.combineClassification || !op.simplifiedModel) {
      throw new IllegalArgumentException("Can only create a model using this method if combineClassification and simplifiedModel are turned on");
    }
    TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform = TwoDimensionalMap.treeMap();
    binaryTransform.put("", "", W);

    TwoDimensionalMap<String, String, SimpleTensor> binaryTensors = TwoDimensionalMap.treeMap();
    binaryTensors.put("", "", Wt);

    TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification = TwoDimensionalMap.treeMap();

    Map<String, SimpleMatrix> unaryClassification = Generics.newTreeMap();
    unaryClassification.put("", Wcat);

    return new SentimentModel(binaryTransform, binaryTensors, binaryClassification, unaryClassification, wordVectors, op);
  }

  private SentimentModel(TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform,
                         TwoDimensionalMap<String, String, SimpleTensor> binaryTensors,
                         TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification,
                         Map<String, SimpleMatrix> unaryClassification,
                         Map<String, SimpleMatrix> wordVectors,
                         RNNOptions op) {
    this.op = op;

    this.binaryTransform = binaryTransform;
    this.binaryTensors = binaryTensors;
    this.binaryClassification = binaryClassification;
    this.unaryClassification = unaryClassification;
    this.wordVectors = wordVectors;
    this.numClasses = op.numClasses;
    if (op.numHid <= 0) {
      int nh = 0;
      for (SimpleMatrix wv : wordVectors.values()) {
        nh = wv.getNumElements();
      }
      this.numHid = nh;
    } else {
      this.numHid = op.numHid;
    }
    this.numBinaryMatrices = binaryTransform.size();
    binaryTransformSize = numHid * (2 * numHid + 1);
    if (op.useTensors) {
      binaryTensorSize = numHid * numHid * numHid * 4;
    } else {
      binaryTensorSize = 0;
    }
    binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1);

    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numClasses * (numHid + 1);

    rand = new Random(op.randomSeed);

    identity = SimpleMatrix.identity(numHid);
  }

  /**
   * The traditional way of initializing an empty model suitable for training.
   */
  public SentimentModel(RNNOptions op, List<Tree> trainingTrees) {
    this.op = op;
    rand = new Random(op.randomSeed);

    if (op.randomWordVectors) {
      initRandomWordVectors(trainingTrees);
    } else {
      readWordVectors();
    }
    if (op.numHid > 0) {
      this.numHid = op.numHid;
    } else {
      int size = 0;
      for (SimpleMatrix vector : wordVectors.values()) {
        size = vector.getNumElements();
        break;
      }
      this.numHid = size;
    }

    TwoDimensionalSet<String, String> binaryProductions = TwoDimensionalSet.hashSet();
    if (op.simplifiedModel) {
      binaryProductions.add("", "");
    } else {
      // TODO
      // figure out what binary productions we have in these trees
      // Note: the current sentiment training data does not actually
      // have any constituent labels
      throw new UnsupportedOperationException("Not yet implemented");
    }

    Set<String> unaryProductions = Generics.newHashSet();
    if (op.simplifiedModel) {
      unaryProductions.add("");
    } else {
      // TODO
      // figure out what unary productions we have in these trees (preterminals only, after the collapsing)
      throw new UnsupportedOperationException("Not yet implemented");
    }

    this.numClasses = op.numClasses;

    identity = SimpleMatrix.identity(numHid);

    binaryTransform = TwoDimensionalMap.treeMap();
    binaryTensors = TwoDimensionalMap.treeMap();
    binaryClassification = TwoDimensionalMap.treeMap();

    // When making a flat model (no symantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (Pair<String, String> binary : binaryProductions) {
      String left = basicCategory(binary.first);
      String right = basicCategory(binary.second);
      if (binaryTransform.contains(left, right)) {
        continue;
      }
      binaryTransform.put(left, right, randomTransformMatrix());
      if (op.useTensors) {
        binaryTensors.put(left, right, randomBinaryTensor());
      }
      if (!op.combineClassification) {
        binaryClassification.put(left, right, randomClassificationMatrix());
      }
    }
    numBinaryMatrices = binaryTransform.size();
    binaryTransformSize = numHid * (2 * numHid + 1);
    if (op.useTensors) {
      binaryTensorSize = numHid * numHid * numHid * 4;
    } else {
      binaryTensorSize = 0;
    }
    binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1);

    unaryClassification = Generics.newTreeMap();

    // When making a flat model (no symantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (String unary : unaryProductions) {
      unary = basicCategory(unary);
      if (unaryClassification.containsKey(unary)) {
        continue;
      }
      unaryClassification.put(unary, randomClassificationMatrix());
    }
    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numClasses * (numHid + 1);

    // System.err.println("Binary transform matrices:");
    // System.err.println(binaryTransform);
    // System.err.println("Binary classification matrices:");
    // System.err.println(binaryClassification);
    // System.err.println("Unary classification matrices:");
    // System.err.println(unaryClassification);
  }

  SimpleTensor randomBinaryTensor() {
    double range = 1.0 / (4.0 * numHid);
    SimpleTensor tensor = SimpleTensor.random(numHid * 2, numHid * 2, numHid, -range, range, rand);
    return tensor.scale(op.trainOptions.scalingForInit);
  }

  SimpleMatrix randomTransformMatrix() {
    SimpleMatrix binary = new SimpleMatrix(numHid, numHid * 2 + 1);
    // bias column values are initialized zero
    binary.insertIntoThis(0, 0, randomTransformBlock());
    binary.insertIntoThis(0, numHid, randomTransformBlock());
    return binary.scale(op.trainOptions.scalingForInit);
  }

  SimpleMatrix randomTransformBlock() {
    double range = 1.0 / (Math.sqrt((double) numHid) * 2.0);
    return SimpleMatrix.random(numHid,numHid,-range,range,rand).plus(identity);
  }

  /**
   * Returns matrices of the right size for either binary or unary (terminal) classification
   */
  SimpleMatrix randomClassificationMatrix() {
    SimpleMatrix score = new SimpleMatrix(numClasses, numHid + 1);
    // Leave the bias column with 0 values
    double range = 1.0 / (Math.sqrt((double) numHid));
    score.insertIntoThis(0, 0, SimpleMatrix.random(numClasses, numHid, -range, range, rand));
    return score.scale(op.trainOptions.scalingForInit);
  }

  SimpleMatrix randomWordVector() {
    return randomWordVector(op.numHid, rand);
  }

  static SimpleMatrix randomWordVector(int size, Random rand) {
    return NeuralUtils.randomGaussian(size, 1, rand);
  }

  void initRandomWordVectors(List<Tree> trainingTrees) {
    if (op.numHid == 0) {
      throw new RuntimeException("Cannot create random word vectors for an unknown numHid");
    }
    Set<String> words = Generics.newHashSet();
    words.add(UNKNOWN_WORD);
    for (Tree tree : trainingTrees) {
      List<Tree> leaves = tree.getLeaves();
      for (Tree leaf : leaves) {
        String word = leaf.label().value();
        if (op.lowercaseWordVectors) {
          word = word.toLowerCase();
        }
        words.add(word);
      }
    }
    this.wordVectors = Generics.newTreeMap();
    for (String word : words) {
      SimpleMatrix vector = randomWordVector();
      wordVectors.put(word, vector);
    }
  }

  void readWordVectors() {
    Embedding embedding = new Embedding(op.wordVectors, op.numHid);
    this.wordVectors = Generics.newTreeMap();
//    Map<String, SimpleMatrix> rawWordVectors = NeuralUtils.readRawWordVectors(op.wordVectors, op.numHid);
//    for (String word : rawWordVectors.keySet()) {
    for (String word : embedding.keySet()) {
      // TODO: factor out unknown word vector code from DVParser
      wordVectors.put(word, embedding.get(word));
    }

    String unkWord = op.unkWord;
    SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
    wordVectors.put(UNKNOWN_WORD, unknownWordVector);
    if (unknownWordVector == null) {
      throw new RuntimeException("Unknown word vector not specified in the word vector file");
    }

  }

  public int totalParamSize() {
    int totalSize = 0;
    // binaryTensorSize was set to 0 if useTensors=false
    totalSize = numBinaryMatrices * (binaryTransformSize + binaryClassificationSize + binaryTensorSize);
    totalSize += numUnaryMatrices * unaryClassificationSize;
    totalSize += wordVectors.size() * numHid;
    return totalSize;
  }

  public double[] paramsToVector() {
    int totalSize = totalParamSize();
    return NeuralUtils.paramsToVector(totalSize, binaryTransform.valueIterator(), binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensors.valueIterator()), unaryClassification.values().iterator(), wordVectors.values().iterator());
  }

  public void vectorToParams(double[] theta) {
    NeuralUtils.vectorToParams(theta, binaryTransform.valueIterator(), binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensors.valueIterator()), unaryClassification.values().iterator(), wordVectors.values().iterator());
  }

  // TODO: combine this and getClassWForNode?
  public SimpleMatrix getWForNode(Tree node) {
    if (node.children().length == 2) {
      String leftLabel = node.children()[0].value();
      String leftBasic = basicCategory(leftLabel);
      String rightLabel = node.children()[1].value();
      String rightBasic = basicCategory(rightLabel);
      return binaryTransform.get(leftBasic, rightBasic);
    } else if (node.children().length == 1) {
      throw new AssertionError("No unary transform matrices, only unary classification");
    } else {
      throw new AssertionError("Unexpected tree children size of " + node.children().length);
    }
  }

  public SimpleTensor getTensorForNode(Tree node) {
    if (!op.useTensors) {
      throw new AssertionError("Not using tensors");
    }
    if (node.children().length == 2) {
      String leftLabel = node.children()[0].value();
      String leftBasic = basicCategory(leftLabel);
      String rightLabel = node.children()[1].value();
      String rightBasic = basicCategory(rightLabel);
      return binaryTensors.get(leftBasic, rightBasic);
    } else if (node.children().length == 1) {
      throw new AssertionError("No unary transform matrices, only unary classification");
    } else {
      throw new AssertionError("Unexpected tree children size of " + node.children().length);
    }
  }

  public SimpleMatrix getClassWForNode(Tree node) {
    if (op.combineClassification) {
      return unaryClassification.get("");
    } else if (node.children().length == 2) {
      String leftLabel = node.children()[0].value();
      String leftBasic = basicCategory(leftLabel);
      String rightLabel = node.children()[1].value();
      String rightBasic = basicCategory(rightLabel);
      return binaryClassification.get(leftBasic, rightBasic);
    } else if (node.children().length == 1) {
      String unaryLabel = node.children()[0].value();
      String unaryBasic = basicCategory(unaryLabel);
      return unaryClassification.get(unaryBasic);
    } else {
      throw new AssertionError("Unexpected tree children size of " + node.children().length);
    }
  }

  public SimpleMatrix getWordVector(String word) {
    return wordVectors.get(getVocabWord(word));
  }

  public String getVocabWord(String word) {
    if (op.lowercaseWordVectors) {
      word = word.toLowerCase();
    }
    if (wordVectors.containsKey(word)) {
      return word;
    }
    // TODO: go through unknown words here
    return UNKNOWN_WORD;
  }

  public String basicCategory(String category) {
    if (op.simplifiedModel) {
      return "";
    }
    String basic = op.langpack.basicCategory(category);
    if (basic.length() > 0 && basic.charAt(0) == '@') {
      basic = basic.substring(1);
    }
    return basic;
  }

  public SimpleMatrix getUnaryClassification(String category) {
    category = basicCategory(category);
    return unaryClassification.get(category);
  }

  public SimpleMatrix getBinaryClassification(String left, String right) {
    if (op.combineClassification) {
      return unaryClassification.get("");
    } else {
      left = basicCategory(left);
      right = basicCategory(right);
      return binaryClassification.get(left, right);
    }
  }

  public SimpleMatrix getBinaryTransform(String left, String right) {
    left = basicCategory(left);
    right = basicCategory(right);
    return binaryTransform.get(left, right);
  }

  public SimpleTensor getBinaryTensor(String left, String right) {
    left = basicCategory(left);
    right = basicCategory(right);
    return binaryTensors.get(left, right);
  }

  public void saveSerialized(String path) {
    try {
      IOUtils.writeObjectToFile(this, path);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    }
  }

  public static SentimentModel loadSerialized(String path) {
    try {
      return IOUtils.readObjectFromURLOrClasspathOrFileSystem(path);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeIOException(e);
    }
  }

  public void printParamInformation(int index) {
    int curIndex = 0;
    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryTransform) {
      if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
        System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTransform \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
        return;
      } else {
        curIndex += entry.getValue().getNumElements();
      }
    }

    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryClassification) {
      if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
        System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryClassification \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
        return;
      } else {
        curIndex += entry.getValue().getNumElements();
      }
    }

    for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : binaryTensors) {
      if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
        System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTensor \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
        return;
      } else {
        curIndex += entry.getValue().getNumElements();
      }
    }

    for (Map.Entry<String, SimpleMatrix> entry : unaryClassification.entrySet()) {
      if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
        System.err.println("Index " + index + " is element " + (index - curIndex) + " of unaryClassification \"" + entry.getKey() + "\"");
        return;
      } else {
        curIndex += entry.getValue().getNumElements();
      }
    }

    for (Map.Entry<String, SimpleMatrix> entry : wordVectors.entrySet()) {
      if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
        System.err.println("Index " + index + " is element " + (index - curIndex) + " of wordVector \"" + entry.getKey() + "\"");
        return;
      } else {
        curIndex += entry.getValue().getNumElements();
      }
    }

    System.err.println("Index " + index + " is beyond the length of the parameters; total parameter space was " + totalParamSize());
  }

  private static final long serialVersionUID = 1;
}
TOP

Related Classes of edu.stanford.nlp.sentiment.SentimentModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.