Alphabet inputAlphabet = new Alphabet();
for (int i = 0; i < inputVocabSize; i++)
inputAlphabet.lookupIndex("feature" + i);
Alphabet outputAlphabet = new Alphabet();
CRF crf = new CRF(inputAlphabet, outputAlphabet);
String[] stateNames = new String[numStates];
for (int i = 0; i < numStates; i++)
stateNames[i] = "state" + i;
crf.addFullyConnectedStates(stateNames);
crf.setWeightsDimensionDensely();
crf.getState(0).setInitialWeight(1.0);
crf.getState(1).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
crf.getState(0).setFinalWeight(0.0);
crf.getState(1).setFinalWeight(0.0);
crf.setParameter(0, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state0
// self-transition
crf.setParameter(0, 1, 0, 1.0); // state0->state1
crf.setParameter(1, 1, 0, 1.0); // state1 self-transition
crf.setParameter(1, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state1->state0
FeatureVectorSequence fvs = new FeatureVectorSequence(
new FeatureVector[] {
new FeatureVector((Alphabet) crf.getInputAlphabet(),
new double[] { 1 }),
new FeatureVector((Alphabet) crf.getInputAlphabet(),
new double[] { 1 }),
new FeatureVector((Alphabet) crf.getInputAlphabet(),
new double[] { 1 }), });
SumLattice lattice = new SumLatticeDefault(crf, fvs, true);
// We start in state0
assertTrue(lattice.getGammaProbability(0, crf.getState(0)) == 1.0);
assertTrue(lattice.getGammaProbability(0, crf.getState(1)) == 0.0);
// We go to state1
assertTrue(lattice.getGammaProbability(1, crf.getState(0)) == 0.0);
assertTrue(lattice.getGammaProbability(1, crf.getState(1)) == 1.0);
// And on through a self-transition
assertTrue(lattice
.getXiProbability(1, crf.getState(1), crf.getState(1)) == 1.0);
assertTrue(lattice
.getXiProbability(1, crf.getState(1), crf.getState(0)) == 0.0);
assertTrue("Lattice weight = " + lattice.getTotalWeight(), lattice
.getTotalWeight() == 4.0);
// Gammas at all times sum to 1.0
for (int time = 0; time < lattice.length() - 1; time++) {
double gammasum = lattice
.getGammaProbability(time, crf.getState(0))
+ lattice.getGammaProbability(time, crf.getState(1));
assertEquals("Gammas at time step " + time + " sum to " + gammasum,
1.0, gammasum, 0.0001);
}
// Xis at all times sum to 1.0
for (int time = 0; time < lattice.length() - 1; time++) {
double xissum = lattice.getXiProbability(time, crf.getState(0), crf
.getState(0))
+ lattice.getXiProbability(time, crf.getState(0), crf
.getState(1))
+ lattice.getXiProbability(time, crf.getState(1), crf
.getState(0))
+ lattice.getXiProbability(time, crf.getState(1), crf
.getState(1));
assertEquals("Xis at time step " + time + " sum to " + xissum, 1.0,
xissum, 0.0001);
}
}