/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import gnu.trove.TIntIntHashMap;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.Iterator;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
* Latent Dirichlet Allocation with optimized hyperparameters
* @author David Mimno, Andrew McCallum
* @deprecated Use ParallelTopicModel instead, which uses substantially faster data structures even for non-parallel operation.
public class LDAHyper implements Serializable {
// Analogous to a cc.mallet.classify.Classification
public class Topication implements Serializable {
public Instance instance;
public LDAHyper model;
public LabelSequence topicSequence;
public Labeling topicDistribution; // not actually constructed by model fitting, but could be added for "test" documents.
public Topication (Instance instance, LDAHyper model, LabelSequence topicSequence) {
this.instance = instance;
this.model = model;
this.topicSequence = topicSequence;
// Maintainable serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeObject (instance);
out.writeObject (model);
out.writeObject (topicSequence);
out.writeObject (topicDistribution);
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
instance = (Instance) in.readObject();
model = (LDAHyper) in.readObject();
topicSequence = (LabelSequence) in.readObject();
topicDistribution = (Labeling) in.readObject();
protected ArrayList<Topication> data; // the training instances and their topic assignments
protected Alphabet alphabet; // the alphabet for the input data
protected LabelAlphabet topicAlphabet; // the alphabet for the topics
protected int numTopics; // Number of topics to be fit
protected int numTypes;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
public static final double DEFAULT_BETA = 0.01;
protected double smoothingOnlyMass = 0.0;
protected double[] cachedCoefficients;
int topicTermCount = 0;
int betaTopicCount = 0;
int smoothingOnlyCount = 0;
// Instance list for empirical likelihood calculation
protected InstanceList testing = null;
// An array to put the topic counts for the current document.
// Initialized locally below. Defined here to avoid
// garbage collection overhead.
protected int[] oneDocTopicCounts; // indexed by <document index, topic index>
protected gnu.trove.TIntIntHashMap[] typeTopicCounts; // indexed by <feature index, topic index>
protected int[] tokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
protected int[] docLengthCounts; // histogram of document sizes
protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
public int iterationsSoFar = 0;
public int numIterations = 1000;
public int burninPeriod = 20; // was 50; //was 200;
public int saveSampleInterval = 5; // was 10;
public int optimizeInterval = 20; // was 50;
public int showTopicsInterval = 10; // was 50;
public int wordsPerTopic = 7;
protected int outputModelInterval = 0;
protected String outputModelFilename;
protected int saveStateInterval = 0;
protected String stateFilename = null;
protected Randoms random;
protected NumberFormat formatter;
protected boolean printLogLikelihood = false;
public LDAHyper (int numberOfTopics) {
this (numberOfTopics, numberOfTopics, DEFAULT_BETA);
public LDAHyper (int numberOfTopics, double alphaSum, double beta) {
this (numberOfTopics, alphaSum, beta, new Randoms());
private static LabelAlphabet newLabelAlphabet (int numTopics) {
LabelAlphabet ret = new LabelAlphabet();
for (int i = 0; i < numTopics; i++)
return ret;
public LDAHyper (int numberOfTopics, double alphaSum, double beta, Randoms random) {
this (newLabelAlphabet (numberOfTopics), alphaSum, beta, random);
public LDAHyper (LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random)
this.data = new ArrayList<Topication>();
this.topicAlphabet = topicAlphabet;
this.numTopics = topicAlphabet.size();
this.alphaSum = alphaSum;
this.alpha = new double[numTopics];
Arrays.fill(alpha, alphaSum / numTopics);
this.beta = beta;
this.random = random;
oneDocTopicCounts = new int[numTopics];
tokensPerTopic = new int[numTopics];
formatter = NumberFormat.getInstance();
System.err.println("LDA: " + numTopics + " topics");
public Alphabet getAlphabet() { return alphabet; }
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public int getNumTopics() { return numTopics; }
public ArrayList<Topication> getData() { return data; }
public int getCountFeatureTopic (int featureIndex, int topicIndex) { return typeTopicCounts[featureIndex].get(topicIndex); }
public int getCountTokensPerTopic (int topicIndex) { return tokensPerTopic[topicIndex]; }
/** Held-out instances for empirical likelihood calculation */
public void setTestingInstances(InstanceList testing) {
this.testing = testing;
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
public void setRandomSeed(int seed) {
random = new Randoms(seed);
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
public void setModelOutput(int interval, String filename) {
this.outputModelInterval = interval;
this.outputModelFilename = filename;
/** Define how often and where to save the state
* @param interval Save a copy of the state every <code>interval</code> iterations.
* @param filename Save the state to this file, with the iteration number as a suffix
public void setSaveState(int interval, String filename) {
this.saveStateInterval = interval;
this.stateFilename = filename;
protected int instanceLength (Instance instance) {
return ((FeatureSequence)instance.getData()).size();
// Can be safely called multiple times. This method will complain if it can't handle the situation
private void initializeForTypes (Alphabet alphabet) {
if (this.alphabet == null) {
this.alphabet = alphabet;
this.numTypes = alphabet.size();
this.typeTopicCounts = new TIntIntHashMap[numTypes];
for (int fi = 0; fi < numTypes; fi++)
typeTopicCounts[fi] = new TIntIntHashMap();
this.betaSum = beta * numTypes;
} else if (alphabet != this.alphabet) {
throw new IllegalArgumentException ("Cannot change Alphabet.");
} else if (alphabet.size() != this.numTypes) {
this.numTypes = alphabet.size();
TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[numTypes];
for (int i = 0; i < typeTopicCounts.length; i++)
newTypeTopicCounts[i] = typeTopicCounts[i];
for (int i = typeTopicCounts.length; i < numTypes; i++)
newTypeTopicCounts[i] = new TIntIntHashMap();
// TODO AKM July 18: Why wasn't the next line there previously?
// this.typeTopicCounts = newTypeTopicCounts;
this.betaSum = beta * numTypes;
} // else, nothing changed, nothing to be done
private void initializeTypeTopicCounts () {
TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[numTypes];
for (int i = 0; i < typeTopicCounts.length; i++)
newTypeTopicCounts[i] = typeTopicCounts[i];
for (int i = typeTopicCounts.length; i < numTypes; i++)
newTypeTopicCounts[i] = new TIntIntHashMap();
this.typeTopicCounts = newTypeTopicCounts;
public void addInstances (InstanceList training) {
initializeForTypes (training.getDataAlphabet());
ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
for (Instance instance : training) {
LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
if (false)
// This method not yet obeying its last "false" argument, and must be for this to work
sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
else {
Randoms r = new Randoms();
int[] topics = topicSequence.getFeatures();
for (int i = 0; i < topics.length; i++)
topics[i] = r.nextInt(numTopics);
topicSequences.add (topicSequence);
addInstances (training, topicSequences);
public void addInstances (InstanceList training, List<LabelSequence> topics) {
initializeForTypes (training.getDataAlphabet());
assert (training.size() == topics.size());
for (int i = 0; i < training.size(); i++) {
Topication t = new Topication (training.get(i), this, topics.get(i));
data.add (t);
// Include sufficient statistics for this one doc
FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData();
LabelSequence topicSequence = t.topicSequence;
for (int pi = 0; pi < topicSequence.getLength(); pi++) {
int topic = topicSequence.getIndexAtPosition(pi);
typeTopicCounts[tokenSequence.getIndexAtPosition(pi)].adjustOrPutValue(topic, 1, 1);
* Gather statistics on the size of documents
* and create histograms for use in Dirichlet hyperparameter
* optimization.
protected void initializeHistogramsAndCachedValues() {
int maxTokens = 0;
int totalTokens = 0;
int seqLen;
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence fs = (FeatureSequence) data.get(doc).instance.getData();
seqLen = fs.getLength();
if (seqLen > maxTokens)
maxTokens = seqLen;
totalTokens += seqLen;
// Initialize the smoothing-only sampling bucket
smoothingOnlyMass = 0;
for (int topic = 0; topic < numTopics; topic++)
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
// Initialize the cached coefficients, using only smoothing.
cachedCoefficients = new double[ numTopics ];
for (int topic=0; topic < numTopics; topic++)
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
System.err.println("max tokens: " + maxTokens);
System.err.println("total tokens: " + totalTokens);
docLengthCounts = new int[maxTokens + 1];
topicDocCounts = new int[numTopics][maxTokens + 1];
public void estimate () throws IOException {
estimate (numIterations);
public void estimate (int iterationsThisRound) throws IOException {
long startTime = System.currentTimeMillis();
int maxIteration = iterationsSoFar + iterationsThisRound;
for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
long iterationStart = System.currentTimeMillis();
if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
printTopWords (System.out, wordsPerTopic, false);
if (testing != null) {
double el = empiricalLikelihood(1000, testing);
double ll = modelLogLikelihood();
double mi = topicLabelMutualInformation();
System.out.println(ll + "\t" + el + "\t" + mi);
if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
this.printState(new File(stateFilename + '.' + iterationsSoFar));
if (outputModelInterval != 0 && iterations % outputModelInterval == 0) {
this.write (new File(outputModelFilename+'.'+iterations));
// TODO this condition should also check that we have more than one sample to work with here
// (The number of samples actually obtained is not yet tracked.)
if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
iterationsSoFar % optimizeInterval == 0) {
alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts);
smoothingOnlyMass = 0.0;
for (int topic = 0; topic < numTopics; topic++) {
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
// Loop over every document in the corpus
topicTermCount = betaTopicCount = smoothingOnlyCount = 0;
int numDocs = data.size(); // TODO consider beginning by sub-sampling?
for (int di = 0; di < numDocs; di++) {
FeatureSequence tokenSequence = (FeatureSequence) data.get(di).instance.getData();
LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence;
sampleTopicsForOneDoc (tokenSequence, topicSequence,
iterationsSoFar >= burninPeriod && iterationsSoFar % saveSampleInterval == 0,
long elapsedMillis = System.currentTimeMillis() - iterationStart;
if (elapsedMillis < 1000) {
System.out.print(elapsedMillis + "ms ");
else {
System.out.print((elapsedMillis/1000) + "s ");
//System.out.println(topicTermCount + "\t" + betaTopicCount + "\t" + smoothingOnlyCount);
if (iterationsSoFar % 10 == 0) {
System.out.println ("<" + iterationsSoFar + "> ");
if (printLogLikelihood) System.out.println (modelLogLikelihood());
long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
long minutes = seconds / 60; seconds %= 60;
long hours = minutes / 60; minutes %= 60;
long days = hours / 24; hours %= 24;
System.out.print ("\nTotal time: ");
if (days != 0) { System.out.print(days); System.out.print(" days "); }
if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
System.out.print(seconds); System.out.println(" seconds");
private void clearHistograms() {
Arrays.fill(docLengthCounts, 0);
for (int topic = 0; topic < topicDocCounts.length; topic++)
Arrays.fill(topicDocCounts[topic], 0);
/** If topicSequence assignments are already set and accounted for in sufficient statistics,
* then readjustTopicsAndStats should be true. The topics will be re-sampled and sufficient statistics changes.
* If operating on a new or a test document, and featureSequence & topicSequence are not already accounted for in the sufficient statistics,
* then readjustTopicsAndStats should be false. The current topic assignments will be ignored, and the sufficient statistics
* will not be changed.
* If you want to estimate the Dirichlet alpha based on the per-document topic multinomials sampled this round,
* then saveStateForAlphaEstimation should be true. */
private void oldSampleTopicsForOneDoc (FeatureSequence featureSequence,
FeatureSequence topicSequence,
boolean saveStateForAlphaEstimation, boolean readjustTopicsAndStats)
long startTime = System.currentTimeMillis();
int[] oneDocTopics = topicSequence.getFeatures();
TIntIntHashMap currentTypeTopicCounts;
int type, oldTopic, newTopic;
double[] topicDistribution;
double topicDistributionSum;
int docLen = featureSequence.getLength();
int adjustedValue;
int[] topicIndices, topicCounts;
double weight;
// populate topic counts
Arrays.fill(oneDocTopicCounts, 0);
if (readjustTopicsAndStats) {
for (int token = 0; token < docLen; token++) {
oneDocTopicCounts[ oneDocTopics[token] ]++;
// Iterate over the tokens (words) in the document
for (int token = 0; token < docLen; token++) {
type = featureSequence.getIndexAtPosition(token);
oldTopic = oneDocTopics[token];
currentTypeTopicCounts = typeTopicCounts[type];
assert (currentTypeTopicCounts.size() != 0);
if (readjustTopicsAndStats) {
// Remove this token from all counts
adjustedValue = currentTypeTopicCounts.adjustOrPutValue(oldTopic, -1, -1);
if (adjustedValue == 0) currentTypeTopicCounts.remove(oldTopic);
else if (adjustedValue == -1) throw new IllegalStateException ("Token count in topic went negative.");
// Build a distribution over topics for this token
topicIndices = currentTypeTopicCounts.keys();
topicCounts = currentTypeTopicCounts.getValues();
topicDistribution = new double[topicIndices.length];
// TODO Yipes, memory allocation in the inner loop! But note that .keys and .getValues is doing this too.
topicDistributionSum = 0;
for (int i = 0; i < topicCounts.length; i++) {
int topic = topicIndices[i];
weight = ((topicCounts[i] + beta) / (tokensPerTopic[topic] + betaSum)) * ((oneDocTopicCounts[topic] + alpha[topic]));
topicDistributionSum += weight;
topicDistribution[topic] = weight;
// Sample a topic assignment from this distribution
newTopic = topicIndices[random.nextDiscrete (topicDistribution, topicDistributionSum)];
if (readjustTopicsAndStats) {
// Put that new topic into the counts
oneDocTopics[token] = newTopic;
typeTopicCounts[type].adjustOrPutValue(newTopic, 1, 1);
if (saveStateForAlphaEstimation) {
// Update the document-topic count histogram, for dirichlet estimation
docLengthCounts[ docLen ]++;
for (int topic=0; topic < numTopics; topic++) {
topicDocCounts[topic][ oneDocTopicCounts[topic] ]++;
protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
FeatureSequence topicSequence,
boolean shouldSaveState,
boolean readjustTopicsAndStats /* currently ignored */) {
int[] oneDocTopics = topicSequence.getFeatures();
TIntIntHashMap currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
// populate topic counts
TIntIntHashMap localTopicCounts = new TIntIntHashMap();
for (int position = 0; position < docLength; position++) {
localTopicCounts.adjustOrPutValue(oneDocTopics[position], 1, 1);
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
for (int topic: localTopicCounts.keys()) {
int n = localTopicCounts.get(topic);
// initialize the normalization constant for the (B * n_{t|d}) term
topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
// update the coefficients for the non-zero topics
cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
currentTypeTopicCounts = typeTopicCounts[type];
assert(currentTypeTopicCounts.get(oldTopic) >= 0);
// Remove this token from all counts.
// Note that we actually want to remove the key if it goes
// to zero, not set it to 0.
if (currentTypeTopicCounts.get(oldTopic) == 1) {
else {
currentTypeTopicCounts.adjustValue(oldTopic, -1);
smoothingOnlyMass -= alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts.get(oldTopic) /
(tokensPerTopic[oldTopic] + betaSum);
if (localTopicCounts.get(oldTopic) == 1) {
else {
localTopicCounts.adjustValue(oldTopic, -1);
smoothingOnlyMass += alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass += beta * localTopicCounts.get(oldTopic) /
(tokensPerTopic[oldTopic] + betaSum);
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts.get(oldTopic)) /
(tokensPerTopic[oldTopic] + betaSum);
topicTermMass = 0.0;
topicTermIndices = currentTypeTopicCounts.keys();
topicTermValues = currentTypeTopicCounts.getValues();
for (i=0; i < topicTermIndices.length; i++) {
int topic = topicTermIndices[i];
score =
cachedCoefficients[topic] * topicTermValues[i];
// ((alpha[topic] + localTopicCounts.get(topic)) *
// topicTermValues[i]) /
// (tokensPerTopic[topic] + betaSum);
// Note: I tried only doing this next bit if
// score > 0, but it didn't make any difference,
// at least in the first few iterations.
topicTermMass += score;
topicTermScores[i] = score;
// topicTermIndices[i] = topic;
// indicate that this is the last topic
// topicTermIndices[i] = -1;
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
i = -1;
while (sample > 0) {
sample -= topicTermScores[i];
newTopic = topicTermIndices[i];
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
sample /= beta;
topicTermIndices = localTopicCounts.keys();
topicTermValues = localTopicCounts.getValues();
for (i=0; i < topicTermIndices.length; i++) {
newTopic = topicTermIndices[i];
sample -= topicTermValues[i] /
(tokensPerTopic[newTopic] + betaSum);
if (sample <= 0.0) {
else {
sample -= topicBetaMass;
sample /= beta;
for (int topic = 0; topic < numTopics; topic++) {
sample -= alpha[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
if (newTopic == -1) {
System.err.println("LDAHyper sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("LDAHyper: New topic not sampled.");
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
smoothingOnlyMass -= alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts.get(newTopic) /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts.adjustOrPutValue(newTopic, 1, 1);
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts.get(newTopic)) /
(tokensPerTopic[newTopic] + betaSum);
smoothingOnlyMass += alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts.get(newTopic) /
(tokensPerTopic[newTopic] + betaSum);
assert(currentTypeTopicCounts.get(newTopic) >= 0);
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (int topic: localTopicCounts.keys()) {
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
if (shouldSaveState) {
// Update the document-topic count histogram,
// for dirichlet estimation
docLengthCounts[ docLength ]++;
for (int topic: localTopicCounts.keys()) {
topicDocCounts[topic][ localTopicCounts.get(topic) ]++;
public IDSorter[] getSortedTopicWords(int topic) {
IDSorter[] sortedTypes = new IDSorter[ numTypes ];
for (int type = 0; type < numTypes; type++)
sortedTypes[type] = new IDSorter(type, typeTopicCounts[type].get(topic));
return sortedTypes;
public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
PrintStream out = new PrintStream (file);
printTopWords(out, numWords, useNewLines);
// TreeSet implementation is ~70x faster than RankedFeatureVector -DM
public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {
for (int topic = 0; topic < numTopics; topic++) {
TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
for (int type = 0; type < numTypes; type++) {
if (typeTopicCounts[type].containsKey(topic)) {
sortedWords.add(new IDSorter(type, typeTopicCounts[type].get(topic)));
if (usingNewLines) {
out.println ("Topic " + topic);
int word = 1;
Iterator<IDSorter> iterator = sortedWords.iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.println(alphabet.lookupObject(info.getID()) + "\t" +
(int) info.getWeight());
else {
out.print (topic + "\t" + formatter.format(alpha[topic]) + "\t" + tokensPerTopic[topic] + "\t");
int word = 1;
Iterator<IDSorter> iterator = sortedWords.iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.print(alphabet.lookupObject(info.getID()) + " ");
public void topicXMLReport (PrintWriter out, int numWords) {
out.println("<?xml version='1.0' ?>");
for (int topic = 0; topic < numTopics; topic++) {
out.println(" <topic id='" + topic + "' alpha='" + alpha[topic] +
"' totalTokens='" + tokensPerTopic[topic] + "'>");
TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
for (int type = 0; type < numTypes; type++) {
if (typeTopicCounts[type].containsKey(topic)) {
sortedWords.add(new IDSorter(type, typeTopicCounts[type].get(topic)));
int word = 1;
Iterator<IDSorter> iterator = sortedWords.iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.println(" <word rank='" + word + "'>" +
alphabet.lookupObject(info.getID()) +
out.println(" </topic>");
public void topicXMLReportPhrases (PrintStream out, int numWords) {
int numTopics = this.getNumTopics();
gnu.trove.TObjectIntHashMap<String>[] phrases = new gnu.trove.TObjectIntHashMap[numTopics];
Alphabet alphabet = this.getAlphabet();
// Get counts of phrases
for (int ti = 0; ti < numTopics; ti++)
phrases[ti] = new gnu.trove.TObjectIntHashMap<String>();
for (int di = 0; di < this.getData().size(); di++) {
LDAHyper.Topication t = this.getData().get(di);
Instance instance = t.instance;
FeatureSequence fvs = (FeatureSequence) instance.getData();
boolean withBigrams = false;
if (fvs instanceof FeatureSequenceWithBigrams) withBigrams = true;
int prevtopic = -1;
int prevfeature = -1;
int topic = -1;
StringBuffer sb = null;
int feature = -1;
int doclen = fvs.size();
for (int pi = 0; pi < doclen; pi++) {
feature = fvs.getIndexAtPosition(pi);
topic = this.getData().get(di).topicSequence.getIndexAtPosition(pi);
if (topic == prevtopic && (!withBigrams || ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) != -1)) {
if (sb == null)
sb = new StringBuffer (alphabet.lookupObject(prevfeature).toString() + " " + alphabet.lookupObject(feature));
else {
sb.append (" ");
sb.append (alphabet.lookupObject(feature));
} else if (sb != null) {
String sbs = sb.toString();
//System.out.println ("phrase:"+sbs);
if (phrases[prevtopic].get(sbs) == 0)
prevtopic = prevfeature = -1;
sb = null;
} else {
prevtopic = topic;
prevfeature = feature;
// phrases[] now filled with counts
// Now start printing the XML
out.println("<?xml version='1.0' ?>");
double[] probs = new double[alphabet.size()];
for (int ti = 0; ti < numTopics; ti++) {
out.print(" <topic id=\"" + ti + "\" alpha=\"" + alpha[ti] +
"\" totalTokens=\"" + tokensPerTopic[ti] + "\" ");
// For gathering <term> and <phrase> output temporarily
// so that we can get topic-title information before printing it to "out".
ByteArrayOutputStream bout = new ByteArrayOutputStream();
PrintStream pout = new PrintStream (bout);
// For holding candidate topic titles
AugmentableFeatureVector titles = new AugmentableFeatureVector (new Alphabet());
// Print words
for (int type = 0; type < alphabet.size(); type++)
probs[type] = this.getCountFeatureTopic(type, ti) / (double)this.getCountTokensPerTopic(ti);
RankedFeatureVector rfv = new RankedFeatureVector (alphabet, probs);
for (int ri = 0; ri < numWords; ri++) {
int fi = rfv.getIndexAtRank(ri);
pout.println (" <term weight=\""+probs[fi]+"\" count=\""+this.getCountFeatureTopic(fi,ti)+"\">"+alphabet.lookupObject(fi)+ "</term>");
if (ri < 20) // consider top 20 individual words as candidate titles
titles.add(alphabet.lookupObject(fi), this.getCountFeatureTopic(fi,ti));
// Print phrases
Object[] keys = phrases[ti].keys();
int[] values = phrases[ti].getValues();
double counts[] = new double[keys.length];
for (int i = 0; i < counts.length; i++) counts[i] = values[i];
double countssum = MatrixOps.sum (counts);
Alphabet alph = new Alphabet(keys);
rfv = new RankedFeatureVector (alph, counts);
//out.println ("topic "+ti);
int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
//System.out.println ("topic "+ti+" numPhrases="+rfv.numLocations());
for (int ri = 0; ri < max; ri++) {
int fi = rfv.getIndexAtRank(ri);
pout.println (" <phrase weight=\""+counts[fi]/countssum+"\" count=\""+values[fi]+"\">"+alph.lookupObject(fi)+ "</phrase>");
// Any phrase count less than 20 is simply unreliable
if (ri < 20 && values[fi] > 20)
titles.add(alph.lookupObject(fi), 100*values[fi]); // prefer phrases with a factor of 100
// Select candidate titles
StringBuffer titlesStringBuffer = new StringBuffer();
rfv = new RankedFeatureVector (titles.getAlphabet(), titles);
int numTitles = 10;
for (int ri = 0; ri < numTitles && ri < rfv.numLocations(); ri++) {
// Don't add redundant titles
if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri).toString()) == -1) {
titlesStringBuffer.append (rfv.getObjectAtRank(ri));
if (ri < numTitles-1)
titlesStringBuffer.append (", ");
} else
out.println("titles=\"" + titlesStringBuffer.toString() + "\">");
out.println(" </topic>");
public void printDocumentTopics (File f) throws IOException {
printDocumentTopics (new PrintWriter (new FileWriter (f) ) );
public void printDocumentTopics (PrintWriter pw) {
printDocumentTopics (pw, 0.0, -1);
* @param pw A print writer
* @param threshold Only print topics with proportion greater than this number
* @param max Print no more than this many topics
public void printDocumentTopics (PrintWriter pw, double threshold, int max) {
pw.print ("#doc source topic proportion ...\n");
int docLen;
int[] topicCounts = new int[ numTopics ];
IDSorter[] sortedTopics = new IDSorter[ numTopics ];
for (int topic = 0; topic < numTopics; topic++) {
// Initialize the sorters with dummy values
sortedTopics[topic] = new IDSorter(topic, topic);
if (max < 0 || max > numTopics) {
max = numTopics;
for (int di = 0; di < data.size(); di++) {
LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence;
int[] currentDocTopics = topicSequence.getFeatures();
pw.print (di); pw.print (' ');
if (data.get(di).instance.getSource() != null) {
pw.print (data.get(di).instance.getSource());
else {
pw.print ("null-source");
pw.print (' ');
docLen = currentDocTopics.length;
// Count up the tokens
for (int token=0; token < docLen; token++) {
topicCounts[ currentDocTopics[token] ]++;
// And normalize
for (int topic = 0; topic < numTopics; topic++) {
sortedTopics[topic].set(topic, (float) topicCounts[topic] / docLen);
for (int i = 0; i < max; i++) {
if (sortedTopics[i].getWeight() < threshold) { break; }
pw.print (sortedTopics[i].getID() + " " +
sortedTopics[i].getWeight() + " ");
pw.print (" \n");
Arrays.fill(topicCounts, 0);
public void printState (File f) throws IOException {
PrintStream out =
new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
public void printState (PrintStream out) {
out.println ("#doc source pos typeindex type topic");
for (int di = 0; di < data.size(); di++) {
FeatureSequence tokenSequence = (FeatureSequence) data.get(di).instance.getData();
LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence;
String source = "NA";
if (data.get(di).instance.getSource() != null) {
source = data.get(di).instance.getSource().toString();
for (int pi = 0; pi < topicSequence.getLength(); pi++) {
int type = tokenSequence.getIndexAtPosition(pi);
int topic = topicSequence.getIndexAtPosition(pi);
out.print(di); out.print(' ');
out.print(source); out.print(' ');
out.print(pi); out.print(' ');
out.print(type); out.print(' ');
out.print(alphabet.lookupObject(type)); out.print(' ');
out.print(topic); out.println();
// Turbo topics
private class CorpusWordCounts {
Alphabet unigramAlphabet;
FeatureCounter unigramCounts = new FeatureCounter(unigramAlphabet);
public CorpusWordCounts(Alphabet alphabet) {
unigramAlphabet = alphabet;
private double mylog(double x) { return (x == 0) ? -1000000.0 : Math.log(x); }
// The likelihood ratio significance test
private double significanceTest(int thisUnigramCount, int nextUnigramCount, int nextBigramCount, int nextTotalCount, int minCount) {
if (nextBigramCount < minCount) return -1.0;
assert(nextUnigramCount >= nextBigramCount);
double log_pi_vu = mylog(nextBigramCount) - mylog(thisUnigramCount);
double log_pi_vnu = mylog(nextUnigramCount - nextBigramCount) - mylog(nextTotalCount - nextBigramCount);
double log_pi_v_old = mylog(nextUnigramCount) - mylog(nextTotalCount);
double log_1mp_v = mylog(1 - Math.exp(log_pi_vnu));
double log_1mp_vu = mylog(1 - Math.exp(log_pi_vu));
return 2 * (nextBigramCount * log_pi_vu +
(nextUnigramCount - nextBigramCount) * log_pi_vnu -
nextUnigramCount * log_pi_v_old +
(thisUnigramCount- nextBigramCount) * (log_1mp_vu - log_1mp_v));
public int[] significantBigrams(int word) {
public void write (File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f));
catch (IOException e) {
System.err.println("LDAHyper.write: Exception writing LDAHyper to file " + f + ": " + e);
public static LDAHyper read (File f) {
LDAHyper lda = null;
try {
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
lda = (LDAHyper) ois.readObject();
lda.initializeTypeTopicCounts(); // To work around a bug in Trove?
catch (IOException e) {
System.err.println("Exception reading file " + f + ": " + e);
catch (ClassNotFoundException e) {
System.err.println("Exception reading file " + f + ": " + e);
return lda;
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
// Instance lists
out.writeObject (data);
out.writeObject (alphabet);
out.writeObject (topicAlphabet);
out.writeInt (numTopics);
out.writeObject (alpha);
out.writeDouble (beta);
out.writeDouble (betaSum);
for (int fi = 0; fi < numTypes; fi++)
out.writeObject (typeTopicCounts[fi]);
for (int ti = 0; ti < numTopics; ti++)
out.writeInt (tokensPerTopic[ti]);
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int featuresLength;
int version = in.readInt ();
data = (ArrayList<Topication>) in.readObject ();
alphabet = (Alphabet) in.readObject();
topicAlphabet = (LabelAlphabet) in.readObject();
numTopics = in.readInt();
alpha = (double[]) in.readObject();
beta = in.readDouble();
betaSum = in.readDouble();
smoothingOnlyMass = in.readDouble();
cachedCoefficients = (double[]) in.readObject();
iterationsSoFar = in.readInt();
numIterations = in.readInt();
burninPeriod = in.readInt();
saveSampleInterval = in.readInt();
optimizeInterval = in.readInt();
showTopicsInterval = in.readInt();
wordsPerTopic = in.readInt();
outputModelInterval = in.readInt();
outputModelFilename = (String) in.readObject();
saveStateInterval = in.readInt();
stateFilename = (String) in.readObject();
random = (Randoms) in.readObject();
formatter = (NumberFormat) in.readObject();
printLogLikelihood = in.readBoolean();
docLengthCounts = (int[]) in.readObject();
topicDocCounts = (int[][]) in.readObject();
int numDocs = data.size();
this.numTypes = alphabet.size();
typeTopicCounts = new TIntIntHashMap[numTypes];
for (int fi = 0; fi < numTypes; fi++)
typeTopicCounts[fi] = (TIntIntHashMap) in.readObject();
tokensPerTopic = new int[numTopics];
for (int ti = 0; ti < numTopics; ti++)
tokensPerTopic[ti] = in.readInt();
public double topicLabelMutualInformation() {
int doc, level, label, topic, token, type;
int[] docTopics;
if (data.get(0).instance.getTargetAlphabet() == null) {
return 0.0;
int targetAlphabetSize = data.get(0).instance.getTargetAlphabet().size();
int[][] topicLabelCounts = new int[ numTopics ][ targetAlphabetSize ];
int[] topicCounts = new int[ numTopics ];
int[] labelCounts = new int[ targetAlphabetSize ];
int total = 0;
for (doc=0; doc < data.size(); doc++) {
label = data.get(doc).instance.getLabeling().getBestIndex();
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
docTopics = topicSequence.getFeatures();
for (token = 0; token < docTopics.length; token++) {
topic = docTopics[token];
topicLabelCounts[ topic ][ label ]++;
/* // This block will print out the best topics for each label
IDSorter[] wp = new IDSorter[numTypes];
for (topic = 0; topic < numTopics; topic++) {
for (type = 0; type < numTypes; type++) {
wp[type] = new IDSorter (type, (((double) typeTopicCounts[type][topic]) /
Arrays.sort (wp);
StringBuffer terms = new StringBuffer();
for (int i = 0; i < 8; i++) {
terms.append(" ");
for (label = 0; label < topicLabelCounts[topic].length; label++) {
System.out.println(topicLabelCounts[ topic ][ label ] + "\t" +
double topicEntropy = 0.0;
double labelEntropy = 0.0;
double jointEntropy = 0.0;
double p;
double log2 = Math.log(2);
for (topic = 0; topic < topicCounts.length; topic++) {
if (topicCounts[topic] == 0) { continue; }
p = (double) topicCounts[topic] / total;
topicEntropy -= p * Math.log(p) / log2;
for (label = 0; label < labelCounts.length; label++) {
if (labelCounts[label] == 0) { continue; }
p = (double) labelCounts[label] / total;
labelEntropy -= p * Math.log(p) / log2;
for (topic = 0; topic < topicCounts.length; topic++) {
for (label = 0; label < labelCounts.length; label++) {
if (topicLabelCounts[ topic ][ label ] == 0) { continue; }
p = (double) topicLabelCounts[ topic ][ label ] / total;
jointEntropy -= p * Math.log(p) / log2;
return topicEntropy + labelEntropy - jointEntropy;
public double empiricalLikelihood(int numSamples, InstanceList testing) {
double[][] likelihoods = new double[ testing.size() ][ numSamples ];
double[] multinomial = new double[numTypes];
double[] topicDistribution, currentSample, currentWeights;
Dirichlet topicPrior = new Dirichlet(alpha);
int sample, doc, topic, type, token, seqLen;
FeatureSequence fs;
for (sample = 0; sample < numSamples; sample++) {
topicDistribution = topicPrior.nextDistribution();
Arrays.fill(multinomial, 0.0);
for (topic = 0; topic < numTopics; topic++) {
for (type=0; type<numTypes; type++) {
multinomial[type] +=
topicDistribution[topic] *
(beta + typeTopicCounts[type].get(topic)) /
(betaSum + tokensPerTopic[topic]);
// Convert to log probabilities
for (type=0; type<numTypes; type++) {
assert(multinomial[type] > 0.0);
multinomial[type] = Math.log(multinomial[type]);
for (doc=0; doc<testing.size(); doc++) {
fs = (FeatureSequence) testing.get(doc).getData();
seqLen = fs.getLength();
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
// Adding this check since testing instances may
// have types not found in training instances,
// as pointed out by Steven Bethard.
if (type < numTypes) {
likelihoods[doc][sample] += multinomial[type];
double averageLogLikelihood = 0.0;
double logNumSamples = Math.log(numSamples);
for (doc=0; doc<testing.size(); doc++) {
double max = Double.NEGATIVE_INFINITY;
for (sample = 0; sample < numSamples; sample++) {
if (likelihoods[doc][sample] > max) {
max = likelihoods[doc][sample];
double sum = 0.0;
for (sample = 0; sample < numSamples; sample++) {
sum += Math.exp(likelihoods[doc][sample] - max);
averageLogLikelihood += Math.log(sum) + max - logNumSamples;
return averageLogLikelihood;
public double modelLogLikelihood() {
double logLikelihood = 0.0;
int nonZeroTopics;
// The likelihood of the model is a combination of a
// Dirichlet-multinomial for the words in each topic
// and a Dirichlet-multinomial for the topics in each
// document.
// The likelihood function of a dirichlet multinomial is
// Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i )
// prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) )
// So the log likelihood is
// logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
// sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]
// Do the documents first
int[] topicCounts = new int[numTopics];
double[] topicLogGammas = new double[numTopics];
int[] docTopics;
for (int topic=0; topic < numTopics; topic++) {
topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
for (int doc=0; doc < data.size(); doc++) {
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
docTopics = topicSequence.getFeatures();
for (int token=0; token < docTopics.length; token++) {
topicCounts[ docTopics[token] ]++;
for (int topic=0; topic < numTopics; topic++) {
if (topicCounts[topic] > 0) {
logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
topicLogGammas[ topic ]);
// subtract the (count + parameter) sum term
logLikelihood -= Dirichlet.logGammaStirling(alphaSum + docTopics.length);
Arrays.fill(topicCounts, 0);
// add the parameter sum term
logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);
// And the topics
// Count the number of type-topic pairs
int nonZeroTypeTopics = 0;
for (int type=0; type < numTypes; type++) {
int[] usedTopics = typeTopicCounts[type].keys();
for (int topic : usedTopics) {
int count = typeTopicCounts[type].get(topic);
if (count > 0) {
logLikelihood +=
Dirichlet.logGammaStirling(beta + count);
for (int topic=0; topic < numTopics; topic++) {
logLikelihood -=
Dirichlet.logGammaStirling( (beta * numTopics) +
tokensPerTopic[ topic ] );
logLikelihood +=
(Dirichlet.logGammaStirling(beta * numTopics)) -
(Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);
return logLikelihood;
// Recommended to use mallet/bin/vectors2topics instead.
public static void main (String[] args) throws IOException {
InstanceList training = InstanceList.load (new File(args[0]));
int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
InstanceList testing =
args.length > 2 ? InstanceList.load (new File(args[2])) : null;
LDAHyper lda = new LDAHyper (numTopics, 50.0, 0.01);
lda.printLogLikelihood = true;