Package edu.stanford.nlp.sentiment

Source Code of edu.stanford.nlp.sentiment.SentimentCostAndGradient

package edu.stanford.nlp.sentiment;

import java.util.List;
import java.util.Map;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;

// TODO: get rid of the word Sentiment everywhere
public class SentimentCostAndGradient extends AbstractCachingDiffFunction {

  private final SentimentModel model;
  private final List<Tree> trainingBatch;

  public SentimentCostAndGradient(SentimentModel model, List<Tree> trainingBatch) {
    this.model = model;
    this.trainingBatch = trainingBatch;
  }

  @Override
  public int domainDimension() {
    // TODO: cache this for speed?
    return model.totalParamSize();
  }

  private static double sumError(Tree tree) {
    if (tree.isLeaf()) {
      return 0.0;
    } else if (tree.isPreTerminal()) {
      return RNNCoreAnnotations.getPredictionError(tree);
    } else {
      double error = 0.0;
      for (Tree child : tree.children()) {
        error += sumError(child);
      }
      return RNNCoreAnnotations.getPredictionError(tree) + error;
    }
  }

  /**
   * Returns the index with the highest value in the <code>predictions</code> matrix.
   * Indexed from 0.
   */
  private static int getPredictedClass(SimpleMatrix predictions) {
    int argmax = 0;
    for (int i = 1; i < predictions.getNumElements(); ++i) {
      if (predictions.get(i) > predictions.get(argmax)) {
        argmax = i;
      }
    }
    return argmax;
  }

  @Override
  public void calculate(double[] theta) {
    model.vectorToParams(theta);

    // double localValue = 0.0;
    // double[] localDerivative = new double[theta.length];

    // We use TreeMap for each of these so that they stay in a
    // canonical sorted order
    // TODO: factor out the initialization routines
    // binaryTD stands for Transform Derivatives (see the SentimentModel)
    TwoDimensionalMap<String, String, SimpleMatrix> binaryTD = TwoDimensionalMap.treeMap();
    // the derivatives of the tensors for the binary nodes
    TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD = TwoDimensionalMap.treeMap();
    // binaryCD stands for Classification Derivatives
    TwoDimensionalMap<String, String, SimpleMatrix> binaryCD = TwoDimensionalMap.treeMap();

    // unaryCD stands for Classification Derivatives
    Map<String, SimpleMatrix> unaryCD = Generics.newTreeMap();

    // word vector derivatives
    Map<String, SimpleMatrix> wordVectorD = Generics.newTreeMap();

    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : model.binaryTransform) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();

      binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
    }

    if (!model.op.combineClassification) {
      for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : model.binaryClassification) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();

        binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      }
    }

    if (model.op.useTensors) {
      for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : model.binaryTensors) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();
        int numSlices = entry.getValue().numSlices();

        binaryTensorTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleTensor(numRows, numCols, numSlices));
      }
    }

    for (Map.Entry<String, SimpleMatrix> entry : model.unaryClassification.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      unaryCD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
    }
    for (Map.Entry<String, SimpleMatrix> entry : model.wordVectors.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      wordVectorD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
    }

    // TODO: This part can easily be parallelized
    List<Tree> forwardPropTrees = Generics.newArrayList();
    for (Tree tree : trainingBatch) {
      Tree trainingTree = tree.deepCopy();
      // this will attach the error vectors and the node vectors
      // to each node in the tree
      forwardPropagateTree(trainingTree);
      forwardPropTrees.add(trainingTree);
    }

    // TODO: we may find a big speedup by separating the derivatives and then summing
    double error = 0.0;
    for (Tree tree : forwardPropTrees) {
      backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD);
      error += sumError(tree);
    }

    // scale the error by the number of sentences so that the
    // regularization isn't drowned out for large training batchs
    double scale = (1.0 / trainingBatch.size());
    value = error * scale;

    value += scaleAndRegularize(binaryTD, model.binaryTransform, scale, model.op.trainOptions.regTransformMatrix);
    value += scaleAndRegularize(binaryCD, model.binaryClassification, scale, model.op.trainOptions.regClassification);
    value += scaleAndRegularizeTensor(binaryTensorTD, model.binaryTensors, scale, model.op.trainOptions.regTransformTensor);
    value += scaleAndRegularize(unaryCD, model.unaryClassification, scale, model.op.trainOptions.regClassification);
    value += scaleAndRegularize(wordVectorD, model.wordVectors, scale, model.op.trainOptions.regWordVector);

    derivative = NeuralUtils.paramsToVector(theta.length, binaryTD.valueIterator(), binaryCD.valueIterator(), SimpleTensor.iteratorSimpleMatrix(binaryTensorTD.valueIterator()), unaryCD.values().iterator(), wordVectorD.values().iterator());
  }

  static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives,
                            TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices,
                            double scale,
                            double regCost) {
    double cost = 0.0; // the regularization cost
    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
      SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
      D = D.scale(scale).plus(entry.getValue().scale(regCost));
      derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
      cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
    }
    return cost;
  }

  static double scaleAndRegularize(Map<String, SimpleMatrix> derivatives,
                            Map<String, SimpleMatrix> currentMatrices,
                            double scale,
                            double regCost) {
    double cost = 0.0; // the regularization cost
    for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
      SimpleMatrix D = derivatives.get(entry.getKey());
      D = D.scale(scale).plus(entry.getValue().scale(regCost));
      derivatives.put(entry.getKey(), D);
      cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
    }
    return cost;
  }

  static double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> derivatives,
                                  TwoDimensionalMap<String, String, SimpleTensor> currentMatrices,
                                  double scale,
                                  double regCost) {
    double cost = 0.0; // the regularization cost
    for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : currentMatrices) {
      SimpleTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
      D = D.scale(scale).plus(entry.getValue().scale(regCost));
      derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
      cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
    }
    return cost;
  }

  private void backpropDerivativesAndError(Tree tree,
                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
                                           TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD,
                                           Map<String, SimpleMatrix> unaryCD,
                                           Map<String, SimpleMatrix> wordVectorD) {
    SimpleMatrix delta = new SimpleMatrix(model.op.numHid, 1);
    backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
  }

  private void backpropDerivativesAndError(Tree tree,
                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
                                           TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD,
                                           Map<String, SimpleMatrix> unaryCD,
                                           Map<String, SimpleMatrix> wordVectorD,
                                           SimpleMatrix deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    SimpleMatrix currentVector = RNNCoreAnnotations.getNodeVector(tree);
    String category = tree.label().value();
    category = model.basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    SimpleMatrix goldLabel = new SimpleMatrix(model.numClasses, 1);
    int goldClass = RNNCoreAnnotations.getGoldClass(tree);
    if (goldClass >= 0) {
      goldLabel.set(goldClass, 1.0);
    }

    double nodeWeight = model.op.trainOptions.getClassWeight(goldClass);

    SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);

    // If this is an unlabeled class, set deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    SimpleMatrix deltaClass = goldClass >= 0 ? predictions.minus(goldLabel).scale(nodeWeight) : new SimpleMatrix(predictions.numRows(), predictions.numCols());
    SimpleMatrix localCD = deltaClass.mult(NeuralUtils.concatenateWithBias(currentVector).transpose());

    double error = -(NeuralUtils.elementwiseApplyLog(predictions).elementMult(goldLabel).elementSum());
    error = error * nodeWeight;
    RNNCoreAnnotations.setPredictionError(tree, error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).plus(localCD));

      String word = tree.children()[0].label().value();
      word = model.getVocabWord(word);

      //SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
      //SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
      //SimpleMatrix deltaFull = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).plus(deltaUp);
      //SimpleMatrix wordDerivative = deltaFull.elementMult(currentVectorDerivative);
      //wordVectorD.put(word, wordVectorD.get(word).plus(wordDerivative));

      SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
      SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
      deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).elementMult(currentVectorDerivative);
      SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);
      wordVectorD.put(word, wordVectorD.get(word).plus(deltaFull));
    } else {
      // Otherwise, this must be a binary node
      String leftCategory = model.basicCategory(tree.children()[0].label().value());
      String rightCategory = model.basicCategory(tree.children()[1].label().value());
      if (model.op.combineClassification) {
        unaryCD.put("", unaryCD.get("").plus(localCD));
      } else {
        binaryCD.put(leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).plus(localCD));
      }

      SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
      SimpleMatrix deltaFromClass = model.getBinaryClassification(leftCategory, rightCategory).transpose().mult(deltaClass);
      deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).elementMult(currentVectorDerivative);
      SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);

      SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
      SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
      SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
      SimpleMatrix W_df = deltaFull.mult(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, binaryTD.get(leftCategory, rightCategory).plus(W_df));
      SimpleMatrix deltaDown;
      if (model.op.useTensors) {
        SimpleTensor Wt_df = getTensorGradient(deltaFull, leftVector, rightVector);
        binaryTensorTD.put(leftCategory, rightCategory, binaryTensorTD.get(leftCategory, rightCategory).plus(Wt_df));
        deltaDown = computeTensorDeltaDown(deltaFull, leftVector, rightVector, model.getBinaryTransform(leftCategory, rightCategory), model.getBinaryTensor(leftCategory, rightCategory));
      } else {
        deltaDown = model.getBinaryTransform(leftCategory, rightCategory).transpose().mult(deltaFull);
      }

      SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
      SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
      SimpleMatrix leftDeltaDown = deltaDown.extractMatrix(0, deltaFull.numRows(), 0, 1);
      SimpleMatrix rightDeltaDown = deltaDown.extractMatrix(deltaFull.numRows(), deltaFull.numRows() * 2, 0, 1);
      backpropDerivativesAndError(tree.children()[0], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, leftDerivative.elementMult(leftDeltaDown));
      backpropDerivativesAndError(tree.children()[1], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, rightDerivative.elementMult(rightDeltaDown));
    }
  }

  private static SimpleMatrix computeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector,
                                              SimpleMatrix W, SimpleTensor Wt) {
    SimpleMatrix WTDelta = W.transpose().mult(deltaFull);
    SimpleMatrix WTDeltaNoBias = WTDelta.extractMatrix(0, deltaFull.numRows() * 2, 0, 1);
    int size = deltaFull.getNumElements();
    SimpleMatrix deltaTensor = new SimpleMatrix(size*2, 1);
    SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
    for (int slice = 0; slice < size; ++slice) {
      SimpleMatrix scaledFullVector = fullVector.scale(deltaFull.get(slice));
      deltaTensor = deltaTensor.plus(Wt.getSlice(slice).plus(Wt.getSlice(slice).transpose()).mult(scaledFullVector));
    }
    return deltaTensor.plus(WTDeltaNoBias);
  }

  private static SimpleTensor getTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector) {
    int size = deltaFull.getNumElements();
    SimpleTensor Wt_df = new SimpleTensor(size*2, size*2, size);
    // TODO: combine this concatenation with computeTensorDeltaDown?
    SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
    for (int slice = 0; slice < size; ++slice) {
      Wt_df.setSlice(slice, fullVector.scale(deltaFull.get(slice)).mult(fullVector.transpose()));
    }
    return Wt_df;
  }

  /**
   * This is the method to call for assigning labels and node vectors
   * to the Tree.  After calling this, each of the non-leaf nodes will
   * have the node vector and the predictions of their classes
   * assigned to that subtree's node.  The annotations filled in are
   * the RNNCoreAnnotations.NodeVector, Predictions, and
   * PredictedClass.  In general, PredictedClass will be the most
   * useful annotation except when training.
   */
  public void forwardPropagateTree(Tree tree) {
    SimpleMatrix nodeVector = null;
    SimpleMatrix classification = null;

    if (tree.isLeaf()) {
      // We do nothing for the leaves.  The preterminals will
      // calculate the classification for this word/tag.  In fact, the
      // recursion should not have gotten here (unless there are
      // degenerate trees of just one leaf)
      throw new AssertionError("We should not have reached leaves in forwardPropagate");
    } else if (tree.isPreTerminal()) {
      classification = model.getUnaryClassification(tree.label().value());
      String word = tree.children()[0].label().value();
      SimpleMatrix wordVector = model.getWordVector(word);
      nodeVector = NeuralUtils.elementwiseApplyTanh(wordVector);
    } else if (tree.children().length == 1) {
      throw new AssertionError("Non-preterminal nodes of size 1 should have already been collapsed");
    } else if (tree.children().length == 2) {
      forwardPropagateTree(tree.children()[0]);
      forwardPropagateTree(tree.children()[1]);

      String leftCategory = tree.children()[0].label().value();
      String rightCategory = tree.children()[1].label().value();
      SimpleMatrix W = model.getBinaryTransform(leftCategory, rightCategory);
      classification = model.getBinaryClassification(leftCategory, rightCategory);

      SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
      SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
      SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
      if (model.op.useTensors) {
        SimpleTensor tensor = model.getBinaryTensor(leftCategory, rightCategory);
        SimpleMatrix tensorIn = NeuralUtils.concatenate(leftVector, rightVector);
        SimpleMatrix tensorOut = tensor.bilinearProducts(tensorIn);
        nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector).plus(tensorOut));
      } else {
        nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector));
      }
    } else {
      throw new AssertionError("Tree not correctly binarized");
    }

    SimpleMatrix predictions = NeuralUtils.softmax(classification.mult(NeuralUtils.concatenateWithBias(nodeVector)));

    int index = getPredictedClass(predictions);
    if (!(tree.label() instanceof CoreLabel)) {
      throw new AssertionError("Expected CoreLabels in the nodes");
    }
    CoreLabel label = (CoreLabel) tree.label();
    label.set(RNNCoreAnnotations.Predictions.class, predictions);
    label.set(RNNCoreAnnotations.PredictedClass.class, index);
    label.set(RNNCoreAnnotations.NodeVector.class, nodeVector);
  } // end forwardPropagateTree

}
TOP

Related Classes of edu.stanford.nlp.sentiment.SentimentCostAndGradient

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.