package edu.stanford.nlp.parser.shiftreduce;
import junit.framework.TestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.parser.lexparser.Debinarizer;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
/**
* Test the results that come back when you run the ReorderingOracle
* on various inputs
*
* @author John Bauer
*/
public class ReorderingOracleTest extends TestCase {
FinalizeTransition finalize = new FinalizeTransition(Collections.singleton("ROOT"));
ShiftTransition shift = new ShiftTransition();
BinaryTransition rightNP = new BinaryTransition("NP", BinaryTransition.Side.RIGHT);
BinaryTransition tempRightNP = new BinaryTransition("@NP", BinaryTransition.Side.RIGHT);
BinaryTransition leftNP = new BinaryTransition("NP", BinaryTransition.Side.LEFT);
BinaryTransition tempLeftNP = new BinaryTransition("@NP", BinaryTransition.Side.LEFT);
BinaryTransition rightVP = new BinaryTransition("VP", BinaryTransition.Side.RIGHT);
BinaryTransition tempRightVP = new BinaryTransition("@VP", BinaryTransition.Side.RIGHT);
BinaryTransition leftVP = new BinaryTransition("VP", BinaryTransition.Side.LEFT);
BinaryTransition tempLeftVP = new BinaryTransition("@VP", BinaryTransition.Side.LEFT);
BinaryTransition rightS = new BinaryTransition("S", BinaryTransition.Side.RIGHT);
BinaryTransition tempRightS = new BinaryTransition("@S", BinaryTransition.Side.RIGHT);
BinaryTransition leftS = new BinaryTransition("S", BinaryTransition.Side.LEFT);
BinaryTransition tempLeftS = new BinaryTransition("@S", BinaryTransition.Side.LEFT);
UnaryTransition unaryADVP = new UnaryTransition("ADVP", false);
String[] WORDS = { "My", "dog", "also", "likes", "eating", "sausage" };
String[] TAGS = { "PRP$", "NN", "RB", "VBZ", "VBZ", "NN" };
List<TaggedWord> sentence = Sentence.toTaggedList(Arrays.asList(WORDS), Arrays.asList(TAGS));
Tree[] correctTrees = {
Tree.valueOf("(ROOT (S (NP (PRP$ My) (NN dog)) (ADVP (RB also)) (VP (VBZ likes) (S (VP (VBG eating) (NP (NN sausage))))) (. .)))"),
Tree.valueOf("(NP (NP (NN A) (NN B)) (NN C))") , // doesn't have to make sense
Tree.valueOf("(ROOT (S (NP (PRP$ My) (JJ small) (NN dog)) (ADVP (RB also)) (VP (VBZ likes) (S (VP (VBG eating) (NP (NN sausage))))) (. .)))"),
};
List<Tree> binarizedTrees; // initialized in setUp
Tree[] incorrectShiftTrees = {
Tree.valueOf("(ROOT (S (PRP$ My) (NN dog) (ADVP (RB also)) (VP (VBZ likes) (S (VP (VBG eating) (NP (NN sausage))))) (. .)))"),
Tree.valueOf("(NP (NN A) (NN B) (NN C))") , // doesn't have to make sense
Tree.valueOf("(ROOT (S (PRP$ My) (JJ small) (NN dog) (ADVP (RB also)) (VP (VBZ likes) (S (VP (VBG eating) (NP (NN sausage))))) (. .)))"),
};
Debinarizer debinarizer = new Debinarizer(false);
public void setUp() {
Options op = new Options();
Treebank treebank = op.tlpParams.memoryTreebank();
treebank.addAll(Arrays.asList(correctTrees));
binarizedTrees = ShiftReduceParser.binarizeTreebank(treebank, op);
}
public List<Transition> buildTransitionList(Transition ... transitions) {
return Generics.newLinkedList(Arrays.asList(transitions));
}
public void testReorderIncorrectBinaryTransition() {
List<Transition> transitions = buildTransitionList(shift, rightNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectBinaryTransition(transitions));
assertEquals(buildTransitionList(shift, rightVP, finalize), transitions);
transitions = buildTransitionList(shift, unaryADVP, rightNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectBinaryTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, rightVP, finalize), transitions);
transitions = buildTransitionList(shift, rightNP, unaryADVP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectBinaryTransition(transitions));
assertEquals(buildTransitionList(shift, rightVP, finalize), transitions);
}
public void testReorderIncorrectShiftResultingTree() {
for (int testcase = 0; testcase < correctTrees.length; ++testcase) {
State state = ShiftReduceParser.initialStateFromGoldTagTree(correctTrees[testcase]);
List<Transition> gold = CreateTransitionSequence.createTransitionSequence(binarizedTrees.get(testcase));
// System.err.println(correctTrees[testcase]);
// System.err.println(gold);
int tnum = 0;
for (; tnum < gold.size(); ++tnum) {
if (gold.get(tnum) instanceof BinaryTransition) {
break;
}
state = gold.get(tnum).apply(state);
}
state = shift.apply(state);
List<Transition> reordered = Generics.newLinkedList(gold.subList(tnum, gold.size()));
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(reordered));
// System.err.println(reordered);
for (Transition transition : reordered) {
state = transition.apply(state);
}
Tree debinarized = debinarizer.transformTree(state.stack.peek());
// System.err.println(debinarized);
assertEquals(incorrectShiftTrees[testcase].toString(), debinarized.toString());
}
}
public void testReorderIncorrectShift() {
List<Transition> transitions = buildTransitionList(rightNP, shift, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(rightNP, shift, shift, leftNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, leftNP, tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(rightNP, shift, unaryADVP, shift, leftNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(unaryADVP, shift, leftNP, tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(rightNP, shift, shift, unaryADVP, leftNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(leftNP, shift, shift, unaryADVP, leftNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(leftNP, shift, shift, unaryADVP, leftNP, leftVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempLeftVP, leftVP, finalize), transitions);
transitions = buildTransitionList(rightNP, shift, shift, unaryADVP, leftNP, leftVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempLeftVP, rightVP, finalize), transitions);
transitions = buildTransitionList(leftNP, leftNP, shift, shift, unaryADVP, leftNP, rightVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempRightVP, tempRightVP, rightVP, finalize), transitions);
transitions = buildTransitionList(leftNP, rightNP, shift, shift, unaryADVP, leftNP, leftVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempLeftVP, tempLeftVP, rightVP, finalize), transitions);
transitions = buildTransitionList(leftNP, leftNP, shift, shift, unaryADVP, leftNP, leftVP, finalize);
assertTrue(ReorderingOracle.reorderIncorrectShiftTransition(transitions));
assertEquals(buildTransitionList(shift, unaryADVP, leftNP, tempLeftVP, tempLeftVP, leftVP, finalize), transitions);
}
}