package cc.mallet.fst;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Level;
import java.util.logging.Logger;
import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.DenseVector;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
/** Default, full dynamic programming implementation of the Forward-Backward "Sum-(Product)-Lattice" algorithm */
public class SumLatticeDefault implements SumLattice
{
private static Logger logger = MalletLogger.getLogger(SumLatticeDefault.class.getName());
//{logger.setLevel(Level.FINE);}
// Static variables acting as default values for the correspondingly-named instance variables.
// Can be overridden sort of like named parameters, like this:
// SumLattice lattice = new SumLatticeDefault(transducer, input) {{ saveXis=true; }}
protected static boolean saveXis = false;
// "ip" == "input position", "op" == "output position", "i" == "state index"
Transducer t;
double totalWeight;
Sequence input, output;
LatticeNode[][] nodes; // indexed by ip,i
int latticeLength;
double[][] gammas; // indexed by ip,i
double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true;
LabelVector labelings[]; // indexed by op, created only if "outputAlphabet" is non-null in constructor
// Ensure that instances cannot easily be created by a zero arg constructor.
protected SumLatticeDefault() { }
protected LatticeNode getLatticeNode (int ip, int stateIndex)
{
if (nodes[ip][stateIndex] == null)
nodes[ip][stateIndex] = new LatticeNode (ip, t.getState (stateIndex));
return nodes[ip][stateIndex];
}
public SumLatticeDefault (Transducer trans, Sequence input)
{
this (trans, input, null, (Transducer.Incrementor)null, saveXis, null);
}
public SumLatticeDefault (Transducer trans, Sequence input, boolean saveXis)
{
this (trans, input, null, (Transducer.Incrementor)null, saveXis, null);
}
public SumLatticeDefault (Transducer trans, Sequence input, Transducer.Incrementor incrementor)
{
this (trans, input, null, incrementor, saveXis, null);
}
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output)
{
this (trans, input, output, (Transducer.Incrementor)null, saveXis, null);
}
// You may pass null for output, meaning that the lattice
// is not constrained to match the output
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor)
{
this (trans, input, output, incrementor, saveXis, null);
}
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet)
{
this (trans, input, output, incrementor, saveXis, outputAlphabet);
}
// You may pass null for output, meaning that the lattice
// is not constrained to match the output
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis)
{
this (trans, input, output, incrementor, saveXis, null);
}
// If outputAlphabet is non-null, this will create a LabelVector
// for each position in the output sequence indicating the
// probability distribution over possible outputs at that time
// index
public SumLatticeDefault (Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
{
assert (output == null || input.size() == output.size());
if (false && logger.isLoggable (Level.FINE)) {
logger.fine ("Starting Lattice");
logger.fine ("Input: ");
for (int ip = 0; ip < input.size(); ip++)
logger.fine (" " + input.get(ip));
logger.fine ("\nOutput: ");
if (output == null)
logger.fine ("null");
else
for (int op = 0; op < output.size(); op++)
logger.fine (" " + output.get(op));
logger.fine ("\n");
}
// Initialize some structures
this.t = trans;
this.input = input;
this.output = output;
// xxx Not very efficient when the lattice is actually sparse,
// especially when the number of states is large and the
// sequence is long.
latticeLength = input.size()+1;
int numStates = t.numStates();
nodes = new LatticeNode[latticeLength][numStates];
// xxx Yipes, this could get big; something sparse might be better?
gammas = new double[latticeLength][numStates];
if (saveXis) xis = new double[latticeLength][numStates][numStates];
double outputCounts[][] = null;
if (outputAlphabet != null)
outputCounts = new double[latticeLength][outputAlphabet.size()];
for (int i = 0; i < numStates; i++) {
for (int ip = 0; ip < latticeLength; ip++)
gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
if (saveXis)
for (int j = 0; j < numStates; j++)
for (int ip = 0; ip < latticeLength; ip++)
xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
}
// Forward pass
logger.fine ("Starting Foward pass");
boolean atLeastOneInitialState = false;
for (int i = 0; i < numStates; i++) {
double initialWeight = t.getState(i).getInitialWeight();
//System.out.println ("Forward pass initialCost = "+initialCost);
if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
getLatticeNode(0, i).alpha = initialWeight;
//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);
atLeastOneInitialState = true;
}
}
if (atLeastOneInitialState == false)
logger.warning ("There are no starting states!");
for (int ip = 0; ip < latticeLength-1; ip++)
for (int i = 0; i < numStates; i++) {
if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
// xxx if we end up doing this a lot,
// we could save a list of the non-null ones
continue;
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
if (logger.isLoggable (Level.FINE))
logger.fine (" Starting Foward transition iteration from state "
+ s.getName() + " on input " + input.get(ip).toString()
+ " and output "
+ (output==null ? "(null)" : output.get(ip).toString()));
while (iter.hasNext()) {
State destination = iter.nextState();
if (logger.isLoggable (Level.FINE))
logger.fine ("Forward Lattice[inputPos="+ip+"][source="+s.getName()+"][dest="+destination.getName()+"]");
LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());
destinationNode.output = iter.getOutput();
double transitionWeight = iter.getWeight();
if (logger.isLoggable (Level.FINE))
logger.fine ("BEFORE update: destinationNode.alpha="+destinationNode.alpha);
destinationNode.alpha = Transducer.sumLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionWeight);
if (logger.isLoggable (Level.FINE))
logger.fine ("transitionWeight="+transitionWeight+" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha
+" destinationNode.alpha="+destinationNode.alpha);
//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);
}
}
if (logger.isLoggable (Level.FINE)) {
logger.fine("Forward Lattice:");
for (int ip = 0; ip < latticeLength; ip++) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < numStates; i++)
sb.append (" "+(nodes[ip][i] == null ? "<null>" : nodes[ip][i].alpha));
logger.fine(sb.toString());
}
}
// Calculate total weight of Lattice. This is the normalizer
totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);
//System.out.println ("Ending beta, state["+i+"] = "+t.getState(i).getFinalWeight());
totalWeight = Transducer.sumLogProb (totalWeight, (nodes[latticeLength-1][i].alpha + t.getState(i).getFinalWeight()));
}
logger.fine ("totalWeight="+totalWeight);
// totalWeight is now an "unnormalized weight" of the entire Lattice
// If the sequence has -infinite weight, just return.
// Usefully this avoids calling any incrementX methods.
// It also relies on the fact that the gammas[][] and .alpha (but not .beta) values
// are already initialized to values that reflect -infinite weight
// TODO Is it important to fill in the betas before we return?
if (totalWeight == Transducer.IMPOSSIBLE_WEIGHT)
return;
// Backward pass
for (int i = 0; i < numStates; i++)
if (nodes[latticeLength-1][i] != null) {
State s = t.getState(i);
nodes[latticeLength-1][i].beta = s.getFinalWeight();
gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - totalWeight;
if (incrementor != null) {
double p = Math.exp(gammas[latticeLength-1][i]);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i];
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p+", gamma="+gammas[latticeLength-1][i];
incrementor.incrementFinalState (s, p);
}
}
for (int ip = latticeLength-2; ip >= 0; ip--) {
for (int i = 0; i < numStates; i++) {
if (nodes[ip][i] == null || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
// Note that skipping here based on alpha means that beta values won't
// be correct, but since alpha is infinite anyway, it shouldn't matter.
continue;
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, output, ip);
while (iter.hasNext()) {
State destination = iter.nextState();
if (logger.isLoggable (Level.FINE))
logger.fine ("Backward Lattice[inputPos="+ip+"][source="+s.getName()+"][dest="+destination.getName()+"]");
int j = destination.getIndex();
LatticeNode destinationNode = nodes[ip+1][j];
if (destinationNode != null) {
double transitionWeight = iter.getWeight();
assert (!Double.isNaN(transitionWeight));
double oldBeta = nodes[ip][i].beta;
assert (!Double.isNaN(nodes[ip][i].beta));
nodes[ip][i].beta = Transducer.sumLogProb (nodes[ip][i].beta, destinationNode.beta + transitionWeight);
assert (!Double.isNaN(nodes[ip][i].beta))
: "dest.beta="+destinationNode.beta+" trans="+transitionWeight+" sum="+(destinationNode.beta+transitionWeight) + " oldBeta="+oldBeta;
double xi = nodes[ip][i].alpha + transitionWeight + nodes[ip+1][j].beta - totalWeight;
if (saveXis) xis[ip][i][j] = xi;
assert (!Double.isNaN(nodes[ip][i].alpha));
assert (!Double.isNaN(transitionWeight));
assert (!Double.isNaN(nodes[ip+1][j].beta));
assert (!Double.isNaN(totalWeight));
if (incrementor != null || outputAlphabet != null) {
double p = Math.exp(xi);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+xi;
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p+", xis["+ip+"]["+i+"]["+j+"]="+xi;
if (incrementor != null)
incrementor.incrementTransition(iter, p);
if (outputAlphabet != null) {
int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);
assert (outputIndex >= 0);
// xxx This assumes that "ip" == "op"!
outputCounts[ip][outputIndex] += p;
//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);
}
}
}
}
gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - totalWeight;
}
}
if (incrementor != null)
for (int i = 0; i < numStates; i++) {
double p = Math.exp(gammas[0][i]);
// gsc: reducing from 1e-10 to 1e-6
// gsc: removing the isNaN check, range check will catch the NaN error as well
// assert (p >= 0.0 && p <= 1.0+1e-10 && !Double.isNaN(p)) : "p="+p;
assert (p >= 0.0 && p <= 1.0+1e-6) : "p="+p;
incrementor.incrementInitialState(t.getState(i), p);
}
if (outputAlphabet != null) {
labelings = new LabelVector[latticeLength];
for (int ip = latticeLength-2; ip >= 0; ip--) {
assert (Math.abs(1.0-MatrixOps.sum (outputCounts[ip])) < 0.000001);;
labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);
}
}
if (logger.isLoggable (Level.FINE)) {
logger.fine("Lattice:");
for (int ip = 0; ip < latticeLength; ip++) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < numStates; i++)
sb.append (" "+gammas[ip][i]);
logger.fine(sb.toString());
}
}
}
public double[][][] getXis(){
return xis;
}
public double[][] getGammas(){
return gammas;
}
public double getTotalWeight () {
assert (!Double.isNaN(totalWeight));
return totalWeight; }
public double getGammaWeight(int inputPosition, State s) {
return gammas[inputPosition][s.getIndex()]; }
public double getGammaWeight(int inputPosition, int stateIndex) {
return gammas[inputPosition][stateIndex]; }
public double getGammaProbability (int inputPosition, State s) {
return Math.exp (gammas[inputPosition][s.getIndex()]); }
public double getGammaProbability (int inputPosition, int stateIndex) {
return Math.exp (gammas[inputPosition][stateIndex]); }
public double getXiProbability (int ip, State s1, State s2) {
if (xis == null)
throw new IllegalStateException ("xis were not saved.");
int i = s1.getIndex ();
int j = s2.getIndex ();
return Math.exp (xis[ip][i][j]);
}
public double getXiWeight(int ip, State s1, State s2)
{
if (xis == null)
throw new IllegalStateException ("xis were not saved.");
int i = s1.getIndex ();
int j = s2.getIndex ();
return xis[ip][i][j];
}
public int length () { return latticeLength; }
public double getAlpha (int ip, State s) {
LatticeNode node = getLatticeNode (ip, s.getIndex ());
return node.alpha;
}
public double getBeta (int ip, State s) {
LatticeNode node = getLatticeNode (ip, s.getIndex ());
return node.beta;
}
public LabelVector getLabelingAtPosition (int outputPosition) {
if (labelings != null)
return labelings[outputPosition];
return null;
}
public Transducer getTransducer ()
{
return t;
}
// A container for some information about a particular input position and state
protected class LatticeNode
{
int inputPosition;
// outputPosition not really needed until we deal with asymmetric epsilon.
State state;
Object output;
double alpha = Transducer.IMPOSSIBLE_WEIGHT;
double beta = Transducer.IMPOSSIBLE_WEIGHT;
LatticeNode (int inputPosition, State state) {
this.inputPosition = inputPosition;
this.state = state;
assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT); // xxx Remove this check
}
}
public static class Factory extends SumLatticeFactory implements Serializable
{
public SumLattice newSumLattice (Transducer trans, Sequence input, Sequence output,
Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet)
{
return new SumLatticeDefault (trans, input, output, incrementor, saveXis, outputAlphabet);
}
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException {
out.writeInt(CURRENT_SERIAL_VERSION);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
}
}
}