Package edu.stanford.nlp.sentiment

Source Code of edu.stanford.nlp.sentiment.ConvertMatlabModel

package edu.stanford.nlp.sentiment;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Random;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.util.Generics;

/**
* This tool is of very limited scope: it converts a model built with
* the Matlab version of the code to the Java version of the code.  It
* is useful to save this tool in case the format of the Java model
* changes, in which case this will let us easily recreate it.
* <br>
* Another set of matrices is in <br>
* /u/nlp/data/sentiment/binary/model_binary_best_asTextFiles/
*
* @author John Bauer
*/
public class ConvertMatlabModel {

  private ConvertMatlabModel() {} // static class

  /** Will not overwrite an existing word vector if it is already there */
  public static void copyWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
    if (wordVectors.containsKey(target) || !wordVectors.containsKey(source)) {
      return;
    }

    System.err.println("Using wordVector " + source + " for " + target);

    wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
  }

  /** <br>Will</br> overwrite an existing word vector */
  public static void replaceWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
    if (!wordVectors.containsKey(source)) {
      return;
    }

    wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
  }

  public static SimpleMatrix loadMatrix(String binaryName, String textName) throws IOException {
    File matrixFile = new File(binaryName);
    if (matrixFile.exists()) {
      return SimpleMatrix.loadBinary(matrixFile.getPath());
    }

    matrixFile = new File(textName);
    if (matrixFile.exists()) {
      return NeuralUtils.loadTextMatrix(matrixFile);
    }

    throw new RuntimeException("Could not find either " + binaryName + " or " + textName);
  }

  public static void main(String[] args) throws IOException {
    String basePath = "/user/socherr/scr/projects/semComp/RNTN/src/params/";
    int numSlices = 25;

    boolean useEscapedParens = false;

    for (int argIndex = 0; argIndex < args.length; ) {
      if (args[argIndex].equalsIgnoreCase("-slices")) {
        numSlices = Integer.parseInt(args[argIndex + 1]);
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-path")) {
        basePath = args[argIndex + 1];
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-useEscapedParens")) {
        useEscapedParens = true;
        argIndex += 1;
      } else {
        System.err.println("Unknown argument " + args[argIndex]);
        System.exit(2);
      }
    }

    SimpleMatrix[] slices = new SimpleMatrix[numSlices];
    for (int i = 0; i < numSlices; ++i) {
      slices[i] = loadMatrix(basePath + "bin/Wt_" + (i + 1) + ".bin", basePath + "Wt_" + (i + 1) + ".txt");
    }
    SimpleTensor tensor = new SimpleTensor(slices);
    System.err.println("W tensor size: " + tensor.numRows() + "x" + tensor.numCols() + "x" + tensor.numSlices());

    SimpleMatrix W = loadMatrix(basePath + "bin/W.bin", basePath + "W.txt");
    System.err.println("W matrix size: " + W.numRows() + "x" + W.numCols());

    SimpleMatrix Wcat = loadMatrix(basePath + "bin/Wcat.bin", basePath + "Wcat.txt");
    System.err.println("W cat size: " + Wcat.numRows() + "x" + Wcat.numCols());

    SimpleMatrix combinedWV = loadMatrix(basePath + "bin/Wv.bin", basePath + "Wv.txt");
    System.err.println("Word matrix size: " + combinedWV.numRows() + "x" + combinedWV.numCols());

    File vocabFile = new File(basePath + "vocab_1.txt");
    if (!vocabFile.exists()) {
      vocabFile = new File(basePath + "words.txt");
    }
    List<String> lines = Generics.newArrayList();
    for (String line : IOUtils.readLines(vocabFile)) {
      lines.add(line.trim());
    }

    System.err.println("Lines in vocab file: " + lines.size());

    Map<String, SimpleMatrix> wordVectors = Generics.newTreeMap();

    for (int i = 0; i < lines.size() && i < combinedWV.numCols(); ++i) {
      String[] pieces = lines.get(i).split(" +");
      if (pieces.length == 0 || pieces.length > 1) {
        continue;
      }
      wordVectors.put(pieces[0], combinedWV.extractMatrix(0, numSlices, i, i+1));
      if (pieces[0].equals("UNK")) {
        wordVectors.put(SentimentModel.UNKNOWN_WORD, wordVectors.get("UNK"));
      }
    }

    // If there is no ",", we first try to look for an HTML escaping,
    // then fall back to "." as better than just a random word vector.
    // Same for "``" and ";"
    copyWordVector(wordVectors, "&#44", ",");
    copyWordVector(wordVectors, ".", ",");
    copyWordVector(wordVectors, "&#59", ";");
    copyWordVector(wordVectors, ".", ";");
    copyWordVector(wordVectors, "&#96&#96", "``");
    copyWordVector(wordVectors, "''", "``");

    if (useEscapedParens) {
      replaceWordVector(wordVectors, "(", "-LRB-");
      replaceWordVector(wordVectors, ")", "-RRB-");
    }

    RNNOptions op = new RNNOptions();
    op.numHid = numSlices;
    op.lowercaseWordVectors = false;

    if (Wcat.numRows() == 2) {
      op.classNames = new String[] { "Negative", "Positive" };
      op.equivalenceClasses = new int[][] { { 0 }, { 1 } }; // TODO: set to null once old models are updated
      op.numClasses = 2;
    }

    if (!wordVectors.containsKey(SentimentModel.UNKNOWN_WORD)) {
      wordVectors.put(SentimentModel.UNKNOWN_WORD, SimpleMatrix.random(numSlices, 1, -0.00001, 0.00001, new Random()));
    }

    SentimentModel model = SentimentModel.modelFromMatrices(W, Wcat, tensor, wordVectors, op);
    model.saveSerialized("matlab.ser.gz");
  }
}
TOP

Related Classes of edu.stanford.nlp.sentiment.ConvertMatlabModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.