Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.SplittingGrammarExtractor

package edu.stanford.nlp.parser.lexparser;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ThreeDimensionalMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.TwoDimensionalMap;


import java.io.*;

/**
* This class is a reimplementation of Berkeley's state splitting
* grammar.  This work is experimental and still in progress.  There
* are several extremely important pieces to implement:
* <ol>
* <li> this code should use log probabilities throughout instead of
*      multiplying tiny numbers
* <li> time efficiency of the training code is fawful
* <li> there are better ways to extract parses using this grammar than
*      the method in ExhaustivePCFGParser
* <li> we should also implement cascading parsers that let us
*      shortcircuit low quality parses earlier (which could possibly
*      benefit non-split parsers as well)
* <li> when looping, we should short circuit if we go too many loops
* <li> ought to smooth as per page 436
* </ol>
*
* @author John Bauer
*/
public class SplittingGrammarExtractor {
  static final int MIN_DEBUG_ITERATION=0;
  static final int MAX_DEBUG_ITERATION=0;
  static final int MAX_ITERATIONS = Integer.MAX_VALUE;

  int iteration = 0;

  boolean DEBUG() {
    return (iteration >= MIN_DEBUG_ITERATION && iteration < MAX_DEBUG_ITERATION);
  }

  Options op;
  /**
   * These objects are created and filled in here.  The caller can get
   * the data from the extractor once it is finished.
   */
  Index<String> stateIndex;
  Index<String> wordIndex;
  Index<String> tagIndex;
  /**
   * This is a list gotten from the list of startSymbols in op.langpack()
   */
  List<String> startSymbols;

  /**
   * A combined list of all the trees in the training set.
   */
  List<Tree> trees = new ArrayList<Tree>();

  /**
   * All of the weights associated with the trees in the training set.
   * In general, this is just the weight of the original treebank.
   * Note that this uses an identity hash map to map from tree pointer
   * to weight.
   */
  Counter<Tree> treeWeights = new ClassicCounter<Tree>(MapFactory.<Tree,MutableDouble>identityHashMapFactory());

  /**
   * How many total weighted trees we have
   */
  double trainSize;

  /**
   * The original states in the trees
   */
  Set<String> originalStates = Generics.newHashSet();

  /**
   * The current number of times a particular state has been split
   */
  IntCounter<String> stateSplitCounts = new IntCounter<String>();

  /**
   * The binary betas are weights to go from Ax to By, Cz.  This maps
   * from (A, B, C) to (x, y, z) to beta(Ax, By, Cz).
   */
  ThreeDimensionalMap<String, String, String, double[][][]> binaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
  /**
   * The unary betas are weights to go from Ax to By.  This maps
   * from (A, B) to (x, y) to beta(Ax, By).
   */
  TwoDimensionalMap<String, String, double[][]> unaryBetas = new TwoDimensionalMap<String, String, double[][]>();

  /**
   * The latest lexicon we trained.  At the end of the process, this
   * is the lexicon for the parser.
   */
  Lexicon lex;

  transient Index<String> tempWordIndex;
  transient Index<String> tempTagIndex;

  /**
   * The lexicon we are in the process of building in each iteration.
   */
  transient Lexicon tempLex;

  /**
   * The latest pair of unary and binary grammars we trained.
   */
  Pair<UnaryGrammar, BinaryGrammar> bgug;

  Random random = new Random(87543875943265L);

  static final double LEX_SMOOTH = 0.0001;
  static final double STATE_SMOOTH = 0.0;

  public SplittingGrammarExtractor(Options op) {
    this.op = op;
    startSymbols = Arrays.asList(op.langpack().startSymbols());
  }

  double[] neginfDoubles(int size) {
    double[] result = new double[size];
    for (int i = 0; i < size; ++i) {
      result[i] = Double.NEGATIVE_INFINITY;
    }
    return result;
  }

  public void outputTransitions(Tree tree,
                                IdentityHashMap<Tree, double[][]> unaryTransitions,
                                IdentityHashMap<Tree, double[][][]> binaryTransitions) {
    outputTransitions(tree, 0, unaryTransitions, binaryTransitions);
  }

