logger.fine ("-INFINITE weight or NULL...skipping");
continue;
}
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
if (logger.isLoggable (Level.FINE))
logger.fine (" Starting Forward transition iteration from state "
+ s.getName() + " on input " + input.get(ip).toString()
+ " and output "
+ (output==null ? "(null)" : output.get(ip).toString()));
while (iter.hasNext()) {
State destination = iter.nextState();
boolean legalTransition = true;
// check constraints to see if node at <ip,i> can transition to destination
if (ip+1 < constraints.length && constraints[ip+1] > 0 && ((constraints[ip+1]-1) != destination.getIndex())) {
logger.fine ("Destination state does not match positive constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]-1)+", source ="+i+", destination="+destination.getIndex());
legalTransition = false;
}
else if (((ip+1) < constraints.length) && constraints[ip+1] < 0 && (-(constraints[ip+1]+1) == destination.getIndex())) {
logger.fine ("Destination state does not match negative constraint. Assigning -infinite weight. position="+(ip+1)+", constraint="+(constraints[ip+1]+1)+", destination="+destination.getIndex());
legalTransition = false;
}
if (logger.isLoggable (Level.FINE))
logger.fine ("Forward Lattice[inputPos="+ip
+"][source="+s.getName()
+"][dest="+destination.getName()+"]");
LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());
destinationNode.output = iter.getOutput();
double transitionWeight = iter.getWeight();
if (legalTransition) {
//if (logger.isLoggable (Level.FINE))
logger.fine ("transitionWeight="+transitionWeight
+" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha
+" destinationNode.alpha="+destinationNode.alpha);
destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha,
nodes[ip][i].alpha + transitionWeight);
//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);
logger.fine ("Set alpha of latticeNode at ip = "+ (ip+1) + " stateIndex = " + destination.getIndex() + ", destinationNode.alpha = " + destinationNode.alpha);
}
else {
// this is an illegal transition according to our
// constraints, so set its prob to 0 . NO, alpha's are
// unnormalized weights...set to -Inf //
// destinationNode.alpha = 0.0;
// destinationNode.alpha = IMPOSSIBLE_WEIGHT;
logger.fine ("Illegal transition from state " + i + " to state " + destination.getIndex() + ". Setting alpha to -Inf");
}
}
}
// Calculate total weight of Lattice. This is the normalizer
totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
// Note: actually we could sum at any ip index,
// the choice of latticeLength-1 is arbitrary
//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
//System.out.println ("Ending beta, state["+i+"] = "+getState(i).finalWeight);
if (constraints[latticeLength-1] > 0 && i != constraints[latticeLength-1]-1)
continue;
if (constraints[latticeLength-1] < 0 && -i == constraints[latticeLength-1]+1)
continue;
logger.fine ("Summing final lattice weight. state="+i+", alpha="+nodes[latticeLength-1][i].alpha + ", final weight = "+t.getState(i).getFinalWeight());
totalWeight = Transducer.sumLogProb (totalWeight,
(nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight()));
}
// Weight is now an "unnormalized weight" of the entire Lattice
//assert (weight >= 0) : "weight = "+weight;
// If the sequence has -infinite weight, just return.
// Usefully this avoids calling any incrementX methods.
// It also relies on the fact that the gammas[][] and .alpha and .beta values
// are already initialized to values that reflect -infinite weight
// xxx Although perhaps not all (alphas,betas) exactly correctly reflecting?
if (totalWeight == Transducer.IMPOSSIBLE_WEIGHT)
return;
// Backward pass
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
State s = t.getState(i);
nodes[latticeLength-1][i].beta = s.getFinalWeight();
gammas[latticeLength-1][i] =
nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - totalWeight;
if (incrementor != null) {
double p = Math.exp(gammas[latticeLength-1][i]);
assert (p >= 0.0 && p <= 1.0 && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i];
incrementor.incrementFinalState(s, p);
}
}
for (int ip = latticeLength-2; ip >= 0; ip--) {
for (int i = 0; i < numStates; i++) {
if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
// Note that skipping here based on alpha means that beta values won't
// be correct, but since alpha is -infinite anyway, it shouldn't matter.
continue;
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
while (iter.hasNext()) {
State destination = iter.nextState();
if (logger.isLoggable (Level.FINE))
logger.fine ("Backward Lattice[inputPos="+ip
+"][source="+s.getName()
+"][dest="+destination.getName()+"]");
int j = destination.getIndex();
LatticeNode destinationNode = nodes[ip+1][j];
if (destinationNode != null) {
double transitionWeight = iter.getWeight();
assert (!Double.isNaN(transitionWeight));
// assert (transitionWeight >= 0); Not necessarily
double oldBeta = nodes[ip][i].beta;
assert (!Double.isNaN(nodes[ip][i].beta));
nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta,
destinationNode.beta + transitionWeight);
assert (!Double.isNaN(nodes[ip][i].beta))
: "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight)
+ " oldBeta="+oldBeta;
// xis[ip][i][j] = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - weight;
assert (!Double.isNaN(nodes[ip][i].alpha));
assert (!Double.isNaN(transitionWeight));
assert (!Double.isNaN(nodes[ip+1][j].beta));
assert (!Double.isNaN(totalWeight));
if (incrementor != null || outputAlphabet != null) {
double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - totalWeight;
double p = Math.exp(xi);
assert (p > Transducer.IMPOSSIBLE_WEIGHT && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+xi;
if (incrementor != null)
incrementor.incrementTransition(iter, p);
if (outputAlphabet != null) {
int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);
assert (outputIndex >= 0);
// xxx This assumes that "ip" == "op"!
outputCounts[ip][outputIndex] += p;
//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);
}