/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import gnu.trove.*;
* Latent Dirichlet Allocation.
* @author David Mimno, Andrew McCallum
public class MultinomialHMM {
int numTopics; // Number of topics to be fit
int numStates; // Number of hidden states
int numDocs;
int numSequences;
// Dirichlet(alpha,alpha,...) is the distribution over topics
double[] alpha;
double alphaSum;
// Prior on per-topic multinomial distribution over words
double beta;
double betaSum;
// Prior on the state-state transition distributions
double gamma;
double gammaSum;
double pi;
double sumPi;
TIntObjectHashMap<TIntIntHashMap> documentTopics;
int[] documentSequenceIDs;
int[] documentStates;
int[][] stateTopicCounts;
int[] stateTopicTotals;
int[][] stateStateTransitions;
int[] stateTransitionTotals;
int[] initialStateCounts;
// Keep track of the most times each topic is
// used in any document
int[] maxTokensPerTopic;
// The size of the largest document
int maxDocLength;
// Rather than calculating log gammas for every state and every topic
// we cache log predictive distributions for every possible state
// and document.
double[][][] topicLogGammaCache;
double[][] docLogGammaCache;
int numIterations = 1000;
int burninPeriod = 200;
int saveSampleInterval = 10;
int optimizeInterval = 0;
int showTopicsInterval = 50;
String[] topicKeys;
Randoms random;
NumberFormat formatter;
public MultinomialHMM (int numberOfTopics, String topicsFilename, int numStates) throws IOException {
formatter = NumberFormat.getInstance();
System.out.println("LDA HMM: " + numberOfTopics);
documentTopics = new TIntObjectHashMap<TIntIntHashMap>();
this.numTopics = numberOfTopics;
this.alphaSum = numberOfTopics;
this.alpha = new double[numberOfTopics];
Arrays.fill(alpha, alphaSum / numTopics);
topicKeys = new String[numTopics];
// This initializes numDocs as well
documentStates = new int[ numDocs ];
documentSequenceIDs = new int[ numDocs ];
maxTokensPerTopic = new int[ numTopics ];
maxDocLength = 0;
//int[] histogram = new int[380];
//int totalTokens = 0;
for (int doc=0; doc < numDocs; doc++) {
if (! documentTopics.containsKey(doc)) { continue; }
TIntIntHashMap topicCounts = documentTopics.get(doc);
int count = 0;
for (int topic: topicCounts.keys()) {
int topicCount = topicCounts.get(topic);
//totalTokens += topicCount;
if (topicCount > maxTokensPerTopic[topic]) {
maxTokensPerTopic[topic] = topicCount;
count += topicCount;
if (count > maxDocLength) {
maxDocLength = count;
double runningTotal = 0.0;
for (int i=337; i >= 0; i--) {
runningTotal += i * histogram[i];
System.out.format("%d\t%d\t%.3f\n", i, histogram[i],
runningTotal / totalTokens);
this.numStates = numStates;
this.initialStateCounts = new int[numStates];
topicLogGammaCache = new double[numStates][numTopics][];
for (int state=0; state < numStates; state++) {
for (int topic=0; topic < numTopics; topic++) {
topicLogGammaCache[state][topic] = new double[ maxTokensPerTopic[topic] + 1 ];
//topicLogGammaCache[state][topic] = new double[21];
System.out.println( maxDocLength );
docLogGammaCache = new double[numStates][ maxDocLength + 1 ];
public void setGamma(double g) {
this.gamma = g;
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
public void setTopicDisplayInterval(int interval) {
this.showTopicsInterval = interval;
public void setRandomSeed(int seed) {
random = new Randoms(seed);
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
public void initialize () {
if (random == null) {
random = new Randoms();
gammaSum = gamma * numStates;
stateTopicCounts = new int[numStates][numTopics];
stateTopicTotals = new int[numStates];
stateStateTransitions = new int[numStates][numStates];
stateTransitionTotals = new int[numStates];
pi = 1000.0;
sumPi = numStates * pi;
int maxTokens = 0;
int totalTokens = 0;
numSequences = 0;
int sequenceID;
int currentSequenceID = -1;
// The code to cache topic distributions
// takes an int-int hashmap as a mask to only update
// the distributions for topics that have actually changed.
// Here we create a dummy count hash that has all the topics.
TIntIntHashMap allTopicsDummy = new TIntIntHashMap();
for (int topic = 0; topic < numTopics; topic++) {
allTopicsDummy.put(topic, 1);
for (int state=0; state < numStates; state++) {
recacheStateTopicDistribution(state, allTopicsDummy);
for (int doc = 0; doc < numDocs; doc++) {
sampleState(doc, random, true);
private void recacheStateTopicDistribution(int state, TIntIntHashMap topicCounts) {
int[] currentStateTopicCounts = stateTopicCounts[state];
double[][] currentStateCache = topicLogGammaCache[state];
double[] cache;
for (int topic: topicCounts.keys()) {
cache = currentStateCache[topic];
cache[0] = 0.0;
for (int i=1; i < cache.length; i++) {
cache[i] =
cache[ i-1 ] +
Math.log( alpha[topic] + i - 1 +
currentStateTopicCounts[topic] );
docLogGammaCache[state][0] = 0.0;
for (int i=1; i < docLogGammaCache[state].length; i++) {
docLogGammaCache[state][i] =
docLogGammaCache[state][ i-1 ] +
Math.log( alphaSum + i - 1 +
stateTopicTotals[state] );
public void sample() throws IOException {
long startTime = System.currentTimeMillis();
for (int iterations = 1; iterations <= numIterations; iterations++) {
long iterationStart = System.currentTimeMillis();
//System.out.println (printStateTransitions());
for (int doc = 0; doc < numDocs; doc++) {
sampleState (doc, random, false);
//if (doc % 10000 == 0) { System.out.println (printStateTransitions()); }
System.out.print((System.currentTimeMillis() - iterationStart) + " ");
if (iterations % 10 == 0) {
System.out.println ("<" + iterations + "> ");
PrintWriter out =
new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + iterations)));
out = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + iterations)));
if (iterations % 10 == 0) {
out = new PrintWriter(new BufferedWriter(new FileWriter("states." + iterations)));
for (int doc = 0; doc < documentStates.length; doc++) {
long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
long minutes = seconds / 60; seconds %= 60;
long hours = minutes / 60; minutes %= 60;
long days = hours / 24; hours %= 24;
System.out.print ("\nTotal time: ");
if (days != 0) { System.out.print(days); System.out.print(" days "); }
if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
System.out.print(seconds); System.out.println(" seconds");
public void loadTopicsFromFile(String stateFilename) throws IOException {
BufferedReader in;
if (stateFilename.endsWith(".gz")) {
in = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFilename))));
else {
in = new BufferedReader(new FileReader(new File(stateFilename)));
numDocs = 0;
String line = null;
while ((line = in.readLine()) != null) {
if (line.startsWith("#")) {
String[] fields = line.split(" ");
int doc = Integer.parseInt(fields[0]);
int token = Integer.parseInt(fields[1]);
int type = Integer.parseInt(fields[2]);
int topic = Integer.parseInt(fields[4]);
// Now add the new topic
if (! documentTopics.containsKey(doc)) {
documentTopics.put(doc, new TIntIntHashMap());
if (documentTopics.get(doc).containsKey(topic)) {
else {
documentTopics.get(doc).put(topic, 1);
if (doc >= numDocs) { numDocs = doc + 1; }
System.out.println("loaded topics, " + numDocs + " documents");
public void loadAlphaFromFile(String alphaFilename) throws IOException {
// Now restore the saved alpha parameters
alphaSum = 0.0;
BufferedReader in = new BufferedReader(new FileReader(new File(alphaFilename)));
String line = null;
while ((line = in.readLine()) != null) {
if (line.equals("")) { continue; }
String[] fields = line.split("\\s+");
int topic = Integer.parseInt(fields[0]);
alpha[topic] = 1.0; // Double.parseDouble(fields[1]);
alphaSum += alpha[topic];
StringBuffer topicKey = new StringBuffer();
for (int i=2; i<fields.length; i++) {
topicKey.append(fields[i] + " ");
topicKeys[topic] = topicKey.toString();
System.out.println("loaded alpha");
public void loadStatesFromFile(String stateFilename) throws IOException {
int doc = 0;
int state;
BufferedReader in = new BufferedReader(new FileReader(new File(stateFilename)));
String line = null;
while ((line = in.readLine()) != null) {
// We assume that the sequences are in the instance list
// in order.
state = Integer.parseInt(line);
documentStates[doc] = state;
// Additional bookkeeping will be performed when we load sequence IDs,
// so states MUST be loaded before sequences.
System.out.println("loaded states");
public void loadSequenceIDsFromFile(String sequenceFilename) throws IOException {
int doc = 0;
int sequenceID;
int currentSequenceID = -1;
BufferedReader in = new BufferedReader(new FileReader(new File(sequenceFilename)));
String line = null;
while ((line = in.readLine()) != null) {
// We assume that the sequences are in the instance list
// in order.
String[] fields = line.split("\\t");
sequenceID = Integer.parseInt(fields[0]);
documentSequenceIDs[doc] = sequenceID;
if (sequenceID != currentSequenceID) {
numSequences ++;
currentSequenceID = sequenceID;
if (doc != numDocs) { System.out.println("Warning: number of documents with topics (" + numDocs + ") is not equal to number of docs with sequence IDs (" + doc + ")"); }
System.out.println("loaded sequence");
private void sampleState (int doc, Randoms r, boolean initializing) {
if (doc % 10000 == 0) {
if (initializing) {
System.out.println("initializing doc " + doc);
else {
System.out.println("sampling doc " + doc);
long startTime = System.currentTimeMillis();
// It's possible this document contains no words,
// in which case it has no topics, and no entry in the
// documentTopics hash.
if (! documentTopics.containsKey(doc)) { return; }
TIntIntHashMap topicCounts = documentTopics.get(doc);
// if we are in initializing mode, this is meaningless,
// but it won't hurt.
int oldState = documentStates[doc];
int[] currentStateTopicCounts = stateTopicCounts[oldState];
// Look at the document features (topics).
// If we're not in initializing mode, reduce the topic counts
// of the current (old) state.
int docLength = 0;
for (int topic: topicCounts.keys()) {
int topicCount = topicCounts.get(topic);
if (! initializing) {
currentStateTopicCounts[topic] -= topicCount;
docLength += topicCount;
if (! initializing) {
stateTopicTotals[oldState] -= docLength;
recacheStateTopicDistribution(oldState, topicCounts);
int previousSequenceID = -1;
if (doc > 0) {
previousSequenceID = documentSequenceIDs[ doc-1 ];
int sequenceID = documentSequenceIDs[ doc ];
int nextSequenceID = -1;
if (! initializing &&
doc < numDocs - 1) {
nextSequenceID = documentSequenceIDs[ doc+1 ];
double[] stateLogLikelihoods = new double[numStates];
double[] samplingDistribution = new double[numStates];
int nextState, previousState;
if (initializing) {
// Initializing the states is the same as sampling them,
// but we only look at the previous state and we don't decrement
// any counts.
if (previousSequenceID != sequenceID) {
// New sequence, start from scratch
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
else {
// Continuation
previousState = documentStates[ doc-1 ];
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite end");
else {
// There are four cases:
if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
// 1. This is a singleton document
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
else if (previousSequenceID != sequenceID) {
// 2. This is the beginning of a sequence
nextState = documentStates[doc+1];
assert(stateStateTransitions[oldState][nextState] >= 0);
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (stateStateTransitions[state][nextState] + gamma) *
(initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite beginning");
else if (sequenceID != nextSequenceID) {
// 3. This is the end of a sequence
previousState = documentStates[doc-1];
assert(stateStateTransitions[previousState][oldState] >= 0);
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite end");
else {
// 4. This is the middle of a sequence
nextState = documentStates[doc+1];
if (stateStateTransitions[oldState][nextState] < 0) {
System.out.println(oldState + " -> " + nextState);
assert (stateStateTransitions[oldState][nextState] >= 0);
previousState = documentStates[doc-1];
assert(stateStateTransitions[previousState][oldState] >= 0);
for (int state = 0; state < numStates; state++) {
if (previousState == state && state == nextState) {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + 1 + gamma) /
(stateTransitionTotals[state] + 1 + gammaSum) );
else if (previousState == state) {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + gamma) /
(stateTransitionTotals[state] + 1 + gammaSum) );
else {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + gamma) /
(stateTransitionTotals[state] + gammaSum) );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite middle: " + doc);
System.out.println(previousState + " -> " +
state + " -> " + nextState);
System.out.println(stateStateTransitions[previousState][state] + " -> " +
stateStateTransitions[state][nextState] + " / " +
double max = Double.NEGATIVE_INFINITY;
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] -= stateTransitionTotals[state] / 10;
currentStateTopicCounts = stateTopicCounts[state];
double[][] currentStateLogGammaCache = topicLogGammaCache[state];
int totalTokens = 0;
for (int topic: topicCounts.keys()) {
int count = topicCounts.get(topic);
// Cached Sampling Distribution
stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];
// Hybrid version
if (count < currentStateLogGammaCache[topic].length) {
stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];
else {
int i = currentStateLogGammaCache[topic].length - 1;
stateLogLikelihoods[state] +=
currentStateLogGammaCache[topic][ i ];
for (; i < count; i++) {
stateLogLikelihoods[state] +=
Math.log(alpha[topic] + currentStateTopicCounts[topic] + i);
for (int j=0; j < count; j++) {
stateLogLikelihoods[state] +=
Math.log( (alpha[topic] + currentStateTopicCounts[topic] + j) /
(alphaSum + stateTopicTotals[state] + totalTokens) );
if (Double.isNaN(stateLogLikelihoods[state])) {
System.out.println("NaN: " + alpha[topic] + " + " +
currentStateTopicCounts[topic] + " + " +
j + ") /\n" +
"(" + alphaSum + " + " +
stateTopicTotals[state] + " + " + totalTokens);
// Cached Sampling Distribution
stateLogLikelihoods[state] -= docLogGammaCache[state][ docLength ];
// Hybrid version
if (docLength < docLogGammaCache[state].length) {
stateLogLikelihoods[state] -= docLogGammaCache[state][docLength];
else {
int i = docLogGammaCache[state].length - 1;
stateLogLikelihoods[state] -=
docLogGammaCache[state][ i ];
for (; i < docLength; i++) {
stateLogLikelihoods[state] -=
Math.log(alphaSum + stateTopicTotals[state] + i);
if (stateLogLikelihoods[state] > max) {
max = stateLogLikelihoods[state];
double sum = 0.0;
for (int state = 0; state < numStates; state++) {
if (Double.isNaN(samplingDistribution[state])) {
assert(! Double.isNaN(samplingDistribution[state]));
samplingDistribution[state] =
Math.exp(stateLogLikelihoods[state] - max);
sum += samplingDistribution[state];
if (Double.isNaN(samplingDistribution[state])) {
assert(! Double.isNaN(samplingDistribution[state]));
if (doc % 100 == 0) {
int newState = r.nextDiscrete(samplingDistribution, sum);
documentStates[doc] = newState;
for (int topic = 0; topic < numTopics; topic++) {
stateTopicCounts[newState][topic] += topicCounts.get(topic);
stateTopicTotals[newState] += docLength;
recacheStateTopicDistribution(newState, topicCounts);
if (initializing) {
// If we're initializing the states, don't bother
// looking at the next state.
if (previousSequenceID != sequenceID) {
else {
previousState = documentStates[doc-1];
else {
if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
// 1. This is a singleton document
else if (previousSequenceID != sequenceID) {
// 2. This is the beginning of a sequence
nextState = documentStates[doc+1];
else if (sequenceID != nextSequenceID) {
// 3. This is the end of a sequence
previousState = documentStates[doc-1];
else {
// 4. This is the middle of a sequence
previousState = documentStates[doc-1];
nextState = documentStates[doc+1];
public String printStateTransitions() {
StringBuffer out = new StringBuffer();
IDSorter[] sortedTopics = new IDSorter[numTopics];
for (int s = 0; s < numStates; s++) {
for (int topic=0; topic<numTopics; topic++) {
sortedTopics[topic] = new IDSorter(topic, (double) stateTopicCounts[s][topic] / stateTopicTotals[s]);
out.append("\n" + s + "\n");
for (int i=0; i<4; i++) {
int topic = sortedTopics[i].getID();
out.append(stateTopicCounts[s][topic] + "\t" + topicKeys[topic] + "\n");
out.append("[" + initialStateCounts[s] + "/" + numSequences + "] ");
out.append("[" + stateTransitionTotals[s] + "]");
for (int t = 0; t < numStates; t++) {
if (s == t) {
out.append("[" + stateStateTransitions[s][t] + "]");
else {
return out.toString();
public String stateTransitionMatrix() {
StringBuffer out = new StringBuffer();
for (int s = 0; s < numStates; s++) {
for (int t = 0; t < numStates; t++) {
return out.toString();
public String stateTopics() {
StringBuffer out = new StringBuffer();
for (int s = 0; s < numStates; s++) {
for (int topic=0; topic<numTopics; topic++) {
out.append(stateTopicCounts[s][topic] + "\t");
return out.toString();
public static void main (String[] args) throws IOException {
if (args.length != 4) {
System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
int numTopics = Integer.parseInt(args[0]);
MultinomialHMM hmm =
new MultinomialHMM (numTopics, args[1], 150);