  public void outputTransitions(Tree tree, int depth,
                                IdentityHashMap<Tree, double[][]> unaryTransitions,
                                IdentityHashMap<Tree, double[][][]> binaryTransitions) {
    for (int i = 0; i < depth; ++i) {
      System.out.print(" ");
    }
    if (tree.isLeaf()) {
      System.out.println(tree.label().value());
      return;
    }
    if (tree.children().length == 1) {
      System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value());
      if (!tree.isPreTerminal()) {
        double[][] transitions = unaryTransitions.get(tree);
        for (int i = 0; i < transitions.length; ++i) {
          for (int j = 0; j < transitions[0].length; ++j) {
            for (int z = 0; z < depth; ++z) {
              System.out.print(" ");
            }
            System.out.println("  " + i + "," + j + ": " + transitions[i][j] + " | " + Math.exp(transitions[i][j]));
          }
        }
      }
    } else {
      System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value());
      double[][][] transitions = binaryTransitions.get(tree);
      for (int i = 0; i < transitions.length; ++i) {
        for (int j = 0; j < transitions[0].length; ++j) {
          for (int k = 0; k < transitions[0][0].length; ++k) {
            for (int z = 0; z < depth; ++z) {
              System.out.print(" ");
            }
            System.out.println("  " + i + "," + j + "," + k + ": " + transitions[i][j][k] + " | " + Math.exp(transitions[i][j][k]));
          }
        }
      }
    }
    if (tree.isPreTerminal()) {
      return;
    }
    for (Tree child : tree.children()) {
      outputTransitions(child, depth + 1, unaryTransitions, binaryTransitions);
    }
  }

  public void outputBetas() {
    System.out.println("UNARY:");
    for (String parent : unaryBetas.firstKeySet()) {
      for (String child : unaryBetas.get(parent).keySet()) {
        System.out.println("  " + parent + "->" + child);
        double[][] betas = unaryBetas.get(parent).get(child);
        int parentStates = betas.length;
        int childStates = betas[0].length;
        for (int i = 0; i < parentStates; ++i) {
          for (int j = 0; j < childStates; ++j) {
            System.out.println("    " + i + "->" + j + " " + betas[i][j] + " | " + Math.exp(betas[i][j]));
          }
        }
      }
    }
    System.out.println("BINARY:");
    for (String parent : binaryBetas.firstKeySet()) {
      for (String left : binaryBetas.get(parent).firstKeySet()) {
        for (String right : binaryBetas.get(parent).get(left).keySet()) {
          System.out.println("  " + parent + "->" + left + "," + right);
          double[][][] betas = binaryBetas.get(parent).get(left).get(right);
          int parentStates = betas.length;
          int leftStates = betas[0].length;
          int rightStates = betas[0][0].length;
          for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < leftStates; ++j) {
              for (int k = 0; k < rightStates; ++k) {
                System.out.println("    " + i + "->" + j + "," + k + " " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
              }
            }
          }
        }
      }
    }
  }

  public String state(String tag, int i) {
    if (startSymbols.contains(tag) || tag.equals(Lexicon.BOUNDARY_TAG)) {
      return tag;
    }
    return tag + "^" + i;
  }

  public int getStateSplitCount(Tree tree) {
    return stateSplitCounts.getIntCount(tree.label().value());
  }

  public int getStateSplitCount(String label) {
    return stateSplitCounts.getIntCount(label);
  }


  /**
   * Count all the internal labels in all the trees, and set their
   * initial state counts to 1.
   */
  public void countOriginalStates() {
    originalStates.clear();
    for (Tree tree : trees) {
      countOriginalStates(tree);
    }

    for (String state : originalStates) {
      stateSplitCounts.incrementCount(state, 1);
    }
  }

  /**
   * Counts the labels in the tree, but not the words themselves.
   */
  private void countOriginalStates(Tree tree) {
    if (tree.isLeaf()) {
      return;
    }

    originalStates.add(tree.label().value());
    for (Tree child : tree.children()) {
      if (child.isLeaf())
        continue;
      countOriginalStates(child);
    }
  }

  private void initialBetasAndLexicon() {
    wordIndex = new HashIndex<String>();
    tagIndex = new HashIndex<String>();
    lex = op.tlpParams.lex(op, wordIndex, tagIndex);
    lex.initializeTraining(trainSize);

    for (Tree tree : trees) {
      double weight = treeWeights.getCount(tree);
      lex.incrementTreesRead(weight);
      initialBetasAndLexicon(tree, 0, weight);
    }

    lex.finishTraining();
  }

  private int initialBetasAndLexicon(Tree tree, int position, double weight) {
    if (tree.isLeaf()) {
      // should never get here, unless a training tree is just one leaf
      return position;
    }

    if (tree.isPreTerminal()) {
      // fill in initial lexicon here
      String tag = tree.label().value();
      String word = tree.children()[0].label().value();
      TaggedWord tw = new TaggedWord(word, state(tag, 0));
      lex.train(tw, position, weight);
      return (position + 1);
    }

    if (tree.children().length == 2) {
      String label = tree.label().value();
      String leftLabel = tree.getChild(0).label().value();
      String rightLabel = tree.getChild(1).label().value();
      if (!binaryBetas.contains(label, leftLabel, rightLabel)) {
        double[][][] map = new double[1][1][1];
        map[0][0][0] = 0.0;
        binaryBetas.put(label, leftLabel, rightLabel, map);
      }
    } else if (tree.children().length == 1) {
      String label = tree.label().value();
      String childLabel = tree.getChild(0).label().value();
      if (!unaryBetas.contains(label, childLabel)) {
        double[][] map = new double[1][1];
        map[0][0] = 0.0;
        unaryBetas.put(label, childLabel, map);
      }
    } else {
      // should have been binarized
      throw new RuntimeException("Trees should have been binarized, expected 1 or 2 children");
    }

    for (Tree child : tree.children()) {
      position = initialBetasAndLexicon(child, position, weight);
    }
    return position;
  }


  /**
   * Splits the state counts.  Root states and the boundary tag do not
   * get their counts increased, and all others are doubled.  Betas
   * and transition weights are handled later.
   */
  private void splitStateCounts() {
    // double the count of states...
    IntCounter<String> newStateSplitCounts = new IntCounter<String>();
    newStateSplitCounts.addAll(stateSplitCounts);
    newStateSplitCounts.addAll(stateSplitCounts);

    // root states should only have 1
    for (String root : startSymbols) {
      if (newStateSplitCounts.getCount(root) > 1) {
        newStateSplitCounts.setCount(root, 1);
      }
    }

    if (newStateSplitCounts.getCount(Lexicon.BOUNDARY_TAG) > 1) {
      newStateSplitCounts.setCount(Lexicon.BOUNDARY_TAG, 1);
    }

    stateSplitCounts = newStateSplitCounts;
  }


  static final double EPSILON = 0.0001;

  /**
   * Before each iteration of splitting states, we have tables of
   * betas which correspond to the transitions between different
   * substates.  When we resplit the states, we duplicate parent
   * states and then split their transitions 50/50 with some random
   * variation between child states.
   */
  public void splitBetas() {
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();

    for (String parent : unaryBetas.firstKeySet()) {
      for (String child : unaryBetas.get(parent).keySet()) {
        double[][] betas = unaryBetas.get(parent, child);
        int parentStates = betas.length;
        int childStates = betas[0].length;

        double[][] newBetas;
        if (!startSymbols.contains(parent)) {
          newBetas = new double[parentStates * 2][childStates];
          for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < childStates; ++j) {
              newBetas[i * 2][j] = betas[i][j];
              newBetas[i * 2 + 1][j] = betas[i][j];
            }
          }
          parentStates *= 2;
          betas = newBetas;
        }
        if (!child.equals(Lexicon.BOUNDARY_TAG)) {
          newBetas = new double[parentStates][childStates * 2];
          for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < childStates; ++j) {
              double childWeight = 0.45 + random.nextDouble() * 0.1;
              newBetas[i][j * 2] = betas[i][j] + Math.log(childWeight);
              newBetas[i][j * 2 + 1] = betas[i][j] + Math.log(1.0 - childWeight);
            }
          }
          betas = newBetas;
        }
        tempUnaryBetas.put(parent, child, betas);
      }
    }

    for (String parent : binaryBetas.firstKeySet()) {
      for (String left : binaryBetas.get(parent).firstKeySet()) {
        for (String right : binaryBetas.get(parent).get(left).keySet()) {
          double[][][] betas = binaryBetas.get(parent, left, right);
          int parentStates = betas.length;
          int leftStates = betas[0].length;
          int rightStates = betas[0][0].length;

          double[][][] newBetas;
          if (!startSymbols.contains(parent)) {
            newBetas = new double[parentStates * 2][leftStates][rightStates];
            for (int i = 0; i < parentStates; ++i) {
              for (int j = 0; j < leftStates; ++j) {
                for (int k = 0; k < rightStates; ++k) {
                  newBetas[i * 2][j][k] = betas[i][j][k];
                  newBetas[i * 2 + 1][j][k] = betas[i][j][k];
                }
              }
            }
            parentStates *= 2;
            betas = newBetas;
          }

          newBetas = new double[parentStates][leftStates * 2][rightStates];
          for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < leftStates; ++j) {
              for (int k = 0; k < rightStates; ++k) {
                double leftWeight = 0.45 + random.nextDouble() * 0.1;
                newBetas[i][j * 2][k] = betas[i][j][k] + Math.log(leftWeight);
                newBetas[i][j * 2 + 1][k] = betas[i][j][k] + Math.log(1 - leftWeight);
              }
            }
          }
          leftStates *= 2;
          betas = newBetas;

          if (!right.equals(Lexicon.BOUNDARY_TAG)) {
            newBetas = new double[parentStates][leftStates][rightStates * 2];
            for (int i = 0; i < parentStates; ++i) {
              for (int j = 0; j < leftStates; ++j) {
                for (int k = 0; k < rightStates; ++k) {
                  double rightWeight = 0.45 + random.nextDouble() * 0.1;
                  newBetas[i][j][k * 2] = betas[i][j][k] + Math.log(rightWeight);
                  newBetas[i][j][k * 2 + 1] = betas[i][j][k] + Math.log(1 - rightWeight);
                }
              }
            }
          }
          tempBinaryBetas.put(parent, left, right, newBetas);
        }
      }
    }
    unaryBetas = tempUnaryBetas;
    binaryBetas = tempBinaryBetas;
  }


  /**
   * Recalculates the betas for all known transitions.  The current
   * betas are used to produce probabilities, which then are used to
   * compute new betas.  If splitStates is true, then the
   * probabilities produced are as if the states were split again from
   * the last time betas were calculated.
   * <br>
   * The return value is whether or not the betas have mostly
   * converged from the last time this method was called.  Obviously
   * if splitStates was true, the betas will be entirely different, so
   * this is false.  Otherwise, the new betas are compared against the
   * old values, and convergence means they differ by less than
   * EPSILON.
   */
  public boolean recalculateBetas(boolean splitStates) {
    if (splitStates) {
      if (DEBUG()) {
        System.out.println("Pre-split betas");
        outputBetas();
      }
      splitBetas();
      if (DEBUG()) {
        System.out.println("Post-split betas");
        outputBetas();
      }
    }

    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();

    recalculateTemporaryBetas(splitStates, null, tempUnaryBetas, tempBinaryBetas);
    boolean converged = useNewBetas(!splitStates, tempUnaryBetas, tempBinaryBetas);

    if (DEBUG()) {
      outputBetas();
    }

    return converged;
  }

  public boolean useNewBetas(boolean testConverged,
                             TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                             ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
    rescaleTemporaryBetas(tempUnaryBetas, tempBinaryBetas);

    // if we just split states, we have obviously not converged
    boolean converged = testConverged && testConvergence(tempUnaryBetas, tempBinaryBetas);

    unaryBetas = tempUnaryBetas;
    binaryBetas = tempBinaryBetas;

    wordIndex = tempWordIndex;
    tagIndex = tempTagIndex;
    lex = tempLex;
    if (DEBUG()) {
      System.out.println("LEXICON");
      try {
        OutputStreamWriter osw = new OutputStreamWriter(System.out, "utf-8");
        lex.writeData(osw);
        osw.flush();
      } catch (IOException e) {
        throw new RuntimeIOException(e);
      }
    }
    tempWordIndex = null;
    tempTagIndex = null;
    tempLex = null;

    return converged;
  }

  /**
   * Creates temporary beta data structures and fills them in by
   * iterating over the trees.
   */
  public void recalculateTemporaryBetas(boolean splitStates, Map<String, double[]> totalStateMass,
                                        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                                        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
    tempWordIndex = new HashIndex<String>();
    tempTagIndex = new HashIndex<String>();
    tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
    tempLex.initializeTraining(trainSize);

    for (Tree tree : trees) {
      double weight = treeWeights.getCount(tree);
      if (DEBUG()) {
        System.out.println("Incrementing trees read: " + weight);
      }
      tempLex.incrementTreesRead(weight);
      recalculateTemporaryBetas(tree, splitStates, totalStateMass, tempUnaryBetas, tempBinaryBetas);
    }

    tempLex.finishTraining();
  }

  public boolean testConvergence(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                                 ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {

    // now, we check each of the new betas to see if it's close to the
    // old value for the same transition.  if not, we have not yet
    // converged.  if all of them are, we have converged.
    for (String parentLabel : unaryBetas.firstKeySet()) {
      for (String childLabel : unaryBetas.get(parentLabel).keySet()) {
        double[][] betas = unaryBetas.get(parentLabel, childLabel);
        double[][] newBetas = tempUnaryBetas.get(parentLabel, childLabel);
        int parentStates = betas.length;
        int childStates = betas[0].length;
        for (int i = 0; i < parentStates; ++i) {
          for (int j = 0; j < childStates; ++j) {
            double oldValue = betas[i][j];
            double newValue = newBetas[i][j];
            if (Math.abs(newValue - oldValue) > EPSILON) {
              return false;
            }
          }
        }
      }
    }
    for (String parentLabel : binaryBetas.firstKeySet()) {
      for (String leftLabel : binaryBetas.get(parentLabel).firstKeySet()) {
        for (String rightLabel : binaryBetas.get(parentLabel).get(leftLabel).keySet()) {
          double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
          double[][][] newBetas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
          int parentStates = betas.length;
          int leftStates = betas[0].length;
          int rightStates = betas[0][0].length;
          for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < leftStates; ++j) {
              for (int k = 0; k < rightStates; ++k) {
                double oldValue = betas[i][j][k];
                double newValue = newBetas[i][j][k];
                if (Math.abs(newValue - oldValue) > EPSILON) {
                  return false;
                }
              }
            }
          }
        }
      }
    }

    return true;
  }

  public void recalculateTemporaryBetas(Tree tree, boolean splitStates,
                                        Map<String, double[]> totalStateMass,
                                        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                                        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
    if (DEBUG()) {
      System.out.println("Recalculating temporary betas for tree " + tree);
    }
    double[] stateWeights = { Math.log(treeWeights.getCount(tree)) };

    IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
    IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
    recountTree(tree, splitStates, unaryTransitions, binaryTransitions);

    if (DEBUG()) {
      System.out.println("  Transitions:");
      outputTransitions(tree, unaryTransitions, binaryTransitions);
    }

    recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions,
                              totalStateMass, tempUnaryBetas, tempBinaryBetas);
  }

  public int recalculateTemporaryBetas(Tree tree, double[] stateWeights, int position,
                                       IdentityHashMap<Tree, double[][]> unaryTransitions,
                                       IdentityHashMap<Tree, double[][][]> binaryTransitions,
                                       Map<String, double[]> totalStateMass,
                                       TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                                       ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
    if (tree.isLeaf()) {
      // possible to get here if we have a tree with no structure
      return position;
    }

    if (totalStateMass != null) {
      double[] stateTotal = totalStateMass.get(tree.label().value());
      if (stateTotal == null) {
        stateTotal = new double[stateWeights.length];
        totalStateMass.put(tree.label().value(), stateTotal);
      }
      for (int i = 0; i < stateWeights.length; ++i) {
        stateTotal[i] += Math.exp(stateWeights[i]);
      }
    }

    if (tree.isPreTerminal()) {
      // fill in our new lexicon here.
      String tag = tree.label().value();
      String word = tree.children()[0].label().value();
      // We smooth by LEX_SMOOTH, if relevant.  We rescale so that sum
      // of the weights being added to the lexicon stays the same.
      double total = 0.0;
      for (int state = 0; state < stateWeights.length; ++state) {
        total += Math.exp(stateWeights[state]);
      }
      if (total <= 0.0) {
        return position + 1;
      }
      double scale = 1.0 / (1.0 + LEX_SMOOTH);
      double smoothing = total * LEX_SMOOTH / stateWeights.length;
      for (int state = 0; state < stateWeights.length; ++state) {
        // TODO: maybe optimize all this TaggedWord creation
        TaggedWord tw = new TaggedWord(word, state(tag, state));
        tempLex.train(tw, position, (Math.exp(stateWeights[state]) + smoothing) * scale);
      }
      return position + 1;
    }

    if (tree.children().length == 1) {
      String parentLabel = tree.label().value();
      String childLabel = tree.children()[0].label().value();
      double[][] transitions = unaryTransitions.get(tree);
      int parentStates = transitions.length;
      int childStates = transitions[0].length;
      double[][] betas = tempUnaryBetas.get(parentLabel, childLabel);
      if (betas == null) {
        betas = new double[parentStates][childStates];
        for (int i = 0; i < parentStates; ++i) {
          for (int j = 0; j < childStates; ++j) {
            betas[i][j] = Double.NEGATIVE_INFINITY;
          }
        }
        tempUnaryBetas.put(parentLabel, childLabel, betas);
      }
      double[] childWeights = neginfDoubles(childStates);
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < childStates; ++j) {
          double weight = transitions[i][j];
          betas[i][j] = SloppyMath.logAdd(betas[i][j], weight + stateWeights[i]);
          childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]);
        }
      }
      position = recalculateTemporaryBetas(tree.children()[0], childWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
    } else { // length == 2
      String parentLabel = tree.label().value();
      String leftLabel = tree.children()[0].label().value();
      String rightLabel = tree.children()[1].label().value();
      double[][][] transitions = binaryTransitions.get(tree);
      int parentStates = transitions.length;
      int leftStates = transitions[0].length;
      int rightStates = transitions[0][0].length;

      double[][][] betas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
      if (betas == null) {
        betas = new double[parentStates][leftStates][rightStates];
        for (int i = 0; i < parentStates; ++i) {
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              betas[i][j][k] = Double.NEGATIVE_INFINITY;
            }
          }
        }
        tempBinaryBetas.put(parentLabel, leftLabel, rightLabel, betas);
      }
      double[] leftWeights = neginfDoubles(leftStates);
      double[] rightWeights = neginfDoubles(rightStates);
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            double weight = transitions[i][j][k];
            betas[i][j][k] = SloppyMath.logAdd(betas[i][j][k], weight + stateWeights[i]);
            leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]);
            rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]);
          }
        }
      }
      position = recalculateTemporaryBetas(tree.children()[0], leftWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
      position = recalculateTemporaryBetas(tree.children()[1], rightWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
    }
    return position;
  }

  public void rescaleTemporaryBetas(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
                                    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
    for (String parent : tempUnaryBetas.firstKeySet()) {
      for (String child : tempUnaryBetas.get(parent).keySet()) {
        double[][] betas = tempUnaryBetas.get(parent).get(child);
        int parentStates = betas.length;
        int childStates = betas[0].length;
        for (int i = 0; i < parentStates; ++i) {
          double sum = Double.NEGATIVE_INFINITY;
          for (int j = 0; j < childStates; ++j) {
            sum = SloppyMath.logAdd(sum, betas[i][j]);
          }
          if (Double.isInfinite(sum)) {
            for (int j = 0; j < childStates; ++j) {
              betas[i][j] = -Math.log(childStates);
            }
          } else {
            for (int j = 0; j < childStates; ++j) {
              betas[i][j] -= sum;
            }
          }
        }
      }
    }

    for (String parent : tempBinaryBetas.firstKeySet()) {
      for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
        for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
          double[][][] betas = tempBinaryBetas.get(parent).get(left).get(right);
          int parentStates = betas.length;
          int leftStates = betas[0].length;
          int rightStates = betas[0][0].length;
          for (int i = 0; i < parentStates; ++i) {
            double sum = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < leftStates; ++j) {
              for (int k = 0; k < rightStates; ++k) {
                sum = SloppyMath.logAdd(sum, betas[i][j][k]);
              }
            }
            if (Double.isInfinite(sum)) {
              for (int j = 0; j < leftStates; ++j) {
                for (int k = 0; k < rightStates; ++k) {
                  betas[i][j][k] = -Math.log(leftStates * rightStates);
                }
              }
            } else {
              for (int j = 0; j < leftStates; ++j) {
                for (int k = 0; k < rightStates; ++k) {
                  betas[i][j][k] -= sum;
                }
              }
            }
          }
        }
      }
    }
  }

  public void recountTree(Tree tree, boolean splitStates,
                          IdentityHashMap<Tree, double[][]> unaryTransitions,
                          IdentityHashMap<Tree, double[][][]> binaryTransitions) {
    IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
    IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
    recountTree(tree, splitStates, probIn, probOut, unaryTransitions, binaryTransitions);
  }

  public void recountTree(Tree tree, boolean splitStates,
                          IdentityHashMap<Tree, double[]> probIn,
                          IdentityHashMap<Tree, double[]> probOut,
                          IdentityHashMap<Tree, double[][]> unaryTransitions,
                          IdentityHashMap<Tree, double[][][]> binaryTransitions) {
    recountInside(tree, splitStates, 0, probIn);
    if (DEBUG()) {
      System.out.println("ROOT PROBABILITY: " + probIn.get(tree)[0]);
    }
    recountOutside(tree, probIn, probOut);
    recountWeights(tree, probIn, probOut, unaryTransitions, binaryTransitions);
  }

  public void recountWeights(Tree tree,
                             IdentityHashMap<Tree, double[]> probIn,
                             IdentityHashMap<Tree, double[]> probOut,
                             IdentityHashMap<Tree, double[][]> unaryTransitions,
                             IdentityHashMap<Tree, double[][][]> binaryTransitions) {
    if (tree.isLeaf() || tree.isPreTerminal()) {
      return;
    }
    if (tree.children().length == 1) {
      Tree child = tree.children()[0];
      String parentLabel = tree.label().value();
      String childLabel = child.label().value();
      double[][] betas = unaryBetas.get(parentLabel, childLabel);
      double[] childInside = probIn.get(child);
      double[] parentOutside = probOut.get(tree);
      int parentStates = betas.length;
      int childStates = betas[0].length;
      double[][] transitions = new double[parentStates][childStates];
      unaryTransitions.put(tree, transitions);
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < childStates; ++j) {
          transitions[i][j] = parentOutside[i] + childInside[j] + betas[i][j];
        }
      }
      // Renormalize.  Note that we renormalize to 1, regardless of
      // the original total.
      // TODO: smoothing?
      for (int i = 0; i < parentStates; ++i) {
        double total = Double.NEGATIVE_INFINITY;
        for (int j = 0; j < childStates; ++j) {
          total = SloppyMath.logAdd(total, transitions[i][j]);
        }
        // By subtracting off the log total, we make it so the log sum
        // of the transitions is 0, meaning the sum of the actual
        // transitions is 1.  It works if you do the math...
        if (Double.isInfinite(total)) {
          double transition = -Math.log(childStates);
          for (int j = 0; j < childStates; ++j) {
            transitions[i][j] = transition;
          }
        } else {
          for (int j = 0; j < childStates; ++j) {
            transitions[i][j] = transitions[i][j] - total;
          }
        }
      }
      recountWeights(child, probIn, probOut, unaryTransitions, binaryTransitions);
    } else { // length == 2
      Tree left = tree.children()[0];
      Tree right = tree.children()[1];
      String parentLabel = tree.label().value();
      String leftLabel = left.label().value();
      String rightLabel = right.label().value();
      double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
      double[] leftInside = probIn.get(left);
      double[] rightInside = probIn.get(right);
      double[] parentOutside = probOut.get(tree);
      int parentStates = betas.length;
      int leftStates = betas[0].length;
      int rightStates = betas[0][0].length;
      double[][][] transitions = new double[parentStates][leftStates][rightStates];
      binaryTransitions.put(tree, transitions);
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            transitions[i][j][k] = parentOutside[i] + leftInside[j] + rightInside[k] + betas[i][j][k];
          }
        }
      }
      // Renormalize.  Note that we renormalize to 1, regardless of
      // the original total.
      // TODO: smoothing?
      for (int i = 0; i < parentStates; ++i) {
        double total = Double.NEGATIVE_INFINITY;
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            total = SloppyMath.logAdd(total, transitions[i][j][k]);
          }
        }
        // By subtracting off the log total, we make it so the log sum
        // of the transitions is 0, meaning the sum of the actual
        // transitions is 1.  It works if you do the math...
        if (Double.isInfinite(total)) {
          double transition = -Math.log(leftStates * rightStates);
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              transitions[i][j][k] = transition;
            }
          }
        } else {
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              transitions[i][j][k] = transitions[i][j][k] - total;
            }
          }
        }
      }
      recountWeights(left, probIn, probOut, unaryTransitions, binaryTransitions);
      recountWeights(right, probIn, probOut, unaryTransitions, binaryTransitions);
    }
  }

  public void recountOutside(Tree tree,
                             IdentityHashMap<Tree, double[]> probIn,
                             IdentityHashMap<Tree, double[]> probOut) {
    double[] rootScores = { 0.0 };
    probOut.put(tree, rootScores);
    recurseOutside(tree, probIn, probOut);
  }

  public void recurseOutside(Tree tree,
                             IdentityHashMap<Tree, double[]> probIn,
                             IdentityHashMap<Tree, double[]> probOut) {
    if (tree.isLeaf() || tree.isPreTerminal()) {
      return;
    }
    if (tree.children().length == 1) {
      recountOutside(tree.children()[0], tree, probIn, probOut);
    } else { // length == 2
      recountOutside(tree.children()[0], tree.children()[1], tree,
                     probIn, probOut);
    }
  }

  public void recountOutside(Tree child, Tree parent,
                             IdentityHashMap<Tree, double[]> probIn,
                             IdentityHashMap<Tree, double[]> probOut) {
    String parentLabel = parent.label().value();
    String childLabel = child.label().value();
    double[] parentScores = probOut.get(parent);
    double[][] betas = unaryBetas.get(parentLabel, childLabel);
    int parentStates = betas.length;
    int childStates = betas[0].length;

    double[] scores = neginfDoubles(childStates);
    probOut.put(child, scores);

    for (int i = 0; i < parentStates; ++i) {
      for (int j = 0; j < childStates; ++j) {
        // TODO: no inside scores here, right?
        scores[j] = SloppyMath.logAdd(scores[j], betas[i][j] + parentScores[i]);
      }
    }

    recurseOutside(child, probIn, probOut);
  }

  public void recountOutside(Tree left, Tree right, Tree parent,
                             IdentityHashMap<Tree, double[]> probIn,
                             IdentityHashMap<Tree, double[]> probOut) {
    String parentLabel = parent.label().value();
    String leftLabel = left.label().value();
    String rightLabel = right.label().value();
    double[] leftInsideScores = probIn.get(left);
    double[] rightInsideScores = probIn.get(right);
    double[] parentScores = probOut.get(parent);
    double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
    int parentStates = betas.length;
    int leftStates = betas[0].length;
    int rightStates = betas[0][0].length;

    double[] leftScores = neginfDoubles(leftStates);
    probOut.put(left, leftScores);
    double[] rightScores = neginfDoubles(rightStates);
    probOut.put(right, rightScores);

    for (int i = 0; i < parentStates; ++i) {
      for (int j = 0; j < leftStates; ++j) {
        for (int k = 0; k < rightStates; ++k) {
          leftScores[j] = SloppyMath.logAdd(leftScores[j], betas[i][j][k] + parentScores[i] + rightInsideScores[k]);
          rightScores[k] = SloppyMath.logAdd(rightScores[k], betas[i][j][k] + parentScores[i] + leftInsideScores[j]);
        }
      }
    }

    recurseOutside(left, probIn, probOut);
    recurseOutside(right, probIn, probOut);
  }

  public int recountInside(Tree tree, boolean splitStates, int loc,
                           IdentityHashMap<Tree, double[]> probIn) {
    if (tree.isLeaf()) {
      throw new RuntimeException();
    } else if (tree.isPreTerminal()) {
      int stateCount = getStateSplitCount(tree);
      String word = tree.children()[0].label().value();
      String tag = tree.label().value();

      double[] scores = new double[stateCount];
      probIn.put(tree, scores);

      if (splitStates && !tag.equals(Lexicon.BOUNDARY_TAG)) {
        for (int i = 0; i < stateCount / 2; ++i) {
          IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex);
          double logProb = lex.score(tw, loc, word, null);
          double wordWeight = 0.45 + random.nextDouble() * 0.1;
          scores[i * 2] = logProb + Math.log(wordWeight);
          scores[i * 2 + 1] = logProb + Math.log(1.0 - wordWeight);
          if (DEBUG()) {
            System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + logProb);
            System.out.println("  Log Split -> " + scores[i * 2] + "," + scores[i * 2 + 1]);
          }
        }
      } else {
        for (int i = 0; i < stateCount; ++i) {
          IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex);
          double prob = lex.score(tw, loc, word, null);
          if (DEBUG()) {
            System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + prob);
          }
          scores[i] = prob;
        }
      }
      loc = loc + 1;
    } else if (tree.children().length == 1) {
      loc = recountInside(tree.children()[0], splitStates, loc, probIn);
      double[] childScores = probIn.get(tree.children()[0]);
      String parentLabel = tree.label().value();
      String childLabel = tree.children()[0].label().value();
      double[][] betas = unaryBetas.get(parentLabel, childLabel);
      int parentStates = betas.length; // size of the first key
      int childStates = betas[0].length;

      double[] scores = neginfDoubles(parentStates);
      probIn.put(tree, scores);

      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < childStates; ++j) {
          scores[i] = SloppyMath.logAdd(scores[i], childScores[j] + betas[i][j]);
        }
      }
      if (DEBUG()) {
        System.out.println(parentLabel + " -> " + childLabel);
        for (int i = 0; i < parentStates; ++i) {
          System.out.println("  " + i + ":" + scores[i]);
          for (int j = 0; j < childStates; ++j) {
            System.out.println("    " + i + "," + j + ": " + betas[i][j] + " | " + Math.exp(betas[i][j]));
          }
        }
      }
    } else { // length == 2
      loc = recountInside(tree.children()[0], splitStates, loc, probIn);
      loc = recountInside(tree.children()[1], splitStates, loc, probIn);
      double[] leftScores = probIn.get(tree.children()[0]);
      double[] rightScores = probIn.get(tree.children()[1]);
      String parentLabel = tree.label().value();
      String leftLabel = tree.children()[0].label().value();
      String rightLabel = tree.children()[1].label().value();
      double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
      int parentStates = betas.length;
      int leftStates = betas[0].length;
      int rightStates = betas[0][0].length;

      double[] scores = neginfDoubles(parentStates);
      probIn.put(tree, scores);

      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            scores[i] = SloppyMath.logAdd(scores[i], leftScores[j] + rightScores[k] + betas[i][j][k]);
          }
        }
      }
      if (DEBUG()) {
        System.out.println(parentLabel + " -> " + leftLabel + "," + rightLabel);
        for (int i = 0; i < parentStates; ++i) {
          System.out.println("  " + i + ":" + scores[i]);
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              System.out.println("    " + i + "," + j + "," + k + ": " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
            }
          }
        }
      }
    }
    return loc;
  }

  public void mergeStates() {
    if (op.trainOptions.splitRecombineRate <= 0.0) {
      return;
    }

    // we go through the machinery to sum up the temporary betas,
    // counting the total mass
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
    Map<String, double[]> totalStateMass = Generics.newHashMap();
    recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);

    // Next, for each tree we count the effect of merging its
    // annotations.  We only consider the most recently split
    // annotations as candidates for merging.
    Map<String, double[]> deltaAnnotations = Generics.newHashMap();
    for (Tree tree : trees) {
      countMergeEffects(tree, totalStateMass, deltaAnnotations);
    }

    // Now we have a map of the (approximate) likelihood loss from
    // merging each state.  We merge the ones that provide the least
    // benefit, up to the splitRecombineRate
    List<Triple<String, Integer, Double>> sortedDeltas =
      new ArrayList<Triple<String, Integer, Double>>();
    for (String state : deltaAnnotations.keySet()) {
      double[] scores = deltaAnnotations.get(state);
      for (int i = 0; i < scores.length; ++i) {
        sortedDeltas.add(new Triple<String, Integer, Double>(state, i * 2, scores[i]));
      }
    }
    Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {
        public int compare(Triple<String, Integer, Double> first,
                           Triple<String, Integer, Double> second) {
          // The most useful splits will have a large loss in
          // likelihood if they are merged.  Thus, we want those at
          // the end of the list.  This means we make the comparison
          // "backwards", sorting from high to low.
          return Double.compare(second.third(), first.third());
        }
        public boolean equals(Object o) { return o == this; }
      });

    // for (Triple<String, Integer, Double> delta : sortedDeltas) {
    //   System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
    // }
    // System.out.println("-------------");

    // Only merge a fraction of the splits based on what the user
    // originally asked for
    int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
    splitsToMerge = Math.max(0, splitsToMerge);
    splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
    sortedDeltas = sortedDeltas.subList(0, splitsToMerge);

    System.out.println();
    System.out.println(sortedDeltas);

    Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);

    recalculateMergedBetas(mergeCorrespondence);

    for (Triple<String, Integer, Double> delta : sortedDeltas) {
      stateSplitCounts.decrementCount(delta.first(), 1);
    }
  }

  public void recalculateMergedBetas(Map<String, int[]> mergeCorrespondence) {
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();

    tempWordIndex = new HashIndex<String>();
    tempTagIndex = new HashIndex<String>();
    tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
    tempLex.initializeTraining(trainSize);

    for (Tree tree : trees) {
      double treeWeight = treeWeights.getCount(tree);
      double[] stateWeights = { Math.log(treeWeight) };
      tempLex.incrementTreesRead(treeWeight);

      IdentityHashMap<Tree, double[][]> oldUnaryTransitions = new IdentityHashMap<Tree, double[][]>();
      IdentityHashMap<Tree, double[][][]> oldBinaryTransitions = new IdentityHashMap<Tree, double[][][]>();
      recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions);

      IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
      IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
      mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence);

      recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions,
                                null, tempUnaryBetas, tempBinaryBetas);
    }

    tempLex.finishTraining();
    useNewBetas(false, tempUnaryBetas, tempBinaryBetas);
  }

  /**
   * Given a tree and the original set of transition probabilities
   * from one state to the next in the tree, along with a list of the
   * weights in the tree and a count of the mass in each substate at
   * the current node, this method merges the probabilities as
   * necessary.  The results go into newUnaryTransitions and
   * newBinaryTransitions.
   */
  public void mergeTransitions(Tree parent,
                               IdentityHashMap<Tree, double[][]> oldUnaryTransitions,
                               IdentityHashMap<Tree, double[][][]> oldBinaryTransitions,
                               IdentityHashMap<Tree, double[][]> newUnaryTransitions,
                               IdentityHashMap<Tree, double[][][]> newBinaryTransitions,
                               double[] stateWeights,
                               Map<String, int[]> mergeCorrespondence) {
    if (parent.isPreTerminal() || parent.isLeaf()) {
      return;
    }
    if (parent.children().length == 1) {
      double[][] oldTransitions = oldUnaryTransitions.get(parent);

      String parentLabel = parent.label().value();
      int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
      int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;

      String childLabel = parent.children()[0].label().value();
      int[] childCorrespondence = mergeCorrespondence.get(childLabel);
      int childStates = childCorrespondence[childCorrespondence.length - 1] + 1;

      // System.out.println("P: " + parentLabel + " " + parentStates +
      //                    " C: " + childLabel + " " + childStates);


      // Add up the probabilities of transitioning to each state,
      // scaled by the probability of being in a given state to begin
      // with.  This accounts for when two states in the parent are
      // collapsed into one state.
      double[][] newTransitions = new double[parentStates][childStates];
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < childStates; ++j) {
          newTransitions[i][j] = Double.NEGATIVE_INFINITY;
        }
      }
      newUnaryTransitions.put(parent, newTransitions);
      for (int i = 0; i < oldTransitions.length; ++i) {
        int ti = parentCorrespondence[i];
        for (int j = 0; j < oldTransitions[0].length; ++j) {
          int tj = childCorrespondence[j];
          // System.out.println(i + " " + ti + " " + j + " " + tj);
          newTransitions[ti][tj] = SloppyMath.logAdd(newTransitions[ti][tj], oldTransitions[i][j] + stateWeights[i]);
        }
      }

      // renormalize
      for (int i = 0; i < parentStates; ++i) {
        double total = Double.NEGATIVE_INFINITY;
        for (int j = 0; j < childStates; ++j) {
          total = SloppyMath.logAdd(total, newTransitions[i][j]);
        }
        if (Double.isInfinite(total)) {
          for (int j = 0; j < childStates; ++j) {
            newTransitions[i][j] = -Math.log(childStates);
          }
        } else {
          for (int j = 0; j < childStates; ++j) {
            newTransitions[i][j] -= total;
          }
        }
      }

      double[] childWeights = neginfDoubles(oldTransitions[0].length);
      for (int i = 0; i < oldTransitions.length; ++i) {
        for (int j = 0; j < oldTransitions[0].length; ++j) {
          double weight = oldTransitions[i][j];
          childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]);
        }
      }

      mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, childWeights, mergeCorrespondence);
    } else {
      double[][][] oldTransitions = oldBinaryTransitions.get(parent);

      String parentLabel = parent.label().value();
      int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
      int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;

      String leftLabel = parent.children()[0].label().value();
      int[] leftCorrespondence = mergeCorrespondence.get(leftLabel);
      int leftStates = leftCorrespondence[leftCorrespondence.length - 1] + 1;

      String rightLabel = parent.children()[1].label().value();
      int[] rightCorrespondence = mergeCorrespondence.get(rightLabel);
      int rightStates = rightCorrespondence[rightCorrespondence.length - 1] + 1;

      // System.out.println("P: " + parentLabel + " " + parentStates +
      //                    " L: " + leftLabel + " " + leftStates +
      //                    " R: " + rightLabel + " " + rightStates);

      double[][][] newTransitions = new double[parentStates][leftStates][rightStates];
      for (int i = 0; i < parentStates; ++i) {
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            newTransitions[i][j][k] = Double.NEGATIVE_INFINITY;
          }
        }
      }
      newBinaryTransitions.put(parent, newTransitions);
      for (int i = 0; i < oldTransitions.length; ++i) {
        int ti = parentCorrespondence[i];
        for (int j = 0; j < oldTransitions[0].length; ++j) {
          int tj = leftCorrespondence[j];
          for (int k = 0; k < oldTransitions[0][0].length; ++k) {
            int tk = rightCorrespondence[k];
            // System.out.println(i + " " + ti + " " + j + " " + tj + " " + k + " " + tk);
            newTransitions[ti][tj][tk] = SloppyMath.logAdd(newTransitions[ti][tj][tk], oldTransitions[i][j][k] + stateWeights[i]);
          }
        }
      }

      // renormalize
      for (int i = 0; i < parentStates; ++i) {
        double total = Double.NEGATIVE_INFINITY;
        for (int j = 0; j < leftStates; ++j) {
          for (int k = 0; k < rightStates; ++k) {
            total = SloppyMath.logAdd(total, newTransitions[i][j][k]);
          }
        }
        if (Double.isInfinite(total)) {
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              newTransitions[i][j][k] = -Math.log(leftStates * rightStates);
            }
          }
        } else {
          for (int j = 0; j < leftStates; ++j) {
            for (int k = 0; k < rightStates; ++k) {
              newTransitions[i][j][k] -= total;
            }
          }
        }
      }

      double[] leftWeights = neginfDoubles(oldTransitions[0].length);
      double[] rightWeights = neginfDoubles(oldTransitions[0][0].length);
      for (int i = 0; i < oldTransitions.length; ++i) {
        for (int j = 0; j < oldTransitions[0].length; ++j) {
          for (int k = 0; k < oldTransitions[0][0].length; ++k) {
            double weight = oldTransitions[i][j][k];
            leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]);
            rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]);
          }
        }
      }

      mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, leftWeights, mergeCorrespondence);
      mergeTransitions(parent.children()[1], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, rightWeights, mergeCorrespondence);
    }
  }

  Map<String, int[]> buildMergeCorrespondence(List<Triple<String, Integer, Double>> deltas) {
    Map<String, int[]> mergeCorrespondence = Generics.newHashMap();
    for (String state : originalStates) {
      int states = getStateSplitCount(state);
      int[] correspondence = new int[states];
      for (int i = 0; i < states; ++i) {
        correspondence[i] = i;
      }
      mergeCorrespondence.put(state, correspondence);
    }
    for (Triple<String, Integer, Double> merge : deltas) {
      int states = getStateSplitCount(merge.first());
      int split = merge.second();
      int[] correspondence = mergeCorrespondence.get(merge.first());
      for (int i = split + 1; i < states; ++i) {
        correspondence[i] = correspondence[i] - 1;
      }
    }
    return mergeCorrespondence;
  }

  public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass,
                                Map<String, double[]> deltaAnnotations) {
    IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
    IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
    IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
    IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
    recountTree(tree, false, probIn, probOut, unaryTransitions, binaryTransitions);

    // no need to count the root
    for (Tree child : tree.children()) {
      countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
    }
  }

  public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass,
                                Map<String, double[]> deltaAnnotations,
                                IdentityHashMap<Tree, double[]> probIn,
                                IdentityHashMap<Tree, double[]> probOut) {
    if (tree.isLeaf()) {
      return;
    }
    if (tree.label().value().equals(Lexicon.BOUNDARY_TAG)) {
      return;
    }

    String label = tree.label().value();
    double totalMass = 0.0;
    double[] stateMass = totalStateMass.get(label);
    for (double mass : stateMass) {
      totalMass += mass;
    }

    double[] nodeProbIn = probIn.get(tree);
    double[] nodeProbOut = probOut.get(tree);

    double[] nodeDelta = deltaAnnotations.get(label);
    if (nodeDelta == null) {
      nodeDelta = new double[nodeProbIn.length / 2];
      deltaAnnotations.put(label, nodeDelta);
    }

    for (int i = 0; i < nodeProbIn.length / 2; ++i) {
      double probInMerged = SloppyMath.logAdd(Math.log(stateMass[i * 2] / totalMass) + nodeProbIn[i * 2],
                                              Math.log(stateMass[i * 2 + 1] / totalMass) + nodeProbIn[i * 2 + 1]);
      double probOutMerged = SloppyMath.logAdd(nodeProbOut[i * 2], nodeProbOut[i * 2 + 1]);
      double probMerged = probInMerged + probOutMerged;
      double probUnmerged = SloppyMath.logAdd(nodeProbIn[i * 2] + nodeProbOut[i * 2],
                                              nodeProbIn[i * 2 + 1] + nodeProbOut[i * 2 + 1]);
      nodeDelta[i] = nodeDelta[i] + probMerged - probUnmerged;
    }

    if (tree.isPreTerminal()) {
      return;
    }
    for (Tree child : tree.children()) {
      countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
    }
  }

  public void buildStateIndex() {
    stateIndex = new HashIndex<String>();
    for (String key : stateSplitCounts.keySet()) {
      for (int i = 0; i < stateSplitCounts.getIntCount(key); ++i) {
        stateIndex.addToIndex(state(key, i));
      }
    }
  }

  public void buildGrammars() {
    // In order to build the grammars, we first need to fill in the
    // temp betas with the sums of the transitions from Ax to By or Ax
    // to By,Cz.  We also need the sum total of the mass in each state
    // Ax over all the trees.

    // we go through the machinery to sum up the temporary betas,
    // counting the total mass...
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
    Map<String, double[]> totalStateMass = Generics.newHashMap();
    recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);

    // ... but note we don't actually rescale the betas.
    // instead we use the temporary betas and the total mass in each
    // state to calculate the grammars

    // First build up a BinaryGrammar.
    // The score for each rule will be the Beta scores found earlier,
    // scaled by the total weight of a transition between unsplit states
    BinaryGrammar bg = new BinaryGrammar(stateIndex);
    for (String parent : tempBinaryBetas.firstKeySet()) {
      int parentStates = getStateSplitCount(parent);
      double[] stateTotal = totalStateMass.get(parent);
      for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
        int leftStates = getStateSplitCount(left);
        for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
          int rightStates = getStateSplitCount(right);
          double[][][] betas = tempBinaryBetas.get(parent, left, right);
          for (int i = 0; i < parentStates; ++i) {
            if (stateTotal[i] < EPSILON) {
              continue;
            }
            for (int j = 0; j < leftStates; ++j) {
              for (int k = 0; k < rightStates; ++k) {
                int parentIndex = stateIndex.indexOf(state(parent, i));
                int leftIndex = stateIndex.indexOf(state(left, j));
                int rightIndex = stateIndex.indexOf(state(right, k));
                double score = betas[i][j][k] - Math.log(stateTotal[i]);
                BinaryRule br = new BinaryRule(parentIndex, leftIndex, rightIndex, score);
                bg.addRule(br);
              }
            }
          }
        }
      }
    }

    // Now build up a UnaryGrammar
    UnaryGrammar ug = new UnaryGrammar(stateIndex);
    for (String parent : tempUnaryBetas.firstKeySet()) {
      int parentStates = getStateSplitCount(parent);
      double[] stateTotal = totalStateMass.get(parent);
      for (String child : tempUnaryBetas.get(parent).keySet()) {
        int childStates = getStateSplitCount(child);
        double[][] betas = tempUnaryBetas.get(parent, child);
        for (int i = 0; i < parentStates; ++i) {
          if (stateTotal[i] < EPSILON) {
            continue;
          }
          for (int j = 0; j < childStates; ++j) {
            int parentIndex = stateIndex.indexOf(state(parent, i));
            int childIndex = stateIndex.indexOf(state(child, j));
            double score = betas[i][j] - Math.log(stateTotal[i]);
            UnaryRule ur = new UnaryRule(parentIndex, childIndex, score);
            ug.addRule(ur);
          }
        }
      }
    }


    bgug = new Pair<UnaryGrammar, BinaryGrammar>(ug, bg);
  }

  public void saveTrees(Collection<Tree> trees1, double weight1,
                        Collection<Tree> trees2, double weight2) {
    trainSize = 0.0;
    int treeCount = 0;
    trees.clear();
    treeWeights.clear();
    for (Tree tree : trees1) {
      trees.add(tree);
      treeWeights.incrementCount(tree, weight1);
      trainSize += weight1;
    }
    treeCount += trees1.size();
    if (trees2 != null && weight2 >= 0.0) {
      for (Tree tree : trees2) {
        trees.add(tree);
        treeWeights.incrementCount(tree, weight2);
        trainSize += weight2;
      }
      treeCount += trees2.size();
    }
    System.err.println("Found " + treeCount +
                       " trees with total weight " + trainSize);
  }

  public void extract(Collection<Tree> treeList) {
    extract(treeList, 1.0, null, 0.0);
  }

  /**
   * First, we do a few setup steps.  We read in all the trees, which
   * is necessary because we continually reprocess them and use the
   * object pointers as hash keys rather than hashing the trees
   * themselves.  We then count the initial states in the treebank.
   * <br>
   * Having done that, we then assign initial probabilities to the
   * trees.  At first, each state has 1.0 of the probability mass for
   * each Ax-ByCz and Ax-By transition.  We then split the number of
   * states and the probabilities on each tree.
   * <br>
   * We then repeatedly recalculate the betas and reannotate the
   * weights, going until we converge, which is defined as no betas
   * move more then epsilon.
   * <br>
   * java -mx4g edu.stanford.nlp.parser.lexparser.LexicalizedParser  -PCFG -saveToSerializedFile englishSplit.ser.gz -saveToTextFile englishSplit.txt -maxLength 40 -train ../data/wsj/wsjtwentytrees.mrg    -testTreebank ../data/wsj/wsjtwentytrees.mrg   -evals "factDA,tsv" -uwm 0  -hMarkov 0 -vMarkov 0 -simpleBinarizedLabels -noRebinarization -predictSplits -splitTrainingThreads 1 -splitCount 1 -splitRecombineRate 0.5
   * <br>
   * may also need
   * <br>
   *  -smoothTagsThresh 0
   * <br>
   * java -mx8g edu.stanford.nlp.parser.lexparser.LexicalizedParser -evals "factDA,tsv" -PCFG -vMarkov 0 -hMarkov 0 -uwm 0 -saveToSerializedFile wsjS1.ser.gz -maxLength 40 -train /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 2200-2219 -compactGrammar 0 -simpleBinarizedLabels -predictSplits -smoothTagsThresh 0 -splitCount 1 -noRebinarization
   */
  public void extract(Collection<Tree> trees1, double weight1,
                      Collection<Tree> trees2, double weight2) {
    saveTrees(trees1, weight1, trees2, weight2);

    countOriginalStates();

    // Initial betas will be 1 for all possible unary and binary
    // transitions in our treebank
    initialBetasAndLexicon();

    for (int cycle = 0; cycle < op.trainOptions.splitCount; ++cycle) {
      // All states except the root state get split into 2
      splitStateCounts();

      // first, recalculate the betas and the lexicon for having split
      // the transitions
      recalculateBetas(true);

      // now, loop until we converge while recalculating betas
      // TODO: add a loop counter, stop after X iterations
      iteration = 0;
      boolean converged = false;
      while (!converged && iteration < MAX_ITERATIONS) {
        if (DEBUG()) {
          System.out.println();
          System.out.println();
          System.out.println("-------------------");
          System.out.println("Iteration " + iteration);
        }

        converged = recalculateBetas(false);
        ++iteration;
      }

      System.err.println("Converged for cycle " + cycle +
                         " in " + iteration + " iterations");

      mergeStates();
    }

    // Build up the state index.  The BG & UG both expect a set count
    // of states.
    buildStateIndex();

    buildGrammars();
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.SplittingGrammarExtractor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.