package cc.mallet.fst;


import java.util.logging.Level;
import java.util.logging.Logger;

import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;

public class SumLatticeScaling implements SumLattice {
  private static Logger logger = MalletLogger
  protected static boolean saveXis = false;

  // "ip" == "input position", "op" == "output position", "i" == "state index"
  Sequence input, output;
  Transducer t;
  double totalWeight;
  LatticeNode[][] nodes; // indexed by ip,i
  double[] alphaLogScaling, betaLogScaling;
  double zLogScaling;
  int latticeLength;
  double[][] gammas; // indexed by ip,i
  double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true;

  // Ensure that instances cannot easily be created by a zero arg constructor.
  protected SumLatticeScaling() {

  protected LatticeNode getLatticeNode(int ip, int stateIndex) {
    if (nodes[ip][stateIndex] == null)
      nodes[ip][stateIndex] = new LatticeNode(ip, t.getState(stateIndex));
    return nodes[ip][stateIndex];

  public SumLatticeScaling(Transducer trans, Sequence input) {
    this(trans, input, null, (Transducer.Incrementor) null, saveXis, null);

  public SumLatticeScaling(Transducer trans, Sequence input, boolean saveXis) {
    this(trans, input, null, (Transducer.Incrementor) null, saveXis, null);

  public SumLatticeScaling(Transducer trans, Sequence input,
      Transducer.Incrementor incrementor) {
    this(trans, input, null, incrementor, saveXis, null);

  public SumLatticeScaling(Transducer trans, Sequence input, Sequence output) {
    this(trans, input, output, (Transducer.Incrementor) null, saveXis, null);

  // You may pass null for output, meaning that the lattice
  // is not constrained to match the output
  public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
      Transducer.Incrementor incrementor) {
    this(trans, input, output, incrementor, saveXis, null);

  public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
      Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet) {
    this(trans, input, output, incrementor, saveXis, outputAlphabet);

  // You may pass null for output, meaning that the lattice
  // is not constrained to match the output
  public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
      Transducer.Incrementor incrementor, boolean saveXis) {
    this(trans, input, output, incrementor, saveXis, null);

  public SumLatticeScaling(Transducer trans, Sequence input, Sequence output,
      Transducer.Incrementor incrementor, boolean saveXis,
      LabelAlphabet outputAlphabet) {
    assert (output == null || input.size() == output.size());

    // Initialize some structures
    this.t = trans;
    this.input = input;
    this.output = output;
    latticeLength = input.size() + 1;
    int numStates = t.numStates();
    nodes = new LatticeNode[latticeLength][numStates];
    alphaLogScaling = new double[latticeLength];
    betaLogScaling = new double[latticeLength];
    gammas = new double[latticeLength][numStates];
    if (saveXis)
      xis = new double[latticeLength][numStates][numStates];

    double outputCounts[][] = null;
    if (outputAlphabet != null)
      outputCounts = new double[latticeLength][outputAlphabet.size()];

    for (int ip = 0; ip < latticeLength; ip++) {
      alphaLogScaling[ip] = 0.0;
      betaLogScaling[ip] = 0.0;
      for (int i = 0; i < numStates; i++) {
        gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
        if (saveXis)
          for (int j = 0; j < numStates; j++)
            xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;

    // Forward pass
    logger.fine("Starting Foward pass");
    boolean atLeastOneInitialState = false;
    for (int i = 0; i < numStates; i++) {
      double initialWeight = t.getState(i).getInitialWeight();
      if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
        getLatticeNode(0, i).alpha = Math.exp(initialWeight);
        atLeastOneInitialState = true;
    if (atLeastOneInitialState == false)
      logger.warning("There are no starting states!");

    for (int ip = 0; ip < latticeLength - 1; ip++) {
      for (int i = 0; i < numStates; i++) {
        if (isInvalidNode(ip, i))
        State s = t.getState(i);
        TransitionIterator iter = s.transitionIterator(input, ip,
            output, ip);
        while (iter.hasNext()) {
          State destination =;
          LatticeNode destinationNode = getLatticeNode(ip + 1,
          if (Double.isNaN(destinationNode.alpha))
            destinationNode.alpha = 0;
          destinationNode.output = iter.getOutput();
          double transitionWeight = iter.getWeight();
          destinationNode.alpha += nodes[ip][i].alpha
              * Math.exp(transitionWeight);
      // re-scale alphas to so that \sum_i \alpha[ip][i] = 1
      rescaleAlphas(ip + 1);

    // Calculate total weight of Lattice. This is the normalizer
    double Z = Double.NaN;
    for (int i = 0; i < numStates; i++)
      if (nodes[latticeLength - 1][i] != null) {
        if (Double.isNaN(Z))
          Z = 0;
        Z += nodes[latticeLength - 1][i].alpha
            * Math.exp(t.getState(i).getFinalWeight());
    zLogScaling = alphaLogScaling[latticeLength - 1];

    if (Double.isNaN(Z)) {
      totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
    } else
      totalWeight = Math.log(Z) + zLogScaling;

    // Backward pass
    for (int i = 0; i < numStates; i++)
      if (nodes[latticeLength - 1][i] != null) {
        State s = t.getState(i);
        nodes[latticeLength - 1][i].beta = Math.exp(s.getFinalWeight());
        double gamma = nodes[latticeLength - 1][i].alpha
            * nodes[latticeLength - 1][i].beta / Z;
        gammas[latticeLength - 1][i] = Math.log(gamma);
        if (incrementor != null) {
          double p = gamma;
          assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
              + ", gamma=" + gammas[latticeLength - 1][i];
          incrementor.incrementFinalState(s, p);
    rescaleBetas(latticeLength - 1);

    for (int ip = latticeLength - 2; ip >= 0; ip--) {
      for (int i = 0; i < numStates; i++) {
        if (isInvalidNode(ip, i))
        State s = t.getState(i);
        TransitionIterator iter = s.transitionIterator(input, ip,
            output, ip);
        double logScaling = alphaLogScaling[ip]
            + betaLogScaling[ip + 1] - zLogScaling;
        double pscaling = Math.exp(logScaling);
        while (iter.hasNext()) {
          State destination =;
          int j = destination.getIndex();
          LatticeNode destinationNode = nodes[ip + 1][j];
          if (destinationNode != null) {
            double transitionWeight = iter.getWeight();
            if (Double.isNaN(nodes[ip][i].beta))
              nodes[ip][i].beta = 0;
            double transitionProb = Math.exp(transitionWeight);
            nodes[ip][i].beta += destinationNode.beta
                * transitionProb;
            double xi = nodes[ip][i].alpha * transitionProb
                * nodes[ip + 1][j].beta / Z;
            if (saveXis)
              xis[ip][i][j] = Math.log(xi) + logScaling;
            if (incrementor != null || outputAlphabet != null) {
              double p = xi * pscaling;
              assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
                  + ", xis[" + ip + "][" + i + "][" + j
                  + "]=" + xi;
              if (incrementor != null)
                incrementor.incrementTransition(iter, p);
              if (outputAlphabet != null) {
                int outputIndex = outputAlphabet.lookupIndex(
                    iter.getOutput(), false);
                assert (outputIndex >= 0);
                outputCounts[ip][outputIndex] += p;
        gammas[ip][i] = Math.log(nodes[ip][i].alpha * nodes[ip][i].beta
            / Z)
            + logScaling;
      // re-scale betas so that they are normalized
    if (incrementor != null)
      for (int i = 0; i < numStates; i++) {
        double p = Math.exp(gammas[0][i]);
        assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p;
        incrementor.incrementInitialState(t.getState(i), p);

  private boolean isInvalidNode(int ip, int i) {
    return nodes[ip][i] == null || Double.isNaN(nodes[ip][i].alpha);

  private void rescaleAlphas(int ip) {
    double sumAlpha = 0;
    for (int i = 0; i < t.numStates(); i++) {
      if (!isInvalidNode(ip, i))
        sumAlpha += nodes[ip][i].alpha;
    assert sumAlpha > 0 : "Invalid sum over alphas for ip=" + ip;
    alphaLogScaling[ip] = Math.log(sumAlpha)
        + (ip == 0 ? 0 : alphaLogScaling[ip - 1]);
    for (int i = 0; i < t.numStates(); i++) {
      if (!isInvalidNode(ip, i))
        nodes[ip][i].alpha /= sumAlpha;

  private void rescaleBetas(int ip) {
    double sumBeta = 0;
    for (int i = 0; i < t.numStates(); i++) {
      if (!isInvalidNode(ip, i))
        sumBeta += nodes[ip][i].beta;
    assert sumBeta > 0 : "Invalid sum over betas for ip=" + ip;
    betaLogScaling[ip] = Math.log(sumBeta)
        + (ip == latticeLength - 1 ? 0 : betaLogScaling[ip + 1]);
    for (int i = 0; i < t.numStates(); i++) {
      if (!isInvalidNode(ip, i))
        nodes[ip][i].beta /= sumBeta;

  public double[][][] getXis() {
    return xis;

  public double[][] getGammas() {
    return gammas;

  public double getTotalWeight() {
    return totalWeight;

  public double getGammaWeight(int inputPosition, State s) {
    return gammas[inputPosition][s.getIndex()];

  public double getGammaWeight(int inputPosition, int stateIndex) {
    return gammas[inputPosition][stateIndex];

  public double getGammaProbability(int inputPosition, State s) {
    return Math.exp(gammas[inputPosition][s.getIndex()]);

  public double getGammaProbability(int inputPosition, int stateIndex) {
    return getGammaProbability(inputPosition, t.getState(stateIndex));

  public double getXiProbability(int ip, State s1, State s2) {
    return Math.exp(getXiWeight(ip, s1, s2));

  public double getXiWeight(int ip, State s1, State s2) {
    if (xis == null)
      throw new IllegalStateException("xis were not saved.");
    int i = s1.getIndex();
    int j = s2.getIndex();
    return xis[ip][i][j];

  public int length() {
    return latticeLength;

  public double getAlpha(int ip, State s) {
    LatticeNode node = getLatticeNode(ip, s.getIndex());
    return node.alpha * Math.exp(alphaLogScaling[ip]);

  public double getBeta(int ip, State s) {
    LatticeNode node = getLatticeNode(ip, s.getIndex());
    return node.beta * Math.exp(betaLogScaling[ip]);

  public LabelVector getLabelingAtPosition(int outputPosition) {
    throw new RuntimeException("Not implemented for SumLatticeScaling!");

  public Transducer getTransducer() {
    return t;

  protected class LatticeNode {
    int inputPosition;
    State state;
    Object output;
    double alpha = Double.NaN;
    double beta = Double.NaN;

    LatticeNode(int inputPosition, State state) {
      this.inputPosition = inputPosition;
      this.state = state;

  public static class Factory extends SumLatticeFactory implements
      Serializable {
    public SumLattice newSumLattice(Transducer trans, Sequence input,
        Sequence output, Transducer.Incrementor incrementor,
        boolean saveXis, LabelAlphabet outputAlphabet) {
      return new SumLatticeScaling(trans, input, output, incrementor,
          saveXis, outputAlphabet);

    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;

    private void writeObject(ObjectOutputStream out) throws IOException {

    private void readObject(ObjectInputStream in) throws IOException,
        ClassNotFoundException {
      int version = in.readInt();

