/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import gnu.trove.*;
/**
* Latent Dirichlet Allocation.
* @author David Mimno, Andrew McCallum
*/
public class MultinomialHMM {
int numTopics; // Number of topics to be fit
int numStates; // Number of hidden states
int numDocs;
int numSequences;
// Dirichlet(alpha,alpha,...) is the distribution over topics
double[] alpha;
double alphaSum;
// Prior on per-topic multinomial distribution over words
double beta;
double betaSum;
// Prior on the state-state transition distributions
double gamma;
double gammaSum;
double pi;
double sumPi;
TIntObjectHashMap<TIntIntHashMap> documentTopics;
int[] documentSequenceIDs;
int[] documentStates;
int[][] stateTopicCounts;
int[] stateTopicTotals;
int[][] stateStateTransitions;
int[] stateTransitionTotals;
int[] initialStateCounts;
// Keep track of the most times each topic is
// used in any document
int[] maxTokensPerTopic;
// The size of the largest document
int maxDocLength;
// Rather than calculating log gammas for every state and every topic
// we cache log predictive distributions for every possible state
// and document.
double[][][] topicLogGammaCache;
double[][] docLogGammaCache;
int numIterations = 1000;
int burninPeriod = 200;
int saveSampleInterval = 10;
int optimizeInterval = 0;
int showTopicsInterval = 50;
String[] topicKeys;
Randoms random;
NumberFormat formatter;
public MultinomialHMM (int numberOfTopics, String topicsFilename, int numStates) throws IOException {
formatter = NumberFormat.getInstance();
formatter.setMaximumFractionDigits(5);
System.out.println("LDA HMM: " + numberOfTopics);
documentTopics = new TIntObjectHashMap<TIntIntHashMap>();
this.numTopics = numberOfTopics;
this.alphaSum = numberOfTopics;
this.alpha = new double[numberOfTopics];
Arrays.fill(alpha, alphaSum / numTopics);
topicKeys = new String[numTopics];
// This initializes numDocs as well
loadTopicsFromFile(topicsFilename);
documentStates = new int[ numDocs ];
documentSequenceIDs = new int[ numDocs ];
maxTokensPerTopic = new int[ numTopics ];
maxDocLength = 0;
//int[] histogram = new int[380];
//int totalTokens = 0;
for (int doc=0; doc < numDocs; doc++) {
if (! documentTopics.containsKey(doc)) { continue; }
TIntIntHashMap topicCounts = documentTopics.get(doc);
int count = 0;
for (int topic: topicCounts.keys()) {
int topicCount = topicCounts.get(topic);
//histogram[topicCount]++;
//totalTokens += topicCount;
if (topicCount > maxTokensPerTopic[topic]) {
maxTokensPerTopic[topic] = topicCount;
}
count += topicCount;
}
if (count > maxDocLength) {
maxDocLength = count;
}
}
/*
double runningTotal = 0.0;
for (int i=337; i >= 0; i--) {
runningTotal += i * histogram[i];
System.out.format("%d\t%d\t%.3f\n", i, histogram[i],
runningTotal / totalTokens);
}
*/
this.numStates = numStates;
this.initialStateCounts = new int[numStates];
topicLogGammaCache = new double[numStates][numTopics][];
for (int state=0; state < numStates; state++) {
for (int topic=0; topic < numTopics; topic++) {
topicLogGammaCache[state][topic] = new double[ maxTokensPerTopic[topic] + 1 ];
//topicLogGammaCache[state][topic] = new double[21];
}
}
System.out.println( maxDocLength );
docLogGammaCache = new double[numStates][ maxDocLength + 1 ];
}
public void setGamma(double g) {
this.gamma = g;
}
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
}
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
}
public void setTopicDisplayInterval(int interval) {
this.showTopicsInterval = interval;
}
public void setRandomSeed(int seed) {
random = new Randoms(seed);
}
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
}
public void initialize () {
if (random == null) {
random = new Randoms();
}
gammaSum = gamma * numStates;
stateTopicCounts = new int[numStates][numTopics];
stateTopicTotals = new int[numStates];
stateStateTransitions = new int[numStates][numStates];
stateTransitionTotals = new int[numStates];
pi = 1000.0;
sumPi = numStates * pi;
int maxTokens = 0;
int totalTokens = 0;
numSequences = 0;
int sequenceID;
int currentSequenceID = -1;
// The code to cache topic distributions
// takes an int-int hashmap as a mask to only update
// the distributions for topics that have actually changed.
// Here we create a dummy count hash that has all the topics.
TIntIntHashMap allTopicsDummy = new TIntIntHashMap();
for (int topic = 0; topic < numTopics; topic++) {
allTopicsDummy.put(topic, 1);
}
for (int state=0; state < numStates; state++) {
recacheStateTopicDistribution(state, allTopicsDummy);
}
for (int doc = 0; doc < numDocs; doc++) {
sampleState(doc, random, true);
}
}
private void recacheStateTopicDistribution(int state, TIntIntHashMap topicCounts) {
int[] currentStateTopicCounts = stateTopicCounts[state];
double[][] currentStateCache = topicLogGammaCache[state];
double[] cache;
for (int topic: topicCounts.keys()) {
cache = currentStateCache[topic];
cache[0] = 0.0;
for (int i=1; i < cache.length; i++) {
cache[i] =
cache[ i-1 ] +
Math.log( alpha[topic] + i - 1 +
currentStateTopicCounts[topic] );
}
}
docLogGammaCache[state][0] = 0.0;
for (int i=1; i < docLogGammaCache[state].length; i++) {
docLogGammaCache[state][i] =
docLogGammaCache[state][ i-1 ] +
Math.log( alphaSum + i - 1 +
stateTopicTotals[state] );
}
}
public void sample() throws IOException {
long startTime = System.currentTimeMillis();
for (int iterations = 1; iterations <= numIterations; iterations++) {
long iterationStart = System.currentTimeMillis();
//System.out.println (printStateTransitions());
for (int doc = 0; doc < numDocs; doc++) {
sampleState (doc, random, false);
//if (doc % 10000 == 0) { System.out.println (printStateTransitions()); }
}
System.out.print((System.currentTimeMillis() - iterationStart) + " ");
if (iterations % 10 == 0) {
System.out.println ("<" + iterations + "> ");
PrintWriter out =
new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + iterations)));
out.print(stateTransitionMatrix());
out.close();
out = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + iterations)));
out.print(stateTopics());
out.close();
if (iterations % 10 == 0) {
out = new PrintWriter(new BufferedWriter(new FileWriter("states." + iterations)));
for (int doc = 0; doc < documentStates.length; doc++) {
out.println(documentStates[doc]);
}
out.close();
}
}
System.out.flush();
}
long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
long minutes = seconds / 60; seconds %= 60;
long hours = minutes / 60; minutes %= 60;
long days = hours / 24; hours %= 24;
System.out.print ("\nTotal time: ");
if (days != 0) { System.out.print(days); System.out.print(" days "); }
if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
System.out.print(seconds); System.out.println(" seconds");
}
public void loadTopicsFromFile(String stateFilename) throws IOException {
BufferedReader in;
if (stateFilename.endsWith(".gz")) {
in = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFilename))));
}
else {
in = new BufferedReader(new FileReader(new File(stateFilename)));
}
numDocs = 0;
String line = null;
while ((line = in.readLine()) != null) {
if (line.startsWith("#")) {
continue;
}
String[] fields = line.split(" ");
int doc = Integer.parseInt(fields[0]);
int token = Integer.parseInt(fields[1]);
int type = Integer.parseInt(fields[2]);
int topic = Integer.parseInt(fields[4]);
// Now add the new topic
if (! documentTopics.containsKey(doc)) {
documentTopics.put(doc, new TIntIntHashMap());
}
if (documentTopics.get(doc).containsKey(topic)) {
documentTopics.get(doc).increment(topic);
}
else {
documentTopics.get(doc).put(topic, 1);
}
if (doc >= numDocs) { numDocs = doc + 1; }
}
in.close();
System.out.println("loaded topics, " + numDocs + " documents");
}
public void loadAlphaFromFile(String alphaFilename) throws IOException {
// Now restore the saved alpha parameters
alphaSum = 0.0;
BufferedReader in = new BufferedReader(new FileReader(new File(alphaFilename)));
String line = null;
while ((line = in.readLine()) != null) {
if (line.equals("")) { continue; }
String[] fields = line.split("\\s+");
int topic = Integer.parseInt(fields[0]);
alpha[topic] = 1.0; // Double.parseDouble(fields[1]);
alphaSum += alpha[topic];
StringBuffer topicKey = new StringBuffer();
for (int i=2; i<fields.length; i++) {
topicKey.append(fields[i] + " ");
}
topicKeys[topic] = topicKey.toString();
}
in.close();
System.out.println("loaded alpha");
}
/*
public void loadStatesFromFile(String stateFilename) throws IOException {
int doc = 0;
int state;
BufferedReader in = new BufferedReader(new FileReader(new File(stateFilename)));
String line = null;
while ((line = in.readLine()) != null) {
// We assume that the sequences are in the instance list
// in order.
state = Integer.parseInt(line);
documentStates[doc] = state;
// Additional bookkeeping will be performed when we load sequence IDs,
// so states MUST be loaded before sequences.
doc++;
}
in.close();
System.out.println("loaded states");
}
*/
public void loadSequenceIDsFromFile(String sequenceFilename) throws IOException {
int doc = 0;
int sequenceID;
int currentSequenceID = -1;
BufferedReader in = new BufferedReader(new FileReader(new File(sequenceFilename)));
String line = null;
while ((line = in.readLine()) != null) {
// We assume that the sequences are in the instance list
// in order.
String[] fields = line.split("\\t");
sequenceID = Integer.parseInt(fields[0]);
documentSequenceIDs[doc] = sequenceID;
if (sequenceID != currentSequenceID) {
numSequences ++;
}
currentSequenceID = sequenceID;
doc++;
}
in.close();
if (doc != numDocs) { System.out.println("Warning: number of documents with topics (" + numDocs + ") is not equal to number of docs with sequence IDs (" + doc + ")"); }
System.out.println("loaded sequence");
}
private void sampleState (int doc, Randoms r, boolean initializing) {
/*
if (doc % 10000 == 0) {
if (initializing) {
System.out.println("initializing doc " + doc);
}
else {
System.out.println("sampling doc " + doc);
}
}
*/
long startTime = System.currentTimeMillis();
// It's possible this document contains no words,
// in which case it has no topics, and no entry in the
// documentTopics hash.
if (! documentTopics.containsKey(doc)) { return; }
TIntIntHashMap topicCounts = documentTopics.get(doc);
// if we are in initializing mode, this is meaningless,
// but it won't hurt.
int oldState = documentStates[doc];
int[] currentStateTopicCounts = stateTopicCounts[oldState];
// Look at the document features (topics).
// If we're not in initializing mode, reduce the topic counts
// of the current (old) state.
int docLength = 0;
for (int topic: topicCounts.keys()) {
int topicCount = topicCounts.get(topic);
if (! initializing) {
currentStateTopicCounts[topic] -= topicCount;
}
docLength += topicCount;
}
if (! initializing) {
stateTopicTotals[oldState] -= docLength;
recacheStateTopicDistribution(oldState, topicCounts);
}
int previousSequenceID = -1;
if (doc > 0) {
previousSequenceID = documentSequenceIDs[ doc-1 ];
}
int sequenceID = documentSequenceIDs[ doc ];
int nextSequenceID = -1;
if (! initializing &&
doc < numDocs - 1) {
nextSequenceID = documentSequenceIDs[ doc+1 ];
}
double[] stateLogLikelihoods = new double[numStates];
double[] samplingDistribution = new double[numStates];
int nextState, previousState;
if (initializing) {
// Initializing the states is the same as sampling them,
// but we only look at the previous state and we don't decrement
// any counts.
if (previousSequenceID != sequenceID) {
// New sequence, start from scratch
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
}
}
else {
// Continuation
previousState = documentStates[ doc-1 ];
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite end");
}
}
}
}
else {
// There are four cases:
if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
// 1. This is a singleton document
initialStateCounts[oldState]--;
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
}
}
else if (previousSequenceID != sequenceID) {
// 2. This is the beginning of a sequence
initialStateCounts[oldState]--;
nextState = documentStates[doc+1];
stateStateTransitions[oldState][nextState]--;
assert(stateStateTransitions[oldState][nextState] >= 0);
stateTransitionTotals[oldState]--;
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( (stateStateTransitions[state][nextState] + gamma) *
(initialStateCounts[state] + pi) /
(numSequences - 1 + sumPi) );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite beginning");
}
}
}
else if (sequenceID != nextSequenceID) {
// 3. This is the end of a sequence
previousState = documentStates[doc-1];
stateStateTransitions[previousState][oldState]--;
assert(stateStateTransitions[previousState][oldState] >= 0);
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite end");
}
}
}
else {
// 4. This is the middle of a sequence
nextState = documentStates[doc+1];
stateStateTransitions[oldState][nextState]--;
if (stateStateTransitions[oldState][nextState] < 0) {
System.out.println(printStateTransitions());
System.out.println(oldState + " -> " + nextState);
System.out.println(sequenceID);
}
assert (stateStateTransitions[oldState][nextState] >= 0);
stateTransitionTotals[oldState]--;
previousState = documentStates[doc-1];
stateStateTransitions[previousState][oldState]--;
assert(stateStateTransitions[previousState][oldState] >= 0);
for (int state = 0; state < numStates; state++) {
if (previousState == state && state == nextState) {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + 1 + gamma) /
(stateTransitionTotals[state] + 1 + gammaSum) );
}
else if (previousState == state) {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + gamma) /
(stateTransitionTotals[state] + 1 + gammaSum) );
}
else {
stateLogLikelihoods[state] =
Math.log( (stateStateTransitions[previousState][state] + gamma) *
(stateStateTransitions[state][nextState] + gamma) /
(stateTransitionTotals[state] + gammaSum) );
}
if (Double.isInfinite(stateLogLikelihoods[state])) {
System.out.println("infinite middle: " + doc);
System.out.println(previousState + " -> " +
state + " -> " + nextState);
System.out.println(stateStateTransitions[previousState][state] + " -> " +
stateStateTransitions[state][nextState] + " / " +
stateTransitionTotals[state]);
}
}
}
}
double max = Double.NEGATIVE_INFINITY;
for (int state = 0; state < numStates; state++) {
stateLogLikelihoods[state] -= stateTransitionTotals[state] / 10;
currentStateTopicCounts = stateTopicCounts[state];
double[][] currentStateLogGammaCache = topicLogGammaCache[state];
int totalTokens = 0;
for (int topic: topicCounts.keys()) {
int count = topicCounts.get(topic);
// Cached Sampling Distribution
stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];
/*
// Hybrid version
if (count < currentStateLogGammaCache[topic].length) {
stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];
}
else {
int i = currentStateLogGammaCache[topic].length - 1;
stateLogLikelihoods[state] +=
currentStateLogGammaCache[topic][ i ];
for (; i < count; i++) {
stateLogLikelihoods[state] +=
Math.log(alpha[topic] + currentStateTopicCounts[topic] + i);
}
}
*/
/*
for (int j=0; j < count; j++) {
stateLogLikelihoods[state] +=
Math.log( (alpha[topic] + currentStateTopicCounts[topic] + j) /
(alphaSum + stateTopicTotals[state] + totalTokens) );
if (Double.isNaN(stateLogLikelihoods[state])) {
System.out.println("NaN: " + alpha[topic] + " + " +
currentStateTopicCounts[topic] + " + " +
j + ") /\n" +
"(" + alphaSum + " + " +
stateTopicTotals[state] + " + " + totalTokens);
}
totalTokens++;
}
*/
}
// Cached Sampling Distribution
stateLogLikelihoods[state] -= docLogGammaCache[state][ docLength ];
/*
// Hybrid version
if (docLength < docLogGammaCache[state].length) {
stateLogLikelihoods[state] -= docLogGammaCache[state][docLength];
}
else {
int i = docLogGammaCache[state].length - 1;
stateLogLikelihoods[state] -=
docLogGammaCache[state][ i ];
for (; i < docLength; i++) {
stateLogLikelihoods[state] -=
Math.log(alphaSum + stateTopicTotals[state] + i);
}
}
*/
if (stateLogLikelihoods[state] > max) {
max = stateLogLikelihoods[state];
}
}
double sum = 0.0;
for (int state = 0; state < numStates; state++) {
if (Double.isNaN(samplingDistribution[state])) {
System.out.println(stateLogLikelihoods[state]);
}
assert(! Double.isNaN(samplingDistribution[state]));
samplingDistribution[state] =
Math.exp(stateLogLikelihoods[state] - max);
sum += samplingDistribution[state];
if (Double.isNaN(samplingDistribution[state])) {
System.out.println(stateLogLikelihoods[state]);
}
assert(! Double.isNaN(samplingDistribution[state]));
if (doc % 100 == 0) {
//System.out.println(samplingDistribution[state]);
}
}
int newState = r.nextDiscrete(samplingDistribution, sum);
documentStates[doc] = newState;
for (int topic = 0; topic < numTopics; topic++) {
stateTopicCounts[newState][topic] += topicCounts.get(topic);
}
stateTopicTotals[newState] += docLength;
recacheStateTopicDistribution(newState, topicCounts);
if (initializing) {
// If we're initializing the states, don't bother
// looking at the next state.
if (previousSequenceID != sequenceID) {
initialStateCounts[newState]++;
}
else {
previousState = documentStates[doc-1];
stateStateTransitions[previousState][newState]++;
stateTransitionTotals[newState]++;
}
}
else {
if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
// 1. This is a singleton document
initialStateCounts[newState]++;
}
else if (previousSequenceID != sequenceID) {
// 2. This is the beginning of a sequence
initialStateCounts[newState]++;
nextState = documentStates[doc+1];
stateStateTransitions[newState][nextState]++;
stateTransitionTotals[newState]++;
}
else if (sequenceID != nextSequenceID) {
// 3. This is the end of a sequence
previousState = documentStates[doc-1];
stateStateTransitions[previousState][newState]++;
}
else {
// 4. This is the middle of a sequence
previousState = documentStates[doc-1];
stateStateTransitions[previousState][newState]++;
nextState = documentStates[doc+1];
stateStateTransitions[newState][nextState]++;
stateTransitionTotals[newState]++;
}
}
}
public String printStateTransitions() {
StringBuffer out = new StringBuffer();
IDSorter[] sortedTopics = new IDSorter[numTopics];
for (int s = 0; s < numStates; s++) {
for (int topic=0; topic<numTopics; topic++) {
sortedTopics[topic] = new IDSorter(topic, (double) stateTopicCounts[s][topic] / stateTopicTotals[s]);
}
Arrays.sort(sortedTopics);
out.append("\n" + s + "\n");
for (int i=0; i<4; i++) {
int topic = sortedTopics[i].getID();
out.append(stateTopicCounts[s][topic] + "\t" + topicKeys[topic] + "\n");
}
out.append("\n");
out.append("[" + initialStateCounts[s] + "/" + numSequences + "] ");
out.append("[" + stateTransitionTotals[s] + "]");
for (int t = 0; t < numStates; t++) {
out.append("\t");
if (s == t) {
out.append("[" + stateStateTransitions[s][t] + "]");
}
else {
out.append(stateStateTransitions[s][t]);
}
}
out.append("\n");
}
return out.toString();
}
public String stateTransitionMatrix() {
StringBuffer out = new StringBuffer();
for (int s = 0; s < numStates; s++) {
for (int t = 0; t < numStates; t++) {
out.append(stateStateTransitions[s][t]);
out.append("\t");
}
out.append("\n");
}
return out.toString();
}
public String stateTopics() {
StringBuffer out = new StringBuffer();
for (int s = 0; s < numStates; s++) {
for (int topic=0; topic<numTopics; topic++) {
out.append(stateTopicCounts[s][topic] + "\t");
}
out.append("\n");
}
return out.toString();
}
public static void main (String[] args) throws IOException {
if (args.length != 4) {
System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
System.exit(0);
}
int numTopics = Integer.parseInt(args[0]);
MultinomialHMM hmm =
new MultinomialHMM (numTopics, args[1], 150);
hmm.setGamma(1.0);
hmm.setRandomSeed(1);
hmm.loadAlphaFromFile(args[2]);
hmm.loadSequenceIDsFromFile(args[3]);
hmm.initialize();
hmm.sample();
}
}