package cc.mallet.fst;
import java.util.logging.Level;
import java.util.logging.Logger;
import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.DenseVector;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
/** Default, full dynamic programming implementation of the Forward-Backward "Sum-(Product)-Lattice" algorithm */
public class SumLatticeDefault implements SumLattice
private static Logger logger = MalletLogger.getLogger(SumLatticeDefault.class.getName());
// Static variables acting as default values for the correspondingly-named instance variables.
// Can be overridden sort of like named parameters, like this:
// SumLattice lattice = new SumLatticeDefault(transducer, input) {{ saveXis=true; }}
protected static boolean saveXis = false;
// "ip" == "input position", "op" == "output position", "i" == "state index"
Transducer t;
double totalWeight;
Sequence input, output;
LatticeNode[][] nodes; // indexed by ip,i
int latticeLength;
double[][] gammas; // indexed by ip,i
double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true;
LabelVector labelings[]; // indexed by op, created only if "outputAlphabet" is non-null in constructor
// Ensure that instances cannot easily be created by a zero arg constructor.
protected SumLatticeDefault() { }
protected LatticeNode getLatticeNode (int ip, int stateIndex)
if (nodes[ip][stateIndex] == null)
nodes[ip][stateIndex] = new LatticeNode (ip, t.getState (stateIndex));
return nodes[ip][stateIndex];
public SumLatticeDefault (Transducer trans, Sequence input)
this (trans, input, null, (Transducer.Incrementor)null, saveXis, null);
public SumLatticeDefault (Transducer trans, Sequence input, boolean saveXis)
this (trans, input, null, (Transducer.Incrementor)null, saveXis, null);
public SumLatticeDefault (Transducer trans, Sequence input, Transducer.Incrementor incrementor)
this (trans, input, null, incrementor, saveXis, null);
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output)
this (trans, input, output, (Transducer.Incrementor)null, saveXis, null);
// You may pass null for output, meaning that the lattice
// is not constrained to match the output
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor)
this (trans, input, output, incrementor, saveXis, null);
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet)
this (trans, input, output, incrementor, saveXis, outputAlphabet);
// You may pass null for output, meaning that the lattice
// is not constrained to match the output
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis)
this (trans, input, output, incrementor, saveXis, null);
// If outputAlphabet is non-null, this will create a LabelVector
// for each position in the output sequence indicating the
// probability distribution over possible outputs at that time
// index
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
assert (output == null || input.size() == output.size());
if (false && logger.isLoggable (Level.FINE)) {
logger.fine ("Starting Lattice");
logger.fine ("Input: ");
for (int ip = 0; ip < input.size(); ip++)
logger.fine (" " + input.get(ip));
logger.fine ("\nOutput: ");
if (output == null)
logger.fine ("null");
for (int op = 0; op < output.size(); op++)
logger.fine (" " + output.get(op));
logger.fine ("\n");
// Initialize some structures
this.t = trans;
this.input = input;
this.output = output;
// xxx Not very efficient when the lattice is actually sparse,
// especially when the number of states is large and the
// sequence is long.
latticeLength = input.size()+1;
int numStates = t.numStates();
nodes = new LatticeNode[latticeLength][numStates];
// xxx Yipes, this could get big; something sparse might be better?
gammas = new double[latticeLength][numStates];
if (saveXis) xis = new double[latticeLength][numStates][numStates];
double outputCounts[][] = null;
if (outputAlphabet != null)
outputCounts = new double[latticeLength][outputAlphabet.size()];
for (int i = 0; i < numStates; i++) {
for (int ip = 0; ip < latticeLength; ip++)
gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
if (saveXis)
for (int j = 0; j < numStates; j++)
for (int ip = 0; ip < latticeLength; ip++)
xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
// Forward pass
logger.fine ("Starting Foward pass");
boolean atLeastOneInitialState = false;
for (int i = 0; i < numStates; i++) {
double initialWeight = t.getState(i).getInitialWeight();
//System.out.println ("Forward pass initialCost = "+initialCost);
if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
getLatticeNode(0, i).alpha = initialWeight;
//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);
atLeastOneInitialState = true;
if (atLeastOneInitialState == false)
logger.warning ("There are no starting states!");
for (int ip = 0; ip < latticeLength-1; ip++)
for (int i = 0; i < numStates; i++) {
if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
// xxx if we end up doing this a lot,
// we could save a list of the non-null ones
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
if (logger.isLoggable (Level.FINE))
logger.fine (" Starting Foward transition iteration from state "
+ s.getName() + " on input " + input.get(ip).toString()
+ " and output "
+ (output==null ? "(null)" : output.get(ip).toString()));
while (iter.hasNext()) {
State destination = iter.nextState();
if (logger.isLoggable (Level.FINE))
logger.fine ("Forward Lattice[inputPos="+ip+"][source="+s.getName()+"][dest="+destination.getName()+"]");
LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());
destinationNode.output = iter.getOutput();
double transitionWeight = iter.getWeight();
if (logger.isLoggable (Level.FINE))
logger.fine ("BEFORE update: destinationNode.alpha="+destinationNode.alpha);
destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionWeight);
if (logger.isLoggable (Level.FINE))
logger.fine ("transitionWeight="+transitionWeight+" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha
+" destinationNode.alpha="+destinationNode.alpha);
//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);
if (logger.isLoggable (Level.FINE)) {
logger.fine("Forward Lattice:");
for (int ip = 0; ip < latticeLength; ip++) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < numStates; i++)
sb.append (" "+(nodes[ip][i] == null ? "<null>" : nodes[ip][i].alpha));
// Calculate total weight of Lattice. This is the normalizer
totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
//System.out.println ("Ending beta, state["+i+"] = "+t.getState(i).getFinalWeight());
totalWeight = Transducer.sumLogProb (totalWeight, (nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight()));
logger.fine ("totalWeight="+totalWeight);
// totalWeight is now an "unnormalized weight" of the entire Lattice
// If the sequence has -infinite weight, just return.
// Usefully this avoids calling any incrementX methods.
// It also relies on the fact that the gammas[][] and .alpha (but not .beta) values
// are already initialized to values that reflect -infinite weight
// TODO Is it important to fill in the betas before we return?
if (totalWeight == Transducer.IMPOSSIBLE_WEIGHT)
// Backward pass
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
State s = t.getState(i);
nodes[latticeLength-1][i].beta = s.getFinalWeight();
gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - totalWeight;
if (incrementor != null) {
double p = Math.exp(gammas[latticeLength-1][i]);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i];
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p+", gamma="+gammas[latticeLength-1][i];
incrementor.incrementFinalState (s, p);
for (int ip = latticeLength-2; ip >= 0; ip--) {
for (int i = 0; i < numStates; i++) {
if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
// Note that skipping here based on alpha means that beta values won't
// be correct, but since alpha is infinite anyway, it shouldn't matter.
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
while (iter.hasNext()) {
State destination = iter.nextState();
if (logger.isLoggable (Level.FINE))
logger.fine ("Backward Lattice[inputPos="+ip+"][source="+s.getName()+"][dest="+destination.getName()+"]");
int j = destination.getIndex();
LatticeNode destinationNode = nodes[ip+1][j];
if (destinationNode != null) {
double transitionWeight = iter.getWeight();
assert (!Double.isNaN(transitionWeight));
double oldBeta = nodes[ip][i].beta;
assert (!Double.isNaN(nodes[ip][i].beta));
nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta, destinationNode.beta + transitionWeight);
assert (!Double.isNaN(nodes[ip][i].beta))
: "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight) + " oldBeta="+oldBeta;
double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - totalWeight;
if (saveXis) xis[ip][i][j] = xi;
assert (!Double.isNaN(nodes[ip][i].alpha));
assert (!Double.isNaN(transitionWeight));
assert (!Double.isNaN(nodes[ip+1][j].beta));
assert (!Double.isNaN(totalWeight));
if (incrementor != null || outputAlphabet != null) {
double p = Math.exp(xi);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+xi;
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p+", xis["+ip+"]["+i+"]["+j+"]="+xi;
if (incrementor != null)
incrementor.incrementTransition(iter, p);
if (outputAlphabet != null) {
int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);
assert (outputIndex >= 0);
// xxx This assumes that "ip" == "op"!
outputCounts[ip][outputIndex] += p;
//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);
gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - totalWeight;
if (incrementor != null)
for (int i = 0; i < numStates; i++) {
double p = Math.exp(gammas[0][i]);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "p="+p;
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p;
incrementor.incrementInitialState(t.getState(i), p);
if (outputAlphabet != null) {
labelings = new LabelVector[latticeLength];
for (int ip = latticeLength-2; ip >= 0; ip--) {
assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);;
labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);
if (logger.isLoggable (Level.FINE)) {
for (int ip = 0; ip < latticeLength; ip++) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < numStates; i++)
sb.append (" "+gammas[ip][i]);
public double[][][] getXis(){
return xis;
public double[][] getGammas(){
return gammas;
public double getTotalWeight () {
assert (!Double.isNaN(totalWeight));
return totalWeight; }
public double getGammaWeight(int inputPosition, State s) {
return gammas[inputPosition][s.getIndex()]; }
public double getGammaWeight(int inputPosition, int stateIndex) {
return gammas[inputPosition][stateIndex]; }
public double getGammaProbability (int inputPosition, State s) {
return Math.exp (gammas[inputPosition][s.getIndex()]); }
public double getGammaProbability (int inputPosition, int stateIndex) {
return Math.exp (gammas[inputPosition][stateIndex]); }
public double getXiProbability (int ip, State s1, State s2) {
if (xis == null)
throw new IllegalStateException ("xis were not saved.");
int i = s1.getIndex ();
int j = s2.getIndex ();
return Math.exp (xis[ip][i][j]);
public double getXiWeight(int ip, State s1, State s2)
if (xis == null)
throw new IllegalStateException ("xis were not saved.");
int i = s1.getIndex ();
int j = s2.getIndex ();
return xis[ip][i][j];
public int length () { return latticeLength; }
public Sequence getInput() {
return input;
public double getAlpha (int ip, State s) {
LatticeNode node = getLatticeNode (ip, s.getIndex ());
return node.alpha;
public double getBeta (int ip, State s) {
LatticeNode node = getLatticeNode (ip, s.getIndex ());
return node.beta;
public LabelVector getLabelingAtPosition (int outputPosition) {
if (labelings != null)
return labelings[outputPosition];
return null;
public Transducer getTransducer ()
return t;
// A container for some information about a particular input position and state
protected class LatticeNode
int inputPosition;
// outputPosition not really needed until we deal with asymmetric epsilon.
State state;
Object output;
double alpha = Transducer.IMPOSSIBLE_WEIGHT;
double beta = Transducer.IMPOSSIBLE_WEIGHT;
LatticeNode (int inputPosition, State state) {
this.inputPosition = inputPosition;
this.state = state;
assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT); // xxx Remove this check
public static class Factory extends SumLatticeFactory implements Serializable
public SumLattice newSumLattice (Transducer trans, Sequence input, Sequence output,
Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
return new SumLatticeDefault (trans, input, output, incrementor, saveXis, outputAlphabet);
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException {
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();