package edu.stanford.nlp.parser.lexparser;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ThreeDimensionalMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.TwoDimensionalMap;
import java.io.*;
/**
* This class is a reimplementation of Berkeley's state splitting
* grammar. This work is experimental and still in progress. There
* are several extremely important pieces to implement:
* <ol>
* <li> this code should use log probabilities throughout instead of
* multiplying tiny numbers
* <li> time efficiency of the training code is fawful
* <li> there are better ways to extract parses using this grammar than
* the method in ExhaustivePCFGParser
* <li> we should also implement cascading parsers that let us
* shortcircuit low quality parses earlier (which could possibly
* benefit non-split parsers as well)
* <li> when looping, we should short circuit if we go too many loops
* <li> ought to smooth as per page 436
* </ol>
*
* @author John Bauer
*/
public class SplittingGrammarExtractor {
static final int MIN_DEBUG_ITERATION=0;
static final int MAX_DEBUG_ITERATION=0;
static final int MAX_ITERATIONS = Integer.MAX_VALUE;
int iteration = 0;
boolean DEBUG() {
return (iteration >= MIN_DEBUG_ITERATION && iteration < MAX_DEBUG_ITERATION);
}
Options op;
/**
* These objects are created and filled in here. The caller can get
* the data from the extractor once it is finished.
*/
Index<String> stateIndex;
Index<String> wordIndex;
Index<String> tagIndex;
/**
* This is a list gotten from the list of startSymbols in op.langpack()
*/
List<String> startSymbols;
/**
* A combined list of all the trees in the training set.
*/
List<Tree> trees = new ArrayList<Tree>();
/**
* All of the weights associated with the trees in the training set.
* In general, this is just the weight of the original treebank.
* Note that this uses an identity hash map to map from tree pointer
* to weight.
*/
Counter<Tree> treeWeights = new ClassicCounter<Tree>(MapFactory.<Tree,MutableDouble>identityHashMapFactory());
/**
* How many total weighted trees we have
*/
double trainSize;
/**
* The original states in the trees
*/
Set<String> originalStates = Generics.newHashSet();
/**
* The current number of times a particular state has been split
*/
IntCounter<String> stateSplitCounts = new IntCounter<String>();
/**
* The binary betas are weights to go from Ax to By, Cz. This maps
* from (A, B, C) to (x, y, z) to beta(Ax, By, Cz).
*/
ThreeDimensionalMap<String, String, String, double[][][]> binaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
/**
* The unary betas are weights to go from Ax to By. This maps
* from (A, B) to (x, y) to beta(Ax, By).
*/
TwoDimensionalMap<String, String, double[][]> unaryBetas = new TwoDimensionalMap<String, String, double[][]>();
/**
* The latest lexicon we trained. At the end of the process, this
* is the lexicon for the parser.
*/
Lexicon lex;
transient Index<String> tempWordIndex;
transient Index<String> tempTagIndex;
/**
* The lexicon we are in the process of building in each iteration.
*/
transient Lexicon tempLex;
/**
* The latest pair of unary and binary grammars we trained.
*/
Pair<UnaryGrammar, BinaryGrammar> bgug;
Random random = new Random(87543875943265L);
static final double LEX_SMOOTH = 0.0001;
static final double STATE_SMOOTH = 0.0;
public SplittingGrammarExtractor(Options op) {
this.op = op;
startSymbols = Arrays.asList(op.langpack().startSymbols());
}
double[] neginfDoubles(int size) {
double[] result = new double[size];
for (int i = 0; i < size; ++i) {
result[i] = Double.NEGATIVE_INFINITY;
}
return result;
}
public void outputTransitions(Tree tree,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions) {
outputTransitions(tree, 0, unaryTransitions, binaryTransitions);
}
public void outputTransitions(Tree tree, int depth,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions) {
for (int i = 0; i < depth; ++i) {
System.out.print(" ");
}
if (tree.isLeaf()) {
System.out.println(tree.label().value());
return;
}
if (tree.children().length == 1) {
System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value());
if (!tree.isPreTerminal()) {
double[][] transitions = unaryTransitions.get(tree);
for (int i = 0; i < transitions.length; ++i) {
for (int j = 0; j < transitions[0].length; ++j) {
for (int z = 0; z < depth; ++z) {
System.out.print(" ");
}
System.out.println(" " + i + "," + j + ": " + transitions[i][j] + " | " + Math.exp(transitions[i][j]));
}
}
}
} else {
System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value());
double[][][] transitions = binaryTransitions.get(tree);
for (int i = 0; i < transitions.length; ++i) {
for (int j = 0; j < transitions[0].length; ++j) {
for (int k = 0; k < transitions[0][0].length; ++k) {
for (int z = 0; z < depth; ++z) {
System.out.print(" ");
}
System.out.println(" " + i + "," + j + "," + k + ": " + transitions[i][j][k] + " | " + Math.exp(transitions[i][j][k]));
}
}
}
}
if (tree.isPreTerminal()) {
return;
}
for (Tree child : tree.children()) {
outputTransitions(child, depth + 1, unaryTransitions, binaryTransitions);
}
}
public void outputBetas() {
System.out.println("UNARY:");
for (String parent : unaryBetas.firstKeySet()) {
for (String child : unaryBetas.get(parent).keySet()) {
System.out.println(" " + parent + "->" + child);
double[][] betas = unaryBetas.get(parent).get(child);
int parentStates = betas.length;
int childStates = betas[0].length;
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
System.out.println(" " + i + "->" + j + " " + betas[i][j] + " | " + Math.exp(betas[i][j]));
}
}
}
}
System.out.println("BINARY:");
for (String parent : binaryBetas.firstKeySet()) {
for (String left : binaryBetas.get(parent).firstKeySet()) {
for (String right : binaryBetas.get(parent).get(left).keySet()) {
System.out.println(" " + parent + "->" + left + "," + right);
double[][][] betas = binaryBetas.get(parent).get(left).get(right);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
System.out.println(" " + i + "->" + j + "," + k + " " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
}
}
}
}
}
}
}
public String state(String tag, int i) {
if (startSymbols.contains(tag) || tag.equals(Lexicon.BOUNDARY_TAG)) {
return tag;
}
return tag + "^" + i;
}
public int getStateSplitCount(Tree tree) {
return stateSplitCounts.getIntCount(tree.label().value());
}
public int getStateSplitCount(String label) {
return stateSplitCounts.getIntCount(label);
}
/**
* Count all the internal labels in all the trees, and set their
* initial state counts to 1.
*/
public void countOriginalStates() {
originalStates.clear();
for (Tree tree : trees) {
countOriginalStates(tree);
}
for (String state : originalStates) {
stateSplitCounts.incrementCount(state, 1);
}
}
/**
* Counts the labels in the tree, but not the words themselves.
*/
private void countOriginalStates(Tree tree) {
if (tree.isLeaf()) {
return;
}
originalStates.add(tree.label().value());
for (Tree child : tree.children()) {
if (child.isLeaf())
continue;
countOriginalStates(child);
}
}
private void initialBetasAndLexicon() {
wordIndex = new HashIndex<String>();
tagIndex = new HashIndex<String>();
lex = op.tlpParams.lex(op, wordIndex, tagIndex);
lex.initializeTraining(trainSize);
for (Tree tree : trees) {
double weight = treeWeights.getCount(tree);
lex.incrementTreesRead(weight);
initialBetasAndLexicon(tree, 0, weight);
}
lex.finishTraining();
}
private int initialBetasAndLexicon(Tree tree, int position, double weight) {
if (tree.isLeaf()) {
// should never get here, unless a training tree is just one leaf
return position;
}
if (tree.isPreTerminal()) {
// fill in initial lexicon here
String tag = tree.label().value();
String word = tree.children()[0].label().value();
TaggedWord tw = new TaggedWord(word, state(tag, 0));
lex.train(tw, position, weight);
return (position + 1);
}
if (tree.children().length == 2) {
String label = tree.label().value();
String leftLabel = tree.getChild(0).label().value();
String rightLabel = tree.getChild(1).label().value();
if (!binaryBetas.contains(label, leftLabel, rightLabel)) {
double[][][] map = new double[1][1][1];
map[0][0][0] = 0.0;
binaryBetas.put(label, leftLabel, rightLabel, map);
}
} else if (tree.children().length == 1) {
String label = tree.label().value();
String childLabel = tree.getChild(0).label().value();
if (!unaryBetas.contains(label, childLabel)) {
double[][] map = new double[1][1];
map[0][0] = 0.0;
unaryBetas.put(label, childLabel, map);
}
} else {
// should have been binarized
throw new RuntimeException("Trees should have been binarized, expected 1 or 2 children");
}
for (Tree child : tree.children()) {
position = initialBetasAndLexicon(child, position, weight);
}
return position;
}
/**
* Splits the state counts. Root states and the boundary tag do not
* get their counts increased, and all others are doubled. Betas
* and transition weights are handled later.
*/
private void splitStateCounts() {
// double the count of states...
IntCounter<String> newStateSplitCounts = new IntCounter<String>();
newStateSplitCounts.addAll(stateSplitCounts);
newStateSplitCounts.addAll(stateSplitCounts);
// root states should only have 1
for (String root : startSymbols) {
if (newStateSplitCounts.getCount(root) > 1) {
newStateSplitCounts.setCount(root, 1);
}
}
if (newStateSplitCounts.getCount(Lexicon.BOUNDARY_TAG) > 1) {
newStateSplitCounts.setCount(Lexicon.BOUNDARY_TAG, 1);
}
stateSplitCounts = newStateSplitCounts;
}
static final double EPSILON = 0.0001;
/**
* Before each iteration of splitting states, we have tables of
* betas which correspond to the transitions between different
* substates. When we resplit the states, we duplicate parent
* states and then split their transitions 50/50 with some random
* variation between child states.
*/
public void splitBetas() {
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
for (String parent : unaryBetas.firstKeySet()) {
for (String child : unaryBetas.get(parent).keySet()) {
double[][] betas = unaryBetas.get(parent, child);
int parentStates = betas.length;
int childStates = betas[0].length;
double[][] newBetas;
if (!startSymbols.contains(parent)) {
newBetas = new double[parentStates * 2][childStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
newBetas[i * 2][j] = betas[i][j];
newBetas[i * 2 + 1][j] = betas[i][j];
}
}
parentStates *= 2;
betas = newBetas;
}
if (!child.equals(Lexicon.BOUNDARY_TAG)) {
newBetas = new double[parentStates][childStates * 2];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
double childWeight = 0.45 + random.nextDouble() * 0.1;
newBetas[i][j * 2] = betas[i][j] + Math.log(childWeight);
newBetas[i][j * 2 + 1] = betas[i][j] + Math.log(1.0 - childWeight);
}
}
betas = newBetas;
}
tempUnaryBetas.put(parent, child, betas);
}
}
for (String parent : binaryBetas.firstKeySet()) {
for (String left : binaryBetas.get(parent).firstKeySet()) {
for (String right : binaryBetas.get(parent).get(left).keySet()) {
double[][][] betas = binaryBetas.get(parent, left, right);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
double[][][] newBetas;
if (!startSymbols.contains(parent)) {
newBetas = new double[parentStates * 2][leftStates][rightStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
newBetas[i * 2][j][k] = betas[i][j][k];
newBetas[i * 2 + 1][j][k] = betas[i][j][k];
}
}
}
parentStates *= 2;
betas = newBetas;
}
newBetas = new double[parentStates][leftStates * 2][rightStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
double leftWeight = 0.45 + random.nextDouble() * 0.1;
newBetas[i][j * 2][k] = betas[i][j][k] + Math.log(leftWeight);
newBetas[i][j * 2 + 1][k] = betas[i][j][k] + Math.log(1 - leftWeight);
}
}
}
leftStates *= 2;
betas = newBetas;
if (!right.equals(Lexicon.BOUNDARY_TAG)) {
newBetas = new double[parentStates][leftStates][rightStates * 2];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
double rightWeight = 0.45 + random.nextDouble() * 0.1;
newBetas[i][j][k * 2] = betas[i][j][k] + Math.log(rightWeight);
newBetas[i][j][k * 2 + 1] = betas[i][j][k] + Math.log(1 - rightWeight);
}
}
}
}
tempBinaryBetas.put(parent, left, right, newBetas);
}
}
}
unaryBetas = tempUnaryBetas;
binaryBetas = tempBinaryBetas;
}
/**
* Recalculates the betas for all known transitions. The current
* betas are used to produce probabilities, which then are used to
* compute new betas. If splitStates is true, then the
* probabilities produced are as if the states were split again from
* the last time betas were calculated.
* <br>
* The return value is whether or not the betas have mostly
* converged from the last time this method was called. Obviously
* if splitStates was true, the betas will be entirely different, so
* this is false. Otherwise, the new betas are compared against the
* old values, and convergence means they differ by less than
* EPSILON.
*/
public boolean recalculateBetas(boolean splitStates) {
if (splitStates) {
if (DEBUG()) {
System.out.println("Pre-split betas");
outputBetas();
}
splitBetas();
if (DEBUG()) {
System.out.println("Post-split betas");
outputBetas();
}
}
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
recalculateTemporaryBetas(splitStates, null, tempUnaryBetas, tempBinaryBetas);
boolean converged = useNewBetas(!splitStates, tempUnaryBetas, tempBinaryBetas);
if (DEBUG()) {
outputBetas();
}
return converged;
}
public boolean useNewBetas(boolean testConverged,
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
rescaleTemporaryBetas(tempUnaryBetas, tempBinaryBetas);
// if we just split states, we have obviously not converged
boolean converged = testConverged && testConvergence(tempUnaryBetas, tempBinaryBetas);
unaryBetas = tempUnaryBetas;
binaryBetas = tempBinaryBetas;
wordIndex = tempWordIndex;
tagIndex = tempTagIndex;
lex = tempLex;
if (DEBUG()) {
System.out.println("LEXICON");
try {
OutputStreamWriter osw = new OutputStreamWriter(System.out, "utf-8");
lex.writeData(osw);
osw.flush();
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
tempWordIndex = null;
tempTagIndex = null;
tempLex = null;
return converged;
}
/**
* Creates temporary beta data structures and fills them in by
* iterating over the trees.
*/
public void recalculateTemporaryBetas(boolean splitStates, Map<String, double[]> totalStateMass,
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
tempWordIndex = new HashIndex<String>();
tempTagIndex = new HashIndex<String>();
tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
tempLex.initializeTraining(trainSize);
for (Tree tree : trees) {
double weight = treeWeights.getCount(tree);
if (DEBUG()) {
System.out.println("Incrementing trees read: " + weight);
}
tempLex.incrementTreesRead(weight);
recalculateTemporaryBetas(tree, splitStates, totalStateMass, tempUnaryBetas, tempBinaryBetas);
}
tempLex.finishTraining();
}
public boolean testConvergence(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
// now, we check each of the new betas to see if it's close to the
// old value for the same transition. if not, we have not yet
// converged. if all of them are, we have converged.
for (String parentLabel : unaryBetas.firstKeySet()) {
for (String childLabel : unaryBetas.get(parentLabel).keySet()) {
double[][] betas = unaryBetas.get(parentLabel, childLabel);
double[][] newBetas = tempUnaryBetas.get(parentLabel, childLabel);
int parentStates = betas.length;
int childStates = betas[0].length;
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
double oldValue = betas[i][j];
double newValue = newBetas[i][j];
if (Math.abs(newValue - oldValue) > EPSILON) {
return false;
}
}
}
}
}
for (String parentLabel : binaryBetas.firstKeySet()) {
for (String leftLabel : binaryBetas.get(parentLabel).firstKeySet()) {
for (String rightLabel : binaryBetas.get(parentLabel).get(leftLabel).keySet()) {
double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
double[][][] newBetas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
double oldValue = betas[i][j][k];
double newValue = newBetas[i][j][k];
if (Math.abs(newValue - oldValue) > EPSILON) {
return false;
}
}
}
}
}
}
}
return true;
}
public void recalculateTemporaryBetas(Tree tree, boolean splitStates,
Map<String, double[]> totalStateMass,
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
if (DEBUG()) {
System.out.println("Recalculating temporary betas for tree " + tree);
}
double[] stateWeights = { Math.log(treeWeights.getCount(tree)) };
IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
recountTree(tree, splitStates, unaryTransitions, binaryTransitions);
if (DEBUG()) {
System.out.println(" Transitions:");
outputTransitions(tree, unaryTransitions, binaryTransitions);
}
recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions,
totalStateMass, tempUnaryBetas, tempBinaryBetas);
}
public int recalculateTemporaryBetas(Tree tree, double[] stateWeights, int position,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions,
Map<String, double[]> totalStateMass,
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
if (tree.isLeaf()) {
// possible to get here if we have a tree with no structure
return position;
}
if (totalStateMass != null) {
double[] stateTotal = totalStateMass.get(tree.label().value());
if (stateTotal == null) {
stateTotal = new double[stateWeights.length];
totalStateMass.put(tree.label().value(), stateTotal);
}
for (int i = 0; i < stateWeights.length; ++i) {
stateTotal[i] += Math.exp(stateWeights[i]);
}
}
if (tree.isPreTerminal()) {
// fill in our new lexicon here.
String tag = tree.label().value();
String word = tree.children()[0].label().value();
// We smooth by LEX_SMOOTH, if relevant. We rescale so that sum
// of the weights being added to the lexicon stays the same.
double total = 0.0;
for (int state = 0; state < stateWeights.length; ++state) {
total += Math.exp(stateWeights[state]);
}
if (total <= 0.0) {
return position + 1;
}
double scale = 1.0 / (1.0 + LEX_SMOOTH);
double smoothing = total * LEX_SMOOTH / stateWeights.length;
for (int state = 0; state < stateWeights.length; ++state) {
// TODO: maybe optimize all this TaggedWord creation
TaggedWord tw = new TaggedWord(word, state(tag, state));
tempLex.train(tw, position, (Math.exp(stateWeights[state]) + smoothing) * scale);
}
return position + 1;
}
if (tree.children().length == 1) {
String parentLabel = tree.label().value();
String childLabel = tree.children()[0].label().value();
double[][] transitions = unaryTransitions.get(tree);
int parentStates = transitions.length;
int childStates = transitions[0].length;
double[][] betas = tempUnaryBetas.get(parentLabel, childLabel);
if (betas == null) {
betas = new double[parentStates][childStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
betas[i][j] = Double.NEGATIVE_INFINITY;
}
}
tempUnaryBetas.put(parentLabel, childLabel, betas);
}
double[] childWeights = neginfDoubles(childStates);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
double weight = transitions[i][j];
betas[i][j] = SloppyMath.logAdd(betas[i][j], weight + stateWeights[i]);
childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]);
}
}
position = recalculateTemporaryBetas(tree.children()[0], childWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
} else { // length == 2
String parentLabel = tree.label().value();
String leftLabel = tree.children()[0].label().value();
String rightLabel = tree.children()[1].label().value();
double[][][] transitions = binaryTransitions.get(tree);
int parentStates = transitions.length;
int leftStates = transitions[0].length;
int rightStates = transitions[0][0].length;
double[][][] betas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
if (betas == null) {
betas = new double[parentStates][leftStates][rightStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
betas[i][j][k] = Double.NEGATIVE_INFINITY;
}
}
}
tempBinaryBetas.put(parentLabel, leftLabel, rightLabel, betas);
}
double[] leftWeights = neginfDoubles(leftStates);
double[] rightWeights = neginfDoubles(rightStates);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
double weight = transitions[i][j][k];
betas[i][j][k] = SloppyMath.logAdd(betas[i][j][k], weight + stateWeights[i]);
leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]);
rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]);
}
}
}
position = recalculateTemporaryBetas(tree.children()[0], leftWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
position = recalculateTemporaryBetas(tree.children()[1], rightWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
}
return position;
}
public void rescaleTemporaryBetas(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas,
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
for (String parent : tempUnaryBetas.firstKeySet()) {
for (String child : tempUnaryBetas.get(parent).keySet()) {
double[][] betas = tempUnaryBetas.get(parent).get(child);
int parentStates = betas.length;
int childStates = betas[0].length;
for (int i = 0; i < parentStates; ++i) {
double sum = Double.NEGATIVE_INFINITY;
for (int j = 0; j < childStates; ++j) {
sum = SloppyMath.logAdd(sum, betas[i][j]);
}
if (Double.isInfinite(sum)) {
for (int j = 0; j < childStates; ++j) {
betas[i][j] = -Math.log(childStates);
}
} else {
for (int j = 0; j < childStates; ++j) {
betas[i][j] -= sum;
}
}
}
}
}
for (String parent : tempBinaryBetas.firstKeySet()) {
for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
double[][][] betas = tempBinaryBetas.get(parent).get(left).get(right);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
for (int i = 0; i < parentStates; ++i) {
double sum = Double.NEGATIVE_INFINITY;
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
sum = SloppyMath.logAdd(sum, betas[i][j][k]);
}
}
if (Double.isInfinite(sum)) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
betas[i][j][k] = -Math.log(leftStates * rightStates);
}
}
} else {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
betas[i][j][k] -= sum;
}
}
}
}
}
}
}
}
public void recountTree(Tree tree, boolean splitStates,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions) {
IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
recountTree(tree, splitStates, probIn, probOut, unaryTransitions, binaryTransitions);
}
public void recountTree(Tree tree, boolean splitStates,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions) {
recountInside(tree, splitStates, 0, probIn);
if (DEBUG()) {
System.out.println("ROOT PROBABILITY: " + probIn.get(tree)[0]);
}
recountOutside(tree, probIn, probOut);
recountWeights(tree, probIn, probOut, unaryTransitions, binaryTransitions);
}
public void recountWeights(Tree tree,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut,
IdentityHashMap<Tree, double[][]> unaryTransitions,
IdentityHashMap<Tree, double[][][]> binaryTransitions) {
if (tree.isLeaf() || tree.isPreTerminal()) {
return;
}
if (tree.children().length == 1) {
Tree child = tree.children()[0];
String parentLabel = tree.label().value();
String childLabel = child.label().value();
double[][] betas = unaryBetas.get(parentLabel, childLabel);
double[] childInside = probIn.get(child);
double[] parentOutside = probOut.get(tree);
int parentStates = betas.length;
int childStates = betas[0].length;
double[][] transitions = new double[parentStates][childStates];
unaryTransitions.put(tree, transitions);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
transitions[i][j] = parentOutside[i] + childInside[j] + betas[i][j];
}
}
// Renormalize. Note that we renormalize to 1, regardless of
// the original total.
// TODO: smoothing?
for (int i = 0; i < parentStates; ++i) {
double total = Double.NEGATIVE_INFINITY;
for (int j = 0; j < childStates; ++j) {
total = SloppyMath.logAdd(total, transitions[i][j]);
}
// By subtracting off the log total, we make it so the log sum
// of the transitions is 0, meaning the sum of the actual
// transitions is 1. It works if you do the math...
if (Double.isInfinite(total)) {
double transition = -Math.log(childStates);
for (int j = 0; j < childStates; ++j) {
transitions[i][j] = transition;
}
} else {
for (int j = 0; j < childStates; ++j) {
transitions[i][j] = transitions[i][j] - total;
}
}
}
recountWeights(child, probIn, probOut, unaryTransitions, binaryTransitions);
} else { // length == 2
Tree left = tree.children()[0];
Tree right = tree.children()[1];
String parentLabel = tree.label().value();
String leftLabel = left.label().value();
String rightLabel = right.label().value();
double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
double[] leftInside = probIn.get(left);
double[] rightInside = probIn.get(right);
double[] parentOutside = probOut.get(tree);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
double[][][] transitions = new double[parentStates][leftStates][rightStates];
binaryTransitions.put(tree, transitions);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
transitions[i][j][k] = parentOutside[i] + leftInside[j] + rightInside[k] + betas[i][j][k];
}
}
}
// Renormalize. Note that we renormalize to 1, regardless of
// the original total.
// TODO: smoothing?
for (int i = 0; i < parentStates; ++i) {
double total = Double.NEGATIVE_INFINITY;
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
total = SloppyMath.logAdd(total, transitions[i][j][k]);
}
}
// By subtracting off the log total, we make it so the log sum
// of the transitions is 0, meaning the sum of the actual
// transitions is 1. It works if you do the math...
if (Double.isInfinite(total)) {
double transition = -Math.log(leftStates * rightStates);
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
transitions[i][j][k] = transition;
}
}
} else {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
transitions[i][j][k] = transitions[i][j][k] - total;
}
}
}
}
recountWeights(left, probIn, probOut, unaryTransitions, binaryTransitions);
recountWeights(right, probIn, probOut, unaryTransitions, binaryTransitions);
}
}
public void recountOutside(Tree tree,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut) {
double[] rootScores = { 0.0 };
probOut.put(tree, rootScores);
recurseOutside(tree, probIn, probOut);
}
public void recurseOutside(Tree tree,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut) {
if (tree.isLeaf() || tree.isPreTerminal()) {
return;
}
if (tree.children().length == 1) {
recountOutside(tree.children()[0], tree, probIn, probOut);
} else { // length == 2
recountOutside(tree.children()[0], tree.children()[1], tree,
probIn, probOut);
}
}
public void recountOutside(Tree child, Tree parent,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut) {
String parentLabel = parent.label().value();
String childLabel = child.label().value();
double[] parentScores = probOut.get(parent);
double[][] betas = unaryBetas.get(parentLabel, childLabel);
int parentStates = betas.length;
int childStates = betas[0].length;
double[] scores = neginfDoubles(childStates);
probOut.put(child, scores);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
// TODO: no inside scores here, right?
scores[j] = SloppyMath.logAdd(scores[j], betas[i][j] + parentScores[i]);
}
}
recurseOutside(child, probIn, probOut);
}
public void recountOutside(Tree left, Tree right, Tree parent,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut) {
String parentLabel = parent.label().value();
String leftLabel = left.label().value();
String rightLabel = right.label().value();
double[] leftInsideScores = probIn.get(left);
double[] rightInsideScores = probIn.get(right);
double[] parentScores = probOut.get(parent);
double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
double[] leftScores = neginfDoubles(leftStates);
probOut.put(left, leftScores);
double[] rightScores = neginfDoubles(rightStates);
probOut.put(right, rightScores);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
leftScores[j] = SloppyMath.logAdd(leftScores[j], betas[i][j][k] + parentScores[i] + rightInsideScores[k]);
rightScores[k] = SloppyMath.logAdd(rightScores[k], betas[i][j][k] + parentScores[i] + leftInsideScores[j]);
}
}
}
recurseOutside(left, probIn, probOut);
recurseOutside(right, probIn, probOut);
}
public int recountInside(Tree tree, boolean splitStates, int loc,
IdentityHashMap<Tree, double[]> probIn) {
if (tree.isLeaf()) {
throw new RuntimeException();
} else if (tree.isPreTerminal()) {
int stateCount = getStateSplitCount(tree);
String word = tree.children()[0].label().value();
String tag = tree.label().value();
double[] scores = new double[stateCount];
probIn.put(tree, scores);
if (splitStates && !tag.equals(Lexicon.BOUNDARY_TAG)) {
for (int i = 0; i < stateCount / 2; ++i) {
IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex);
double logProb = lex.score(tw, loc, word, null);
double wordWeight = 0.45 + random.nextDouble() * 0.1;
scores[i * 2] = logProb + Math.log(wordWeight);
scores[i * 2 + 1] = logProb + Math.log(1.0 - wordWeight);
if (DEBUG()) {
System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + logProb);
System.out.println(" Log Split -> " + scores[i * 2] + "," + scores[i * 2 + 1]);
}
}
} else {
for (int i = 0; i < stateCount; ++i) {
IntTaggedWord tw = new IntTaggedWord(word, state(tag, i), wordIndex, tagIndex);
double prob = lex.score(tw, loc, word, null);
if (DEBUG()) {
System.out.println("Lexicon log prob " + state(tag, i) + "-" + word + ": " + prob);
}
scores[i] = prob;
}
}
loc = loc + 1;
} else if (tree.children().length == 1) {
loc = recountInside(tree.children()[0], splitStates, loc, probIn);
double[] childScores = probIn.get(tree.children()[0]);
String parentLabel = tree.label().value();
String childLabel = tree.children()[0].label().value();
double[][] betas = unaryBetas.get(parentLabel, childLabel);
int parentStates = betas.length; // size of the first key
int childStates = betas[0].length;
double[] scores = neginfDoubles(parentStates);
probIn.put(tree, scores);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
scores[i] = SloppyMath.logAdd(scores[i], childScores[j] + betas[i][j]);
}
}
if (DEBUG()) {
System.out.println(parentLabel + " -> " + childLabel);
for (int i = 0; i < parentStates; ++i) {
System.out.println(" " + i + ":" + scores[i]);
for (int j = 0; j < childStates; ++j) {
System.out.println(" " + i + "," + j + ": " + betas[i][j] + " | " + Math.exp(betas[i][j]));
}
}
}
} else { // length == 2
loc = recountInside(tree.children()[0], splitStates, loc, probIn);
loc = recountInside(tree.children()[1], splitStates, loc, probIn);
double[] leftScores = probIn.get(tree.children()[0]);
double[] rightScores = probIn.get(tree.children()[1]);
String parentLabel = tree.label().value();
String leftLabel = tree.children()[0].label().value();
String rightLabel = tree.children()[1].label().value();
double[][][] betas = binaryBetas.get(parentLabel, leftLabel, rightLabel);
int parentStates = betas.length;
int leftStates = betas[0].length;
int rightStates = betas[0][0].length;
double[] scores = neginfDoubles(parentStates);
probIn.put(tree, scores);
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
scores[i] = SloppyMath.logAdd(scores[i], leftScores[j] + rightScores[k] + betas[i][j][k]);
}
}
}
if (DEBUG()) {
System.out.println(parentLabel + " -> " + leftLabel + "," + rightLabel);
for (int i = 0; i < parentStates; ++i) {
System.out.println(" " + i + ":" + scores[i]);
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
System.out.println(" " + i + "," + j + "," + k + ": " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
}
}
}
}
}
return loc;
}
public void mergeStates() {
if (op.trainOptions.splitRecombineRate <= 0.0) {
return;
}
// we go through the machinery to sum up the temporary betas,
// counting the total mass
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
Map<String, double[]> totalStateMass = Generics.newHashMap();
recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
// Next, for each tree we count the effect of merging its
// annotations. We only consider the most recently split
// annotations as candidates for merging.
Map<String, double[]> deltaAnnotations = Generics.newHashMap();
for (Tree tree : trees) {
countMergeEffects(tree, totalStateMass, deltaAnnotations);
}
// Now we have a map of the (approximate) likelihood loss from
// merging each state. We merge the ones that provide the least
// benefit, up to the splitRecombineRate
List<Triple<String, Integer, Double>> sortedDeltas =
new ArrayList<Triple<String, Integer, Double>>();
for (String state : deltaAnnotations.keySet()) {
double[] scores = deltaAnnotations.get(state);
for (int i = 0; i < scores.length; ++i) {
sortedDeltas.add(new Triple<String, Integer, Double>(state, i * 2, scores[i]));
}
}
Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {
public int compare(Triple<String, Integer, Double> first,
Triple<String, Integer, Double> second) {
// The most useful splits will have a large loss in
// likelihood if they are merged. Thus, we want those at
// the end of the list. This means we make the comparison
// "backwards", sorting from high to low.
return Double.compare(second.third(), first.third());
}
public boolean equals(Object o) { return o == this; }
});
// for (Triple<String, Integer, Double> delta : sortedDeltas) {
// System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
// }
// System.out.println("-------------");
// Only merge a fraction of the splits based on what the user
// originally asked for
int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
splitsToMerge = Math.max(0, splitsToMerge);
splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
System.out.println();
System.out.println(sortedDeltas);
Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);
recalculateMergedBetas(mergeCorrespondence);
for (Triple<String, Integer, Double> delta : sortedDeltas) {
stateSplitCounts.decrementCount(delta.first(), 1);
}
}
public void recalculateMergedBetas(Map<String, int[]> mergeCorrespondence) {
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
tempWordIndex = new HashIndex<String>();
tempTagIndex = new HashIndex<String>();
tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
tempLex.initializeTraining(trainSize);
for (Tree tree : trees) {
double treeWeight = treeWeights.getCount(tree);
double[] stateWeights = { Math.log(treeWeight) };
tempLex.incrementTreesRead(treeWeight);
IdentityHashMap<Tree, double[][]> oldUnaryTransitions = new IdentityHashMap<Tree, double[][]>();
IdentityHashMap<Tree, double[][][]> oldBinaryTransitions = new IdentityHashMap<Tree, double[][][]>();
recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions);
IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence);
recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions,
null, tempUnaryBetas, tempBinaryBetas);
}
tempLex.finishTraining();
useNewBetas(false, tempUnaryBetas, tempBinaryBetas);
}
/**
* Given a tree and the original set of transition probabilities
* from one state to the next in the tree, along with a list of the
* weights in the tree and a count of the mass in each substate at
* the current node, this method merges the probabilities as
* necessary. The results go into newUnaryTransitions and
* newBinaryTransitions.
*/
public void mergeTransitions(Tree parent,
IdentityHashMap<Tree, double[][]> oldUnaryTransitions,
IdentityHashMap<Tree, double[][][]> oldBinaryTransitions,
IdentityHashMap<Tree, double[][]> newUnaryTransitions,
IdentityHashMap<Tree, double[][][]> newBinaryTransitions,
double[] stateWeights,
Map<String, int[]> mergeCorrespondence) {
if (parent.isPreTerminal() || parent.isLeaf()) {
return;
}
if (parent.children().length == 1) {
double[][] oldTransitions = oldUnaryTransitions.get(parent);
String parentLabel = parent.label().value();
int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;
String childLabel = parent.children()[0].label().value();
int[] childCorrespondence = mergeCorrespondence.get(childLabel);
int childStates = childCorrespondence[childCorrespondence.length - 1] + 1;
// System.out.println("P: " + parentLabel + " " + parentStates +
// " C: " + childLabel + " " + childStates);
// Add up the probabilities of transitioning to each state,
// scaled by the probability of being in a given state to begin
// with. This accounts for when two states in the parent are
// collapsed into one state.
double[][] newTransitions = new double[parentStates][childStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < childStates; ++j) {
newTransitions[i][j] = Double.NEGATIVE_INFINITY;
}
}
newUnaryTransitions.put(parent, newTransitions);
for (int i = 0; i < oldTransitions.length; ++i) {
int ti = parentCorrespondence[i];
for (int j = 0; j < oldTransitions[0].length; ++j) {
int tj = childCorrespondence[j];
// System.out.println(i + " " + ti + " " + j + " " + tj);
newTransitions[ti][tj] = SloppyMath.logAdd(newTransitions[ti][tj], oldTransitions[i][j] + stateWeights[i]);
}
}
// renormalize
for (int i = 0; i < parentStates; ++i) {
double total = Double.NEGATIVE_INFINITY;
for (int j = 0; j < childStates; ++j) {
total = SloppyMath.logAdd(total, newTransitions[i][j]);
}
if (Double.isInfinite(total)) {
for (int j = 0; j < childStates; ++j) {
newTransitions[i][j] = -Math.log(childStates);
}
} else {
for (int j = 0; j < childStates; ++j) {
newTransitions[i][j] -= total;
}
}
}
double[] childWeights = neginfDoubles(oldTransitions[0].length);
for (int i = 0; i < oldTransitions.length; ++i) {
for (int j = 0; j < oldTransitions[0].length; ++j) {
double weight = oldTransitions[i][j];
childWeights[j] = SloppyMath.logAdd(childWeights[j], weight + stateWeights[i]);
}
}
mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, childWeights, mergeCorrespondence);
} else {
double[][][] oldTransitions = oldBinaryTransitions.get(parent);
String parentLabel = parent.label().value();
int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;
String leftLabel = parent.children()[0].label().value();
int[] leftCorrespondence = mergeCorrespondence.get(leftLabel);
int leftStates = leftCorrespondence[leftCorrespondence.length - 1] + 1;
String rightLabel = parent.children()[1].label().value();
int[] rightCorrespondence = mergeCorrespondence.get(rightLabel);
int rightStates = rightCorrespondence[rightCorrespondence.length - 1] + 1;
// System.out.println("P: " + parentLabel + " " + parentStates +
// " L: " + leftLabel + " " + leftStates +
// " R: " + rightLabel + " " + rightStates);
double[][][] newTransitions = new double[parentStates][leftStates][rightStates];
for (int i = 0; i < parentStates; ++i) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
newTransitions[i][j][k] = Double.NEGATIVE_INFINITY;
}
}
}
newBinaryTransitions.put(parent, newTransitions);
for (int i = 0; i < oldTransitions.length; ++i) {
int ti = parentCorrespondence[i];
for (int j = 0; j < oldTransitions[0].length; ++j) {
int tj = leftCorrespondence[j];
for (int k = 0; k < oldTransitions[0][0].length; ++k) {
int tk = rightCorrespondence[k];
// System.out.println(i + " " + ti + " " + j + " " + tj + " " + k + " " + tk);
newTransitions[ti][tj][tk] = SloppyMath.logAdd(newTransitions[ti][tj][tk], oldTransitions[i][j][k] + stateWeights[i]);
}
}
}
// renormalize
for (int i = 0; i < parentStates; ++i) {
double total = Double.NEGATIVE_INFINITY;
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
total = SloppyMath.logAdd(total, newTransitions[i][j][k]);
}
}
if (Double.isInfinite(total)) {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
newTransitions[i][j][k] = -Math.log(leftStates * rightStates);
}
}
} else {
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
newTransitions[i][j][k] -= total;
}
}
}
}
double[] leftWeights = neginfDoubles(oldTransitions[0].length);
double[] rightWeights = neginfDoubles(oldTransitions[0][0].length);
for (int i = 0; i < oldTransitions.length; ++i) {
for (int j = 0; j < oldTransitions[0].length; ++j) {
for (int k = 0; k < oldTransitions[0][0].length; ++k) {
double weight = oldTransitions[i][j][k];
leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight + stateWeights[i]);
rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight + stateWeights[i]);
}
}
}
mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, leftWeights, mergeCorrespondence);
mergeTransitions(parent.children()[1], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, rightWeights, mergeCorrespondence);
}
}
Map<String, int[]> buildMergeCorrespondence(List<Triple<String, Integer, Double>> deltas) {
Map<String, int[]> mergeCorrespondence = Generics.newHashMap();
for (String state : originalStates) {
int states = getStateSplitCount(state);
int[] correspondence = new int[states];
for (int i = 0; i < states; ++i) {
correspondence[i] = i;
}
mergeCorrespondence.put(state, correspondence);
}
for (Triple<String, Integer, Double> merge : deltas) {
int states = getStateSplitCount(merge.first());
int split = merge.second();
int[] correspondence = mergeCorrespondence.get(merge.first());
for (int i = split + 1; i < states; ++i) {
correspondence[i] = correspondence[i] - 1;
}
}
return mergeCorrespondence;
}
public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass,
Map<String, double[]> deltaAnnotations) {
IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
recountTree(tree, false, probIn, probOut, unaryTransitions, binaryTransitions);
// no need to count the root
for (Tree child : tree.children()) {
countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
}
}
public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass,
Map<String, double[]> deltaAnnotations,
IdentityHashMap<Tree, double[]> probIn,
IdentityHashMap<Tree, double[]> probOut) {
if (tree.isLeaf()) {
return;
}
if (tree.label().value().equals(Lexicon.BOUNDARY_TAG)) {
return;
}
String label = tree.label().value();
double totalMass = 0.0;
double[] stateMass = totalStateMass.get(label);
for (double mass : stateMass) {
totalMass += mass;
}
double[] nodeProbIn = probIn.get(tree);
double[] nodeProbOut = probOut.get(tree);
double[] nodeDelta = deltaAnnotations.get(label);
if (nodeDelta == null) {
nodeDelta = new double[nodeProbIn.length / 2];
deltaAnnotations.put(label, nodeDelta);
}
for (int i = 0; i < nodeProbIn.length / 2; ++i) {
double probInMerged = SloppyMath.logAdd(Math.log(stateMass[i * 2] / totalMass) + nodeProbIn[i * 2],
Math.log(stateMass[i * 2 + 1] / totalMass) + nodeProbIn[i * 2 + 1]);
double probOutMerged = SloppyMath.logAdd(nodeProbOut[i * 2], nodeProbOut[i * 2 + 1]);
double probMerged = probInMerged + probOutMerged;
double probUnmerged = SloppyMath.logAdd(nodeProbIn[i * 2] + nodeProbOut[i * 2],
nodeProbIn[i * 2 + 1] + nodeProbOut[i * 2 + 1]);
nodeDelta[i] = nodeDelta[i] + probMerged - probUnmerged;
}
if (tree.isPreTerminal()) {
return;
}
for (Tree child : tree.children()) {
countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
}
}
public void buildStateIndex() {
stateIndex = new HashIndex<String>();
for (String key : stateSplitCounts.keySet()) {
for (int i = 0; i < stateSplitCounts.getIntCount(key); ++i) {
stateIndex.addToIndex(state(key, i));
}
}
}
public void buildGrammars() {
// In order to build the grammars, we first need to fill in the
// temp betas with the sums of the transitions from Ax to By or Ax
// to By,Cz. We also need the sum total of the mass in each state
// Ax over all the trees.
// we go through the machinery to sum up the temporary betas,
// counting the total mass...
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
Map<String, double[]> totalStateMass = Generics.newHashMap();
recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
// ... but note we don't actually rescale the betas.
// instead we use the temporary betas and the total mass in each
// state to calculate the grammars
// First build up a BinaryGrammar.
// The score for each rule will be the Beta scores found earlier,
// scaled by the total weight of a transition between unsplit states
BinaryGrammar bg = new BinaryGrammar(stateIndex);
for (String parent : tempBinaryBetas.firstKeySet()) {
int parentStates = getStateSplitCount(parent);
double[] stateTotal = totalStateMass.get(parent);
for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
int leftStates = getStateSplitCount(left);
for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
int rightStates = getStateSplitCount(right);
double[][][] betas = tempBinaryBetas.get(parent, left, right);
for (int i = 0; i < parentStates; ++i) {
if (stateTotal[i] < EPSILON) {
continue;
}
for (int j = 0; j < leftStates; ++j) {
for (int k = 0; k < rightStates; ++k) {
int parentIndex = stateIndex.indexOf(state(parent, i));
int leftIndex = stateIndex.indexOf(state(left, j));
int rightIndex = stateIndex.indexOf(state(right, k));
double score = betas[i][j][k] - Math.log(stateTotal[i]);
BinaryRule br = new BinaryRule(parentIndex, leftIndex, rightIndex, score);
bg.addRule(br);
}
}
}
}
}
}
// Now build up a UnaryGrammar
UnaryGrammar ug = new UnaryGrammar(stateIndex);
for (String parent : tempUnaryBetas.firstKeySet()) {
int parentStates = getStateSplitCount(parent);
double[] stateTotal = totalStateMass.get(parent);
for (String child : tempUnaryBetas.get(parent).keySet()) {
int childStates = getStateSplitCount(child);
double[][] betas = tempUnaryBetas.get(parent, child);
for (int i = 0; i < parentStates; ++i) {
if (stateTotal[i] < EPSILON) {
continue;
}
for (int j = 0; j < childStates; ++j) {
int parentIndex = stateIndex.indexOf(state(parent, i));
int childIndex = stateIndex.indexOf(state(child, j));
double score = betas[i][j] - Math.log(stateTotal[i]);
UnaryRule ur = new UnaryRule(parentIndex, childIndex, score);
ug.addRule(ur);
}
}
}
}
bgug = new Pair<UnaryGrammar, BinaryGrammar>(ug, bg);
}
public void saveTrees(Collection<Tree> trees1, double weight1,
Collection<Tree> trees2, double weight2) {
trainSize = 0.0;
int treeCount = 0;
trees.clear();
treeWeights.clear();
for (Tree tree : trees1) {
trees.add(tree);
treeWeights.incrementCount(tree, weight1);
trainSize += weight1;
}
treeCount += trees1.size();
if (trees2 != null && weight2 >= 0.0) {
for (Tree tree : trees2) {
trees.add(tree);
treeWeights.incrementCount(tree, weight2);
trainSize += weight2;
}
treeCount += trees2.size();
}
System.err.println("Found " + treeCount +
" trees with total weight " + trainSize);
}
public void extract(Collection<Tree> treeList) {
extract(treeList, 1.0, null, 0.0);
}
/**
* First, we do a few setup steps. We read in all the trees, which
* is necessary because we continually reprocess them and use the
* object pointers as hash keys rather than hashing the trees
* themselves. We then count the initial states in the treebank.
* <br>
* Having done that, we then assign initial probabilities to the
* trees. At first, each state has 1.0 of the probability mass for
* each Ax-ByCz and Ax-By transition. We then split the number of
* states and the probabilities on each tree.
* <br>
* We then repeatedly recalculate the betas and reannotate the
* weights, going until we converge, which is defined as no betas
* move more then epsilon.
* <br>
* java -mx4g edu.stanford.nlp.parser.lexparser.LexicalizedParser -PCFG -saveToSerializedFile englishSplit.ser.gz -saveToTextFile englishSplit.txt -maxLength 40 -train ../data/wsj/wsjtwentytrees.mrg -testTreebank ../data/wsj/wsjtwentytrees.mrg -evals "factDA,tsv" -uwm 0 -hMarkov 0 -vMarkov 0 -simpleBinarizedLabels -noRebinarization -predictSplits -splitTrainingThreads 1 -splitCount 1 -splitRecombineRate 0.5
* <br>
* may also need
* <br>
* -smoothTagsThresh 0
* <br>
* java -mx8g edu.stanford.nlp.parser.lexparser.LexicalizedParser -evals "factDA,tsv" -PCFG -vMarkov 0 -hMarkov 0 -uwm 0 -saveToSerializedFile wsjS1.ser.gz -maxLength 40 -train /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 2200-2219 -compactGrammar 0 -simpleBinarizedLabels -predictSplits -smoothTagsThresh 0 -splitCount 1 -noRebinarization
*/
public void extract(Collection<Tree> trees1, double weight1,
Collection<Tree> trees2, double weight2) {
saveTrees(trees1, weight1, trees2, weight2);
countOriginalStates();
// Initial betas will be 1 for all possible unary and binary
// transitions in our treebank
initialBetasAndLexicon();
for (int cycle = 0; cycle < op.trainOptions.splitCount; ++cycle) {
// All states except the root state get split into 2
splitStateCounts();
// first, recalculate the betas and the lexicon for having split
// the transitions
recalculateBetas(true);
// now, loop until we converge while recalculating betas
// TODO: add a loop counter, stop after X iterations
iteration = 0;
boolean converged = false;
while (!converged && iteration < MAX_ITERATIONS) {
if (DEBUG()) {
System.out.println();
System.out.println();
System.out.println("-------------------");
System.out.println("Iteration " + iteration);
}
converged = recalculateBetas(false);
++iteration;
}
System.err.println("Converged for cycle " + cycle +
" in " + iteration + " iterations");
mergeStates();
}
// Build up the state index. The BG & UG both expect a set count
// of states.
buildStateIndex();
buildGrammars();
}
}