package de.jungblut.nlp;
import static com.google.common.base.Preconditions.checkArgument;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.io.Writable;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.writable.MatrixWritable;
import de.jungblut.writable.VectorWritable;
/**
* Hidden Markov Model implementation for multiple observations for all three
* types of problems HMM aims to solve (Decoding, likelihood estimation,
* unsupervised/supervised learning).
*
* @author thomas.jungblut
*
*/
public final class HMM extends AbstractClassifier implements Writable {
private static final Log LOG = LogFactory.getLog(HMM.class);
private int numVisibleStates;
private int numHiddenStates;
/**
* transition matrix of hidden state i (row) to state j column. (A) =
* numHiddenStates * numHiddenStates.
*/
private DoubleMatrix transitionProbabilityMatrix;
/**
* emission matrix of an probability observation/feature o_t generated from a
* state i. (in literature called B) = numHiddenStates * numVisibleStates.
*/
private DoubleMatrix emissionProbabilityMatrix;
/**
* initial hidden state probabilities (prior how likely a state is happening)
*/
private DoubleVector hiddenPriorProbability;
// test seed
private long seed;
// deserialization constructor for Writable types
public HMM() {
seed = System.currentTimeMillis();
}
public HMM(int numVisibleStates, int numHiddenStates) {
this(numVisibleStates, numHiddenStates, System.currentTimeMillis());
}
// test constructor
HMM(int numVisibleStates, int numHiddenStates, long seed) {
this.seed = seed;
this.numVisibleStates = numVisibleStates;
this.numHiddenStates = numHiddenStates;
this.transitionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates,
numHiddenStates);
this.emissionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates,
numVisibleStates);
this.hiddenPriorProbability = new DenseDoubleVector(numHiddenStates);
}
/**
* Normalizes the values of all three main datastructures of the HMM
* (transitionProbabilityMatrix, emissionProbabilitiyMatrix and
* hiddenPriorProbability) by summing the values and dividing each element by
* the sum. For matrices we are using row sums over the hidden states.
*/
private void normalizeProbabilities() {
normalize(hiddenPriorProbability, transitionProbabilityMatrix,
emissionProbabilityMatrix, false);
}
private void logNormalizeProbabilities() {
normalize(hiddenPriorProbability, transitionProbabilityMatrix,
emissionProbabilityMatrix, true);
}
/**
* Likelihood estimation on the current HMM. It estimates the likelihood that
* the given observation sequence is about to happen. P( O | lambda ) where O
* is the observation sequence and lambda are the HMM's parameters. This is
* done by executing the forward algorithm with the given observations clamped
* to the visible states.
*
* @param observationSequence the given sequence of observations (features).
* @return the likelihood (not a probability!) that the given sequence is
* about to happen.
*/
public double estimateLikelihood(DoubleVector[] observationSequence) {
return estimateLikelihood(forward(new DenseDoubleMatrix(
observationSequence.length, numHiddenStates),
transitionProbabilityMatrix, emissionProbabilityMatrix,
hiddenPriorProbability, observationSequence));
}
private static double estimateLikelihood(DoubleMatrix alpha) {
// sum the last row in our alpha matrix generated by the forward algorithm,
// this denotes the endstate of our sequence.
return alpha.getRowVector(alpha.getRowCount() - 1).sum();
}
/**
* Decodes the given observation sequence (features) with the current HMM.
* This discovers the best hidden state sequence Q that is derived by
* executing the Viterbi algorithm with the given observations and the HMM's
* parameters lambda. This is a proxy to {@link ViterbiUtils}
* {@link #decode(DoubleVector[], DoubleVector[])}.
*
* @param observationSequence the given sequence of features.
* @return a matrix containing the predicted hidden state on each row vector.
*/
public DoubleMatrix decode(DoubleVector[] observationSequence,
DoubleVector[] featuresPerHiddenState) {
return ViterbiUtils.decode(emissionProbabilityMatrix,
new SparseDoubleRowMatrix(observationSequence),
new SparseDoubleRowMatrix(featuresPerHiddenState), numHiddenStates);
}
/**
* Trains the current models parameters by executing a baum-welch expectation
* maximization algorithm. TODO this should also be log-scaled for accuracy.
*
* @param features the visible state activations (the vector will be traversed
* for non-zero entries, so the value actually doesn't matter).
* @param epsilon the absolute difference in the train model to the previous.
* If smaller than given value the iterations are stopped and the
* training finishes.
* @param maxIterations if the epsilon threshold is never reached, the maximum
* iterations usually applies by stopping computation after given
* number of iterations.
* @param verbose when set to true it will print information about the
* expectimax values per iteration.
*/
public void trainUnsupervised(DoubleVector[] features, double epsilon,
int maxIterations, boolean verbose) {
// initialize a random starting state
Random random = new Random(seed);
transitionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates,
numHiddenStates, random);
emissionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates,
numVisibleStates, random);
hiddenPriorProbability = new DenseDoubleVector(numHiddenStates);
for (int i = 0; i < numHiddenStates; i++) {
hiddenPriorProbability.set(i, random.nextDouble());
}
normalizeProbabilities();
DoubleMatrix alpha = new DenseDoubleMatrix(features.length, numHiddenStates);
DoubleMatrix beta = new DenseDoubleMatrix(features.length, numHiddenStates);
for (int iteration = 0; iteration < maxIterations; iteration++) {
// in every iteration we initialize a new model that is a copy of the old
DoubleMatrix transitionProbabilityMatrix = this.transitionProbabilityMatrix
.deepCopy();
DoubleMatrix emissionProbabilityMatrix = this.emissionProbabilityMatrix
.deepCopy();
DoubleVector hiddenPriorProbability = this.hiddenPriorProbability
.deepCopy();
// expectation step
alpha = forward(alpha, transitionProbabilityMatrix,
emissionProbabilityMatrix, hiddenPriorProbability, features);
beta = backward(beta, transitionProbabilityMatrix,
emissionProbabilityMatrix, hiddenPriorProbability, features);
// now do the real baum-welch algorithm / maximization step calculate the
// prior out of the alpha and beta factors in their first row
hiddenPriorProbability = alpha.getRowVector(0).multiply(
beta.getRowVector(0));
final double modelLikelihood = estimateLikelihood(alpha);
// compute real transition probabilities
for (int i = 0; i < numHiddenStates; i++) {
for (int j = 0; j < numHiddenStates; j++) {
double temp = 0d;
for (int t = 0; t < features.length - 1; t++) {
Iterator<DoubleVectorElement> iterateNonZero = features[t + 1]
.iterateNonZero();
while (iterateNonZero.hasNext()) {
temp += alpha.get(t, i)
* emissionProbabilityMatrix.get(j, iterateNonZero.next()
.getIndex()) * beta.get(t + 1, j);
}
}
transitionProbabilityMatrix.set(i, j,
transitionProbabilityMatrix.get(i, j) * temp / modelLikelihood);
}
}
// compute real emission probabilities
for (int i = 0; i < numHiddenStates; i++) {
for (int j = 0; j < numVisibleStates; j++) {
double temp = 0d;
for (int t = 0; t < features.length; t++) {
Iterator<DoubleVectorElement> iterateNonZero = features[t]
.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
if (next.getIndex() == j) {
temp += alpha.get(t, i) * beta.get(t, i);
}
}
}
emissionProbabilityMatrix.set(i, j, temp / modelLikelihood);
}
}
// normalize again after baumwelch pass
normalize(hiddenPriorProbability, transitionProbabilityMatrix,
emissionProbabilityMatrix, false);
// measure the difference by taking the squared difference of every matrix
// element and every vector element in our priors
double difference = this.transitionProbabilityMatrix
.subtract(transitionProbabilityMatrix).pow(2).sum()
+ this.emissionProbabilityMatrix.subtract(emissionProbabilityMatrix)
.pow(2).sum()
+ this.getHiddenPriorProbability().subtract(hiddenPriorProbability)
.pow(2).sum();
if (verbose) {
LOG.info("Iteration " + iteration + " | Model difference: "
+ difference + "\r");
}
// set the new model to our fields
this.transitionProbabilityMatrix = transitionProbabilityMatrix;
this.emissionProbabilityMatrix = emissionProbabilityMatrix;
this.hiddenPriorProbability = hiddenPriorProbability;
if (difference < epsilon) {
break;
}
}
// normalize logarithmic for predictions
normalize(hiddenPriorProbability, transitionProbabilityMatrix,
emissionProbabilityMatrix, true);
}
/**
* Backward algorithm to compute the beta factors.
*
* @param beta the pre allocated beta matrix (T x numHiddenStates, where T
* denotes the number of observation sequences).
* @return the mutated alpha matrix that was given as parameter.
*/
private static DoubleMatrix backward(DoubleMatrix beta,
DoubleMatrix transitionProbabilityMatrix,
DoubleMatrix emissionProbabilityMatrix,
DoubleVector hiddenPriorProbability, DoubleVector[] features) {
final int numHiddenStates = beta.getColumnCount();
// set the states on the last row to 1
beta.setRowVector(features.length - 1,
DenseDoubleVector.ones(numHiddenStates));
for (int t = features.length - 2; t >= 0; t--) {
for (int i = 0; i < numHiddenStates; i++) {
double sum = 0d;
for (int j = 0; j < numHiddenStates; j++) {
Iterator<DoubleVectorElement> iterateNonZero = features[t + 1]
.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
sum += beta.get(t + 1, j) * transitionProbabilityMatrix.get(i, j)
* emissionProbabilityMatrix.get(j, next.getIndex());
}
}
beta.set(t, i, sum);
}
}
return beta;
}
/**
* Forward algorithm to compute the alpha factors.
*
* @param alpha the pre allocated alpha matrix (T x numHiddenStates, where T
* denotes the number of observation sequences).
* @return the mutated alpha matrix that was given as parameter.
*/
private static DoubleMatrix forward(DoubleMatrix alpha,
DoubleMatrix transitionProbabilityMatrix,
DoubleMatrix emissionProbabilityMatrix,
DoubleVector hiddenPriorProbability, DoubleVector[] features) {
final int numHiddenStates = alpha.getColumnCount();
for (int i = 0; i < numHiddenStates; i++) {
// for each feature state of our first features
Iterator<DoubleVectorElement> firstFeatures = features[0]
.iterateNonZero();
double emissionSum = 0d;
while (firstFeatures.hasNext()) {
emissionSum += emissionProbabilityMatrix.get(i, firstFeatures.next()
.getIndex());
}
alpha.set(0, i, hiddenPriorProbability.get(i) * emissionSum);
}
// loop over every time step
for (int t = 1; t < features.length; t++) {
for (int i = 0; i < numHiddenStates; i++) {
double sum = 0.0d;
for (int j = 0; j < numHiddenStates; j++) {
sum += alpha.get(t - 1, j) * transitionProbabilityMatrix.get(j, i);
}
Iterator<DoubleVectorElement> featureIterator = features[t]
.iterateNonZero();
double emissionSum = 0d;
while (featureIterator.hasNext()) {
// for each feature in the current features compute the emission from
// feature state q to state i
emissionSum += emissionProbabilityMatrix.get(i, featureIterator
.next().getIndex());
}
alpha.set(t, i, sum * emissionSum);
}
}
return alpha;
}
private static void normalize(DoubleVector hiddenPriorProbability,
DoubleMatrix transitionProbabilityMatrix,
DoubleMatrix emissionProbabilitiyMatrix, boolean log) {
double sum = hiddenPriorProbability.sum();
if (sum != 0d) {
for (int i = 0; i < hiddenPriorProbability.getDimension(); i++) {
hiddenPriorProbability.set(i, hiddenPriorProbability.get(i) / sum);
}
}
for (int row = 0; row < transitionProbabilityMatrix.getRowCount(); row++) {
// note that we are using row vectors here, because dense matrices give us
// the underlying array wrapped by the vector object so we can directly
// mutate the values beneath
DoubleVector rowVector = transitionProbabilityMatrix.getRowVector(row);
rowVector = rowVector.divide(rowVector.sum());
if (log) {
rowVector = rowVector.log();
}
transitionProbabilityMatrix.setRowVector(row, rowVector);
rowVector = emissionProbabilitiyMatrix.getRowVector(row);
rowVector = rowVector.divide(rowVector.sum());
if (log) {
rowVector = rowVector.log();
}
emissionProbabilitiyMatrix.setRowVector(row, rowVector);
}
}
/**
* Trains the current models parameters by executing a forwad pass over the
* given observations (hidden and visible states). Probabilities are +1
* smoothed while counting in case there would be zero probability somewhere.
* This method is compatible to the Classifier#train method so this model can
* be used as a simple classifier.
*
* @param features the visible state activations (the vector will be traversed
* for non-zero entries, so the value actually doesn't matter).
* @param outcome the outcome that was assigned to the given features. This
* can be in the binary case a single element vector (0d or 1d), or
* in the multi-class case a vector which index denotes the class
* (from zero to numHiddenStates, activation is again 0d or 1d). Note
* that in the multi-class case just a single state can be turned on,
* so the classes are mutual exclusive.
*/
public void trainSupervised(DoubleVector[] features, DoubleVector[] outcome) {
// first check both have the same length, then sanity check with the
// parameters
checkArgument(features.length == outcome.length,
"Feature array length must match outcome array length: "
+ features.length + " != " + outcome.length);
// check if we have enough examples (at least 1)
checkArgument(features.length > 0,
"Feature array length be at least 1! Given: " + features.length);
// check if the feature vectors dimension matches the number of visible
// states
checkArgument(features[0].getDimension() == numVisibleStates,
"Feature vector's dimension must match the number of visible states! Given: "
+ features[0].getDimension() + ", but expected " + numVisibleStates);
// now check if the outcome is sane
int outcomeDimension = outcome[0].getDimension();
// this checks whether the outcome dimension is 1, if so we expect binary
// outcomes, else the number of hidden states
int expectedDimension = outcomeDimension == 1 ? 2 : numHiddenStates;
checkArgument(outcomeDimension == expectedDimension,
"Outcome dimension didn't match the given number of hidden states: "
+ outcomeDimension + " != " + expectedDimension);
// +1 smooth first
hiddenPriorProbability = hiddenPriorProbability.add(1d);
for (int rowIndex = 0; rowIndex < numHiddenStates; rowIndex++) {
transitionProbabilityMatrix.setRowVector(rowIndex,
DenseDoubleVector.ones(numHiddenStates));
emissionProbabilityMatrix.setRowVector(rowIndex,
DenseDoubleVector.ones(numVisibleStates));
}
for (int i = 0; i < features.length; i++) {
DoubleVector feat = features[i];
DoubleVector out = outcome[i];
int index = getOutcomeState(out);
hiddenPriorProbability.set(index, hiddenPriorProbability.get(index) + 1);
// count the emissions from feature layer to the hidden layer
Iterator<DoubleVectorElement> iterateNonZero = feat.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
emissionProbabilityMatrix.set(index, next.getIndex(),
emissionProbabilityMatrix.get(index, next.getIndex()) + 1);
}
// now handle the feature by counting the transitions between the hidden
// states with the next feature
if (i + 1 < features.length) {
DoubleVector nextOut = outcome[i + 1];
int nextIndex = getOutcomeState(nextOut);
transitionProbabilityMatrix.set(index, nextIndex,
transitionProbabilityMatrix.get(index, nextIndex) + 1);
}
}
// now we can divide by the counts and normalize the probabilities so they
// sum to 1 and log them
logNormalizeProbabilities();
}
@Override
public void train(DoubleVector[] features, DoubleVector[] outcome) {
trainSupervised(features, outcome);
}
@Override
public DoubleVector predict(DoubleVector features) {
// clamp the features to the visible units, calculate the joint
// probability for each hidden state and put it into the vector
DoubleVector probabilities = emissionProbabilityMatrix
.multiplyVectorRow(features);
double max = probabilities.max();
for (int state = 0; state < probabilities.getDimension(); state++) {
probabilities.set(state, FastMath.exp(probabilities.get(state) - max)
* hiddenPriorProbability.get(state));
}
// normalize again
return probabilities.divide(probabilities.sum());
}
public DoubleVector predict(DoubleVector features,
DoubleVector previousOutcome) {
// clamp the features to the visible units, calculate the joint
// probability for each hidden state and put it into the vector
DoubleVector probabilities = emissionProbabilityMatrix
.multiplyVectorRow(features);
// we can add here, both are logarithms
probabilities.add(transitionProbabilityMatrix
.multiplyVectorRow(previousOutcome));
double max = probabilities.max();
for (int state = 0; state < probabilities.getDimension(); state++) {
probabilities.set(state, FastMath.exp(probabilities.get(state) - max)
* hiddenPriorProbability.get(state));
}
// normalize again
return probabilities.divide(probabilities.sum());
}
public int getNumHiddenStates() {
return this.numHiddenStates;
}
public int getNumVisibleStates() {
return this.numVisibleStates;
}
public DoubleMatrix getEmissionProbabilitiyMatrix() {
return this.emissionProbabilityMatrix;
}
public DoubleVector getHiddenPriorProbability() {
return this.hiddenPriorProbability;
}
public DoubleMatrix getTransitionProbabilityMatrix() {
return this.transitionProbabilityMatrix;
}
/**
* @return the outcome state as integer that can be treated as index.
*/
private int getOutcomeState(DoubleVector out) {
int index;
if (out.getDimension() == 2) {
index = (int) out.get(0); // simple cast is enough here
} else {
// assume that the max index is correctly set to
// 1, no other state was ticked on.
index = out.maxIndex();
}
return index;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(numVisibleStates);
out.writeInt(numHiddenStates);
VectorWritable.writeVector(hiddenPriorProbability, out);
MatrixWritable.writeDenseMatrix(
(DenseDoubleMatrix) transitionProbabilityMatrix, out);
MatrixWritable.writeDenseMatrix(
(DenseDoubleMatrix) emissionProbabilityMatrix, out);
}
@Override
public void readFields(DataInput in) throws IOException {
numVisibleStates = in.readInt();
numHiddenStates = in.readInt();
hiddenPriorProbability = VectorWritable.readVector(in);
transitionProbabilityMatrix = MatrixWritable.readDenseMatrix(in);
emissionProbabilityMatrix = MatrixWritable.readDenseMatrix(in);
}
}