/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
/**
* An implementation of topic model marginal probability estimators
* presented in Wallach et al., "Evaluation Methods for Topic Models", ICML (2009)
*
* @author David Mimno
*/
public class MarginalProbEstimator implements Serializable {
protected int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
protected double smoothingOnlyMass = 0.0;
protected double[] cachedCoefficients;
protected int[][] typeTopicCounts; // indexed by <feature index, topic index>
protected int[] tokensPerTopic; // indexed by <topic index>
protected Randoms random;
public MarginalProbEstimator (int numTopics,
double[] alpha, double alphaSum,
double beta,
int[][] typeTopicCounts,
int[] tokensPerTopic) {
this.numTopics = numTopics;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.typeTopicCounts = typeTopicCounts;
this.tokensPerTopic = tokensPerTopic;
this.alphaSum = alphaSum;
this.alpha = alpha;
this.beta = beta;
this.betaSum = beta * typeTopicCounts.length;
this.random = new Randoms();
cachedCoefficients = new double[ numTopics ];
// Initialize the smoothing-only sampling bucket
smoothingOnlyMass = 0;
// Initialize the cached coefficients, using only smoothing.
// These values will be selectively replaced in documents with
// non-zero counts in particular topics.
for (int topic=0; topic < numTopics; topic++) {
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
System.err.println("Topic Evaluator: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
}
public int[] getTokensPerTopic() { return tokensPerTopic; }
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public double evaluateLeftToRight (InstanceList testing, int numParticles, boolean usingResampling,
PrintStream docProbabilityStream) {
random = new Randoms();
double logNumParticles = Math.log(numParticles);
double totalLogLikelihood = 0;
for (Instance instance : testing) {
FeatureSequence tokenSequence = (FeatureSequence) instance.getData();
double docLogLikelihood = 0;
double[][] particleProbabilities = new double[ numParticles ][];
for (int particle = 0; particle < numParticles; particle++) {
particleProbabilities[particle] =
leftToRight(tokenSequence, usingResampling);
}
for (int position = 0; position < particleProbabilities[0].length; position++) {
double sum = 0;
for (int particle = 0; particle < numParticles; particle++) {
sum += particleProbabilities[particle][position];
}
if (sum > 0.0) {
docLogLikelihood += Math.log(sum) - logNumParticles;
}
}
if (docProbabilityStream != null) {
docProbabilityStream.println(docLogLikelihood);
}
totalLogLikelihood += docLogLikelihood;
}
return totalLogLikelihood;
}
protected double[] leftToRight (FeatureSequence tokenSequence, boolean usingResampling) {
int[] oneDocTopics = new int[tokenSequence.getLength()];
double[] wordProbabilities = new double[tokenSequence.getLength()];
int[] currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
// Keep track of the number of tokens we've examined, not
// including out-of-vocabulary words
int tokensSoFar = 0;
int[] localTopicCounts = new int[numTopics];
int[] localTopicIndex = new int[numTopics];
// Build an array that densely lists the topics that
// have non-zero counts.
int denseIndex = 0;
// Record the total number of non-zero topics
int nonZeroTopics = denseIndex;
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
double logLikelihood = 0;
// All counts are now zero, we are starting completely fresh.
// Iterate over the positions (words) in the document
for (int limit = 0; limit < docLength; limit++) {
// Record the marginal probability of the token
// at the current limit, summed over all topics.
if (usingResampling) {
// Iterate up to the current limit
for (int position = 0; position < limit; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
// Check for out-of-vocabulary words
if (type >= typeTopicCounts.length ||
typeTopicCounts[type] == null) {
continue;
}
currentTypeTopicCounts = typeTopicCounts[type];
// Remove this token from all counts.
// Remove this topic's contribution to the
// normalizing constants.
// Note that we are using clamped estimates of P(w|t),
// so we are NOT changing smoothingOnlyMass.
topicBetaMass -= beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Decrement the local doc/topic counts
localTopicCounts[oldTopic]--;
// Maintain the dense index, if we are deleting
// the old topic
if (localTopicCounts[oldTopic] == 0) {
// First get to the dense location associated with
// the old topic.
denseIndex = 0;
// We know it's in there somewhere, so we don't
// need bounds checking.
while (localTopicIndex[denseIndex] != oldTopic) {
denseIndex++;
}
// shift all remaining dense indices to the left.
while (denseIndex < nonZeroTopics) {
if (denseIndex < localTopicIndex.length - 1) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex + 1];
}
denseIndex++;
}
nonZeroTopics --;
}
// Add the old topic's contribution back into the
// normalizing constants.
topicBetaMass += beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Reset the cached coefficient for this topic
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts[oldTopic]) /
(tokensPerTopic[oldTopic] + betaSum);
// Now go over the type/topic counts, calculating the score
// for each topic.
int index = 0;
int currentTopic, currentValue;
boolean alreadyDecremented = false;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
index++;
}
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = currentTypeTopicCounts[i] & topicMask;
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
else {
//smoothingOnlyCount++;
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
newTopic++;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
}
if (newTopic == -1) {
System.err.println("sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
}
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts[newTopic]++;
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
denseIndex--;
}
localTopicIndex[denseIndex] = newTopic;
nonZeroTopics++;
}
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
// We've just resampled all tokens UP TO the current limit,
// now sample the token AT the current limit.
type = tokenSequence.getIndexAtPosition(limit);
// Check for out-of-vocabulary words
if (type >= typeTopicCounts.length ||
typeTopicCounts[type] == null) {
continue;
}
currentTypeTopicCounts = typeTopicCounts[type];
int index = 0;
int currentTopic, currentValue;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
//System.out.println(" " + currentTopic + " = " + currentValue);
index++;
}
/* // Debugging, to make sure we're getting the right probabilities
for (int topic = 0; topic < numTopics; topic++) {
index = 0;
int displayCount = 0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (currentTopic == topic) {
displayCount = currentValue;
break;
}
index++;
}
System.out.print(topic + "\t");
System.out.print("(" + localTopicCounts[topic] + " + " + alpha[topic] + ") / " +
"(" + alphaSum + " + " + tokensSoFar + ") * ");
System.out.println("(" + displayCount + " + " + beta + ") / " +
"(" + tokensPerTopic[topic] + " + " + betaSum + ") =" +
((displayCount + beta) / (tokensPerTopic[topic] + betaSum)));
}
*/
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Note that we've been absorbing (alphaSum + docLength) into
// the normalizing constant. The true marginal probability needs
// this term, so we stick it back in.
wordProbabilities[limit] +=
(smoothingOnlyMass + topicBetaMass + topicTermMass) /
(alphaSum + tokensSoFar);
//System.out.println("normalizer: " + alphaSum + " + " + tokensSoFar);
tokensSoFar++;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = currentTypeTopicCounts[i] & topicMask;
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
else {
//smoothingOnlyCount++;
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
newTopic++;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
}
if (newTopic == -1) {
System.err.println("sampling error: "+ origSample + " " +
sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
}
// Put that new topic into the counts
oneDocTopics[limit] = newTopic;
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts[newTopic]++;
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
denseIndex--;
}
localTopicIndex[denseIndex] = newTopic;
nonZeroTopics++;
}
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
//System.out.println(type + "\t" + newTopic + "\t" + logLikelihood);
}
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
return wordProbabilities;
}
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeInt(numTopics);
out.writeInt(topicMask);
out.writeInt(topicBits);
out.writeObject(alpha);
out.writeDouble(alphaSum);
out.writeDouble(beta);
out.writeDouble(betaSum);
out.writeObject(typeTopicCounts);
out.writeObject(tokensPerTopic);
out.writeObject(random);
out.writeDouble(smoothingOnlyMass);
out.writeObject(cachedCoefficients);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
numTopics = in.readInt();
topicMask = in.readInt();
topicBits = in.readInt();
alpha = (double[]) in.readObject();
alphaSum = in.readDouble();
beta = in.readDouble();
betaSum = in.readDouble();
typeTopicCounts = (int[][]) in.readObject();
tokensPerTopic = (int[]) in.readObject();
random = (Randoms) in.readObject();
smoothingOnlyMass = in.readDouble();
cachedCoefficients = (double[]) in.readObject();
}
public static MarginalProbEstimator read (File f) throws Exception {
MarginalProbEstimator estimator = null;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
estimator = (MarginalProbEstimator) ois.readObject();
ois.close();
return estimator;
}
}