Package de.jungblut.math

Source Code of de.jungblut.math.ViterbiUtils

package de.jungblut.math;

import java.util.Iterator;

import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;

/**
* Viterbi Utilities for forward backward passes and his famous decoding
* algorithm for hidden markov models.
*
* @author thomas.jungblut
*
*/
public final class ViterbiUtils {

  private ViterbiUtils() {
    throw new IllegalAccessError();
  }

  /**
   * Do a decoding pass on the given HMM weights, the features to decode and how
   * many classes to predict. The output will contain a vector that contains a 1
   * at the index of the predicted label.
   *
   * @param weights the HMM weights.
   * @param features the features to predict on.
   * @param featuresPerState the matrix containing the feature vectors,
   *          precomputed for each possible state in classes. The layout is that
   *          the same feature was computed n-times, so class 0 first, class 1
   *          next and so on and this is layed out in rows (Feature 1 | class 0,
   *          Feature 1 | class 1 ...). Feature 0 is only contained once,
   *          because it only had class zero as previous class.
   * @param classes how many classes? 2 if binary.
   * @return a n x m matrix where n is the number of featurevectors and m is the
   *         number of classes (in binary prediction this is just 1, 0 and 1 are
   *         the predicted labels at index 0 then).
   */
  public static DoubleMatrix decode(DoubleMatrix weights,
      DoubleMatrix features, DoubleMatrix featuresPerState, int classes) {
    final int m = features.getRowCount();
    int[][] backpointers = new int[m][classes];
    double[][] scores = new double[m][classes];

    // define the starting label as 0.
    int prevLabel = 0;
    double[] localScores = computeScores(classes, features.getRowVector(0),
        weights);

    int position = 0;
    for (int currLabel = 0; currLabel < localScores.length; currLabel++) {
      backpointers[position][currLabel] = prevLabel;
      scores[position][currLabel] = localScores[currLabel];
    }

    // for each position in data
    for (position = 1; position < m; position++) {
      int i = position * classes - 1;
      // for each possible previous label
      for (int j = 0; j < classes; j++) {
        prevLabel = j;
        localScores = computeScores(classes,
            featuresPerState.getRowVector(i + j), weights);
        for (int currLabel = 0; currLabel < localScores.length; currLabel++) {
          double score = localScores[currLabel]
              + scores[position - 1][prevLabel];
          if (prevLabel == 0 || score > scores[position][currLabel]) {
            backpointers[position][currLabel] = prevLabel;
            scores[position][currLabel] = score;
          }
        }
      }
    }

    int bestLabel = 0;
    double bestScore = scores[m - 1][bestLabel];
    for (int label = 1; label < scores[m - 1].length; label++) {
      if (scores[m - 1][label] > bestScore) {
        bestLabel = label;
        bestScore = scores[m - 1][label];
      }
    }

    DoubleMatrix outcome = new DenseDoubleMatrix(features.getRowCount(),
        classes == 2 ? 1 : classes);
    // follow the backpointers
    for (position = m - 1; position >= 0; position--) {
      DenseDoubleVector vec = null;
      if (classes != 2) {
        vec = new DenseDoubleVector(classes);
        vec.set(bestLabel, 1);
      } else {
        vec = new DenseDoubleVector(1);
        vec.set(0, bestLabel);
      }
      outcome.setRowVector(position, vec);
      bestLabel = backpointers[position][bestLabel];
    }

    return outcome;
  }

  // compute the scores for a featurevector and its weighs and the number of
  // classes
  static double[] computeScores(int classes, DoubleVector features,
      DoubleMatrix weights) {

    double[] scores = new double[classes];

    Iterator<DoubleVectorElement> iterateNonZero = features.iterateNonZero();
    while (iterateNonZero.hasNext()) {
      DoubleVectorElement next = iterateNonZero.next();
      for (int i = 0; i < scores.length; i++) {
        scores[i] += weights.get(i, next.getIndex());
      }
    }

    return scores;
  }

}
TOP

Related Classes of de.jungblut.math.ViterbiUtils

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.