/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Fernando Pereira <a href="mailto:pereira@cis.upenn.edu">pereira@cis.upenn.edu</a>
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
package cc.mallet.fst;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import cc.mallet.types.ArraySequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.SequencePairAlignment;
import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.search.AStar;
import cc.mallet.util.search.AStarState;
import cc.mallet.util.search.SearchNode;
import cc.mallet.util.search.SearchState;
/** Default, full dynamic programming version of the Viterbi "Max-(Product)-Lattice" algorithm.
*
* @author Fernando Pereira
* @author Andrew McCallum
*/
public class MaxLatticeDefault implements MaxLattice
{
private static Logger logger = MalletLogger.getLogger(MaxLatticeDefault.class.getName());
//{ logger.setLevel(Level.INFO); }
private Transducer t;
private Sequence<Object> input, providedOutput;
private int latticeLength;
private ViterbiNode[][] lattice;
private WeightCache first, last;
private WeightCache[] caches;
private int numCaches, maxCaches;
public Transducer getTransducer () { return t; }
public Sequence getInput() { return input; }
public Sequence getProvidedOutput() { return providedOutput; }
private class ViterbiNode implements AStarState {
int inputPosition; // Position of input used to enter this node
State state; // Transducer state from which this node entered
Object output; // Transducer output produced on entering this node
double delta = Transducer.IMPOSSIBLE_WEIGHT;
ViterbiNode maxWeightPredecessor = null;
ViterbiNode (int inputPosition, State state) {
this.inputPosition = inputPosition;
this.state = state;
}
// The one method required by AStarState
public double completionCost () { return -delta; }
public boolean isFinal() {
return inputPosition == 0 && state.getInitialWeight() > Transducer.IMPOSSIBLE_WEIGHT;
}
private class PreviousStateIterator extends AStarState.NextStateIterator {
private int prev;
private boolean found;
private double weight;
private double[] weights;
private PreviousStateIterator() {
prev = 0;
if (inputPosition > 0) {
int j = state.getIndex();
weights = new double[t.numStates()];
WeightCache c = getCache(inputPosition-1);
for (int s = 0; s < t.numStates(); s++)
weights[s] = c.weight[s][j];
}
}
private void lookAhead() {
if (weights != null && !found) {
for (; prev < t.numStates(); prev++)
if (weights[prev] > Transducer.IMPOSSIBLE_WEIGHT) {
found = true;
return;
}
}
}
public boolean hasNext() {
lookAhead();
return weights != null && prev < t.numStates();
}
public SearchState nextState() {
lookAhead();
weight = weights[prev++];
found = false;
return getViterbiNode(inputPosition-1, prev-1);
}
// Required by SearchState, super-interface of AStarState
public double cost() {
return -weight;
}
public double weight() {
return weight;
}
}
public NextStateIterator getNextStates() {
return new PreviousStateIterator();
}
}
private class WeightCache {
private WeightCache prev, next;
private double weight[][];
private int position;
private WeightCache(int position) {
weight = new double[t.numStates()][t.numStates()];
init(position);
}
private void init(int position) {
this.position = position;
for (int i = 0; i < t.numStates(); i++)
for (int j = 0; j < t.numStates(); j++)
weight[i][j] = Transducer.IMPOSSIBLE_WEIGHT;
}
}
private WeightCache getCache(int position) {
WeightCache cache = caches[position];
if (cache == null) { // No cache for this position
// System.out.println("cache " + numCaches + "/" + maxCaches);
if (numCaches < maxCaches) { // Create another cache
cache = new WeightCache(position);
if (numCaches++ == 0)
first = last = cache;
}
else { // Steal least used cache
cache = last;
caches[cache.position] = null;
cache.init(position);
}
for (int i = 0; i < t.numStates(); i++) {
if (lattice[position][i] == null || lattice[position][i].delta == Transducer.IMPOSSIBLE_WEIGHT)
continue;
State s = t.getState(i);
TransitionIterator iter =
s.transitionIterator (input, position, providedOutput, position);
while (iter.hasNext()) {
State d = iter.next();
cache.weight[i][d.getIndex()] = iter.getWeight();
}
}
caches[position] = cache;
}
if (cache != first) { // Move to front
if (cache == last)
last = cache.prev;
if (cache.prev != null)
cache.prev.next = cache.next;
cache.next = first;
cache.prev = null;
first.prev = cache;
first = cache;
}
return cache;
}
protected ViterbiNode getViterbiNode (int ip, int stateIndex)
{
if (lattice[ip][stateIndex] == null)
lattice[ip][stateIndex] = new ViterbiNode (ip, t.getState (stateIndex));
return lattice[ip][stateIndex];
}
public MaxLatticeDefault (Transducer t, Sequence inputSequence)
{
this (t, inputSequence, null, 100000);
}
public MaxLatticeDefault (Transducer t, Sequence inputSequence, Sequence outputSequence)
{
this (t, inputSequence, outputSequence, 100000);
}
/** Initiate Viterbi decoding of the inputSequence, contrained to match non-null parts of the outputSequence.
* maxCaches indicates how much state information to memoize in n-best decoding. */
public MaxLatticeDefault (Transducer t, Sequence inputSequence, Sequence outputSequence, int maxCaches)
{
// This method initializes the forward path, but does not yet do the backward pass.
this.t = t;
if (maxCaches < 1)
maxCaches = 1;
this.maxCaches = maxCaches;
assert (inputSequence != null);
if (logger.isLoggable (Level.FINE)) {
logger.fine ("Starting ViterbiLattice");
logger.fine ("Input: ");
for (int ip = 0; ip < inputSequence.size(); ip++)
logger.fine (" " + inputSequence.get(ip));
logger.fine ("\nOutput: ");
if (outputSequence == null)
logger.fine ("null");
else
for (int op = 0; op < outputSequence.size(); op++)
logger.fine (" " + outputSequence.get(op));
logger.fine ("\n");
}
this.input = inputSequence;
this.providedOutput = outputSequence;
latticeLength = input.size()+1;
int numStates = t.numStates();
lattice = new ViterbiNode[latticeLength][numStates];
caches = new WeightCache[latticeLength-1];
// Viterbi Forward
logger.fine ("Starting Viterbi");
boolean anyInitialState = false;
for (int i = 0; i < numStates; i++) {
double initialWeight = t.getState(i).getInitialWeight();
if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
ViterbiNode n = getViterbiNode (0, i);
n.delta = initialWeight;
anyInitialState = true;
}
}
if (!anyInitialState) {
logger.warning ("Viterbi: No initial states!");
}
for (int ip = 0; ip < latticeLength-1; ip++)
for (int i = 0; i < numStates; i++) {
if (lattice[ip][i] == null || lattice[ip][i].delta == Transducer.IMPOSSIBLE_WEIGHT)
continue;
State s = t.getState(i);
TransitionIterator iter = s.transitionIterator (input, ip, providedOutput, ip);
if (logger.isLoggable (Level.FINE))
logger.fine (" Starting Viterbi transition iteration from state "
+ s.getName() + " on input " + input.get(ip));
while (iter.hasNext()) {
State destination = iter.next();
if (logger.isLoggable (Level.FINE))
logger.fine ("Viterbi[inputPos="+ip
+"][source="+s.getName()
+"][dest="+destination.getName()+"]");
ViterbiNode destinationNode = getViterbiNode (ip+1, destination.getIndex());
destinationNode.output = iter.getOutput();
double weight = lattice[ip][i].delta + iter.getWeight();
if (ip == latticeLength-2) {
weight += destination.getFinalWeight();
}
if (weight > destinationNode.delta) {
if (logger.isLoggable (Level.FINE))
logger.fine ("Viterbi[inputPos="+ip
+"][source][dest="+destination.getName()
+"] weight increased to "+weight+" by source="+
s.getName());
destinationNode.delta = weight;
destinationNode.maxWeightPredecessor = lattice[ip][i];
}
}
}
}
public double getDelta (int ip, int stateIndex) {
if (lattice != null) {
return getViterbiNode (ip, stateIndex).delta;
}
throw new RuntimeException ("Attempt to called getDelta() when lattice not stored.");
}
private List<SequencePairAlignment<Object,ViterbiNode>> viterbiNodeAlignmentCache = null;
/**
* Perform the backward pass of Viterbi, returning the n-best sequences of
* ViterbiNodes. Each ViterbiNode contains the state, output symbol, and other
* information. Note that the length of each ViterbiNode Sequence is
* inputLength+1, because the first element of the sequence is the start
* state, and the first input/output symbols occur on the transition from a
* start-state to the next state. These first input/output symbols are stored
* in the second ViterbiNode in the sequence. The last ViterbiNode in the
* sequence corresponds to the final state and has the last input/output
* symbols.
*/
public List<SequencePairAlignment<Object,ViterbiNode>> bestViterbiNodeSequences (int n) {
if (viterbiNodeAlignmentCache != null && viterbiNodeAlignmentCache.size() >= n)
return viterbiNodeAlignmentCache;
int numFinal = 0;
for (int i = 0; i < t.numStates(); i++) {
if (lattice[latticeLength-1][i] != null && lattice[latticeLength-1][i].delta > Transducer.IMPOSSIBLE_WEIGHT)
numFinal++;
}
ViterbiNode[] finalNodes = new ViterbiNode[numFinal];
int f = 0;
for (int i = 0; i < t.numStates(); i++) {
if (lattice[latticeLength-1][i] != null && lattice[latticeLength-1][i].delta > Transducer.IMPOSSIBLE_WEIGHT)
finalNodes[f++] = lattice[latticeLength-1][i];
}
AStar search = new AStar(finalNodes, latticeLength * t.numStates());
List<SequencePairAlignment<Object,ViterbiNode>> outputs = new ArrayList<SequencePairAlignment<Object,ViterbiNode>>(n);
for (int i = 0; i < n && search.hasNext(); i++) {
// gsc: removing unnecessary cast
SearchNode ans = search.next();
double weight = -ans.getCost();
ViterbiNode[] seq = new ViterbiNode[latticeLength];
// Commented out so we get the start state ViterbiNode -akm 12/2007
//ans = ans.getParent(); // ans now corresponds to the Viterbi node after the first transition
for (int j = 0; j < latticeLength; j++) {
ViterbiNode v = (ViterbiNode)ans.getState();
assert(v.inputPosition == j); // was == j+1
seq[j] = v;
ans = ans.getParent();
}
outputs.add(new SequencePairAlignment<Object,ViterbiNode>(input, new ArraySequence<ViterbiNode>(seq), weight));
}
viterbiNodeAlignmentCache = outputs;
return outputs;
}
private List<SequencePairAlignment<Object,State>> stateAlignmentCache = null;
/**
* Perform the backward pass of Viterbi, returning the n-best sequences of
* States. Note that the length of each State Sequence is inputLength+1,
* because the first element of the sequence is the start state, and the first
* input/output symbols occur on the transition from a start state to the next
* state. The last State in the sequence corresponds to the final state.
*/
public List<SequencePairAlignment<Object,State>> bestStateAlignments (int n) {
if (stateAlignmentCache != null && stateAlignmentCache.size() >= n)
return stateAlignmentCache;
bestViterbiNodeSequences(n); // ensure that viterbiNodeAlignmentCache has at least size n
ArrayList<SequencePairAlignment<Object,State>> ret = new ArrayList<SequencePairAlignment<Object,State>>(n);
for (int i = 0; i < n; i++) {
State[] ss = new State[latticeLength];
Sequence<ViterbiNode> vs = viterbiNodeAlignmentCache.get(i).output();
for (int j = 0; j < latticeLength; j++)
ss[j] = vs.get(j).state; // Here is where we grab the state from the ViterbiNode
ret.add(new SequencePairAlignment<Object,State>(input, new ArraySequence<State>(ss), viterbiNodeAlignmentCache.get(i).getWeight()));
}
stateAlignmentCache = ret;
return ret;
}
public SequencePairAlignment<Object,State> bestStateAlignment () {
return bestStateAlignments(1).get(0);
}
public List<Sequence<State>> bestStateSequences(int n) {
List<SequencePairAlignment<Object,State>> a = bestStateAlignments(n);
ArrayList<Sequence<State>> ret = new ArrayList<Sequence<State>>(n);
for (int i = 0; i < n; i++)
ret.add (a.get(i).output());
return ret;
}
public Sequence<State> bestStateSequence() {
return bestStateAlignments(1).get(0).output();
}
private List<SequencePairAlignment<Object,Object>> outputAlignmentCache = null;
public List<SequencePairAlignment<Object,Object>> bestOutputAlignments (int n) {
if (outputAlignmentCache != null && outputAlignmentCache.size() >= n)
return outputAlignmentCache;
bestViterbiNodeSequences(n); // ensure that viterbiNodeAlignmentCache has at least size n
ArrayList<SequencePairAlignment<Object,Object>> ret = new ArrayList<SequencePairAlignment<Object,Object>>(n);
for (int i = 0; i < n; i++) {
Object[] ss = new Object[latticeLength-1];
Sequence<ViterbiNode> vs = viterbiNodeAlignmentCache.get(i).output();
for (int j = 0; j < latticeLength-1; j++)
ss[j] = vs.get(j+1).output; // Here is where we grab the output from the ViterbiNode destination
ret.add(new SequencePairAlignment<Object,Object>(input, new ArraySequence<Object>(ss), viterbiNodeAlignmentCache.get(i).getWeight()));
}
outputAlignmentCache = ret;
return ret;
}
public SequencePairAlignment<Object,Object> bestOutputAlignment () {
return bestOutputAlignments(1).get(0);
}
public List<Sequence<Object>> bestOutputSequences (int n) {
bestOutputAlignments(n); // ensure that outputAlignmentCache has at least size n
ArrayList<Sequence<Object>> ret = new ArrayList<Sequence<Object>>(n);
for (int i = 0; i < n; i++)
ret.add (outputAlignmentCache.get(i).output());
return ret;
// TODO consider caching this result
}
public Sequence<Object> bestOutputSequence () {
return bestOutputAlignments(1).get(0).output();
}
public double bestWeight() {
return bestOutputAlignments(1).get(0).getWeight();
}
/** Increment states and transitions with a count of 1.0 along the best state sequence.
* This provides for a so-called "Viterbi training" approximation. */
public void incrementTransducer (Transducer.Incrementor incrementor)
{
// We are only going to increment along the single best path ".get(0)" below.
// We could consider having a version of this method:
// incrementTransducer(Transducer.Incrementor incrementor, double[] counts)
// where the number of n-best paths to increment would be determined by counts.length
SequencePairAlignment<Object,ViterbiNode> viterbiNodeAlignment = this.bestViterbiNodeSequences(1).get(0);
int sequenceLength = viterbiNodeAlignment.output().size();
assert (sequenceLength == viterbiNodeAlignment.input().size()); // Not sure this works for unequal input/output lengths
// Increment the initial state
incrementor.incrementInitialState(viterbiNodeAlignment.output().get(0).state, 1.0);
// Increment the final state
incrementor.incrementFinalState(viterbiNodeAlignment.output().get(sequenceLength-1).state, 1.0);
for (int ip = 0; ip < viterbiNodeAlignment.input().size()-1; ip++) {
TransitionIterator iter =
viterbiNodeAlignment.output().get(ip).state.transitionIterator (input, ip, providedOutput, ip);
// xxx This assumes that a transition is completely
// identified, and made unique by its destination state and
// output. This may not be true!
int numIncrements = 0;
while (iter.hasNext()) {
if (iter.next().equals (viterbiNodeAlignment.output().get(ip+1).state)
&& iter.getOutput().equals (viterbiNodeAlignment.output().get(ip).output)) {
incrementor.incrementTransition(iter, 1.0);
numIncrements++;
}
}
if (numIncrements > 1)
throw new IllegalStateException ("More than one satisfying transition found.");
if (numIncrements == 0)
throw new IllegalStateException ("No satisfying transition found.");
}
}
public double elementwiseAccuracy (Sequence referenceOutput)
{
int accuracy = 0;
Sequence output = bestOutputSequence();
assert (referenceOutput.size() == output.size());
for (int i = 0; i < output.size(); i++) {
//logger.fine("tokenAccuracy: ref: "+referenceOutput.get(i)+" viterbi: "+output.get(i));
if (referenceOutput.get(i).toString().equals (output.get(i).toString())) {
accuracy++;
}
}
logger.info ("Number correct: " + accuracy + " out of " + output.size());
return ((double)accuracy)/output.size();
}
public double tokenAccuracy (Sequence referenceOutput, PrintWriter out)
{
Sequence output = bestOutputSequence();
int accuracy = 0;
String testString;
assert (referenceOutput.size() == output.size());
for (int i = 0; i < output.size(); i++) {
//logger.fine("tokenAccuracy: ref: "+referenceOutput.get(i)+" viterbi: "+output.get(i));
testString = output.get(i).toString();
if (out != null) {
out.println(testString);
}
if (referenceOutput.get(i).toString().equals (testString)) {
accuracy++;
}
}
logger.info ("Number correct: " + accuracy + " out of " + output.size());
return ((double)accuracy)/output.size();
}
public static class Factory extends MaxLatticeFactory implements Serializable
{
public MaxLattice newMaxLattice (Transducer trans, Sequence inputSequence, Sequence outputSequence)
{
return new MaxLatticeDefault (trans, inputSequence, outputSequence);
}
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException {
out.writeInt(CURRENT_SERIAL_VERSION);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt();
}
}
}