/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
* An implementation of topic model marginal probability estimators
* presented in Wallach et al., "Evaluation Methods for Topic Models", ICML (2009)
* @author David Mimno
public class MarginalProbEstimator implements Serializable {
protected int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
protected double smoothingOnlyMass = 0.0;
protected double[] cachedCoefficients;
protected int[][] typeTopicCounts; // indexed by <feature index, topic index>
protected int[] tokensPerTopic; // indexed by <topic index>
protected Randoms random;
public MarginalProbEstimator (int numTopics,
double[] alpha, double alphaSum,
double beta,
int[][] typeTopicCounts,
int[] tokensPerTopic) {
this.numTopics = numTopics;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
this.typeTopicCounts = typeTopicCounts;
this.tokensPerTopic = tokensPerTopic;
this.alphaSum = alphaSum;
this.alpha = alpha;
this.beta = beta;
this.betaSum = beta * typeTopicCounts.length;
this.random = new Randoms();
cachedCoefficients = new double[ numTopics ];
// Initialize the smoothing-only sampling bucket
smoothingOnlyMass = 0;
// Initialize the cached coefficients, using only smoothing.
// These values will be selectively replaced in documents with
// non-zero counts in particular topics.
for (int topic=0; topic < numTopics; topic++) {
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
System.err.println("Topic Evaluator: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
public int[] getTokensPerTopic() { return tokensPerTopic; }
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public double evaluateLeftToRight (InstanceList testing, int numParticles, boolean usingResampling,
PrintStream docProbabilityStream) {
random = new Randoms();
double logNumParticles = Math.log(numParticles);
double totalLogLikelihood = 0;
for (Instance instance : testing) {
FeatureSequence tokenSequence = (FeatureSequence) instance.getData();
double docLogLikelihood = 0;
double[][] particleProbabilities = new double[ numParticles ][];
for (int particle = 0; particle < numParticles; particle++) {
particleProbabilities[particle] =
leftToRight(tokenSequence, usingResampling);
for (int position = 0; position < particleProbabilities[0].length; position++) {
double sum = 0;
for (int particle = 0; particle < numParticles; particle++) {
sum += particleProbabilities[particle][position];
if (sum > 0.0) {
docLogLikelihood += Math.log(sum) - logNumParticles;
if (docProbabilityStream != null) {
totalLogLikelihood += docLogLikelihood;
return totalLogLikelihood;
protected double[] leftToRight (FeatureSequence tokenSequence, boolean usingResampling) {
int[] oneDocTopics = new int[tokenSequence.getLength()];
double[] wordProbabilities = new double[tokenSequence.getLength()];
int[] currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
// Keep track of the number of tokens we've examined, not
// including out-of-vocabulary words
int tokensSoFar = 0;
int[] localTopicCounts = new int[numTopics];
int[] localTopicIndex = new int[numTopics];
// Build an array that densely lists the topics that
// have non-zero counts.
int denseIndex = 0;
// Record the total number of non-zero topics
int nonZeroTopics = denseIndex;
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
double logLikelihood = 0;
// All counts are now zero, we are starting completely fresh.
// Iterate over the positions (words) in the document
for (int limit = 0; limit < docLength; limit++) {
// Record the marginal probability of the token
// at the current limit, summed over all topics.
if (usingResampling) {
// Iterate up to the current limit
for (int position = 0; position < limit; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
// Check for out-of-vocabulary words
if (type >= typeTopicCounts.length ||
typeTopicCounts[type] == null) {
currentTypeTopicCounts = typeTopicCounts[type];
// Remove this token from all counts.
// Remove this topic's contribution to the
// normalizing constants.
// Note that we are using clamped estimates of P(w|t),
// so we are NOT changing smoothingOnlyMass.
topicBetaMass -= beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Decrement the local doc/topic counts
// Maintain the dense index, if we are deleting
// the old topic
if (localTopicCounts[oldTopic] == 0) {
// First get to the dense location associated with
// the old topic.
denseIndex = 0;
// We know it's in there somewhere, so we don't
// need bounds checking.
while (localTopicIndex[denseIndex] != oldTopic) {
// shift all remaining dense indices to the left.
while (denseIndex < nonZeroTopics) {
if (denseIndex < localTopicIndex.length - 1) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex + 1];
nonZeroTopics --;
// Add the old topic's contribution back into the
// normalizing constants.
topicBetaMass += beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Reset the cached coefficient for this topic
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts[oldTopic]) /
(tokensPerTopic[oldTopic] + betaSum);
// Now go over the type/topic counts, calculating the score
// for each topic.
int index = 0;
int currentTopic, currentValue;
boolean alreadyDecremented = false;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
i = -1;
while (sample > 0) {
sample -= topicTermScores[i];
newTopic = currentTypeTopicCounts[i] & topicMask;
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
else {
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
if (newTopic == -1) {
System.err.println("sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
localTopicIndex[denseIndex] = newTopic;
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
// We've just resampled all tokens UP TO the current limit,
// now sample the token AT the current limit.
type = tokenSequence.getIndexAtPosition(limit);
// Check for out-of-vocabulary words
if (type >= typeTopicCounts.length ||
typeTopicCounts[type] == null) {
currentTypeTopicCounts = typeTopicCounts[type];
int index = 0;
int currentTopic, currentValue;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
//System.out.println(" " + currentTopic + " = " + currentValue);
/* // Debugging, to make sure we're getting the right probabilities
for (int topic = 0; topic < numTopics; topic++) {
index = 0;
int displayCount = 0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (currentTopic == topic) {
displayCount = currentValue;
System.out.print(topic + "\t");
System.out.print("(" + localTopicCounts[topic] + " + " + alpha[topic] + ") / " +
"(" + alphaSum + " + " + tokensSoFar + ") * ");
System.out.println("(" + displayCount + " + " + beta + ") / " +
"(" + tokensPerTopic[topic] + " + " + betaSum + ") =" +
((displayCount + beta) / (tokensPerTopic[topic] + betaSum)));
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Note that we've been absorbing (alphaSum + docLength) into
// the normalizing constant. The true marginal probability needs
// this term, so we stick it back in.
wordProbabilities[limit] +=
(smoothingOnlyMass + topicBetaMass + topicTermMass) /
(alphaSum + tokensSoFar);
//System.out.println("normalizer: " + alphaSum + " + " + tokensSoFar);
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
i = -1;
while (sample > 0) {
sample -= topicTermScores[i];
newTopic = currentTypeTopicCounts[i] & topicMask;
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
else {
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
if (newTopic == -1) {
System.err.println("sampling error: "+ origSample + " " +
sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
// Put that new topic into the counts
oneDocTopics[limit] = newTopic;
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
localTopicIndex[denseIndex] = newTopic;
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
//System.out.println(type + "\t" + newTopic + "\t" + logLikelihood);
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
return wordProbabilities;
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
numTopics = in.readInt();
topicMask = in.readInt();
topicBits = in.readInt();
alpha = (double[]) in.readObject();
alphaSum = in.readDouble();
beta = in.readDouble();
betaSum = in.readDouble();
typeTopicCounts = (int[][]) in.readObject();
tokensPerTopic = (int[]) in.readObject();
random = (Randoms) in.readObject();
smoothingOnlyMass = in.readDouble();
cachedCoefficients = (double[]) in.readObject();
public static MarginalProbEstimator read (File f) throws Exception {
MarginalProbEstimator estimator = null;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
estimator = (MarginalProbEstimator) ois.readObject();
return estimator;