/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.Iterator;
import java.util.concurrent.*;
import java.util.logging.*;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.util.Randoms;
import cc.mallet.util.MalletLogger;
/**
* Simple parallel threaded implementation of LDA,
* following the UCI NIPS paper, with SparseLDA
* sampling scheme and data structure.
*
* @author David Mimno, Andrew McCallum
*/
public class ParallelTopicModel implements Serializable {
protected static Logger logger = MalletLogger.getLogger(ParallelTopicModel.class.getName());
protected ArrayList<TopicAssignment> data; // the training instances and their topic assignments
protected Alphabet alphabet; // the alphabet for the input data
protected LabelAlphabet topicAlphabet; // the alphabet for the topics
protected int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected int numTypes;
protected int totalTokens;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
protected boolean usingSymmetricAlpha = false;
public static final double DEFAULT_BETA = 0.01;
protected int[][] typeTopicCounts; // indexed by <feature index, topic index>
protected int[] tokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
protected int[] docLengthCounts; // histogram of document sizes
protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
public int numIterations = 1000;
public int burninPeriod = 200;
public int saveSampleInterval = 10;
public int optimizeInterval = 50;
public int temperingInterval = 0;
public int showTopicsInterval = 50;
public int wordsPerTopic = 7;
protected int saveStateInterval = 0;
protected String stateFilename = null;
protected int saveModelInterval = 0;
protected String modelFilename = null;
protected int randomSeed = -1;
protected NumberFormat formatter;
protected boolean printLogLikelihood = true;
// The number of times each type appears in the corpus
int[] typeTotals;
// The max over typeTotals, used for beta optimization
int maxTypeCount;
int numThreads = 1;
public ParallelTopicModel (int numberOfTopics) {
this (numberOfTopics, numberOfTopics, DEFAULT_BETA);
}
public ParallelTopicModel (int numberOfTopics, double alphaSum, double beta) {
this (newLabelAlphabet (numberOfTopics), alphaSum, beta);
}
private static LabelAlphabet newLabelAlphabet (int numTopics) {
LabelAlphabet ret = new LabelAlphabet();
for (int i = 0; i < numTopics; i++)
ret.lookupIndex("topic"+i);
return ret;
}
public ParallelTopicModel (LabelAlphabet topicAlphabet, double alphaSum, double beta)
{
this.data = new ArrayList<TopicAssignment>();
this.topicAlphabet = topicAlphabet;
this.numTopics = topicAlphabet.size();
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.alphaSum = alphaSum;
this.alpha = new double[numTopics];
Arrays.fill(alpha, alphaSum / numTopics);
this.beta = beta;
tokensPerTopic = new int[numTopics];
formatter = NumberFormat.getInstance();
formatter.setMaximumFractionDigits(5);
logger.info("Coded LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
}
public Alphabet getAlphabet() { return alphabet; }
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public int getNumTopics() { return numTopics; }
public ArrayList<TopicAssignment> getData() { return data; }
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
}
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
}
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
}
public void setRandomSeed(int seed) {
randomSeed = seed;
}
/** Interval for optimizing Dirichlet hyperparameters */
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
// Make sure we always have at least one sample
// before optimizing hyperparameters
if (saveSampleInterval > optimizeInterval) {
saveSampleInterval = optimizeInterval;
}
}
public void setSymmetricAlpha(boolean b) {
usingSymmetricAlpha = b;
}
public void setTemperingInterval(int interval) {
temperingInterval = interval;
}
public void setNumThreads(int threads) {
this.numThreads = threads;
}
/** Define how often and where to save a text representation of the current state.
* Files are GZipped.
*
* @param interval Save a copy of the state every <code>interval</code> iterations.
* @param filename Save the state to this file, with the iteration number as a suffix
*/
public void setSaveState(int interval, String filename) {
this.saveStateInterval = interval;
this.stateFilename = filename;
}
/** Define how often and where to save a serialized model.
*
* @param interval Save a serialized model every <code>interval</code> iterations.
* @param filename Save to this file, with the iteration number as a suffix
*/
public void setSaveSerializedModel(int interval, String filename) {
this.saveModelInterval = interval;
this.modelFilename = filename;
}
public void addInstances (InstanceList training) {
alphabet = training.getDataAlphabet();
numTypes = alphabet.size();
betaSum = beta * numTypes;
typeTopicCounts = new int[numTypes][];
// Get the total number of occurrences of each word type
//int[] typeTotals = new int[numTypes];
typeTotals = new int[numTypes];
int doc = 0;
for (Instance instance : training) {
doc++;
FeatureSequence tokens = (FeatureSequence) instance.getData();
for (int position = 0; position < tokens.getLength(); position++) {
int type = tokens.getIndexAtPosition(position);
typeTotals[ type ]++;
}
}
maxTypeCount = 0;
// Allocate enough space so that we never have to worry about
// overflows: either the number of topics or the number of times
// the type occurs.
for (int type = 0; type < numTypes; type++) {
if (typeTotals[type] > maxTypeCount) { maxTypeCount = typeTotals[type]; }
typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ];
}
doc = 0;
Randoms random = null;
if (randomSeed == -1) {
random = new Randoms();
}
else {
random = new Randoms(randomSeed);
}
for (Instance instance : training) {
doc++;
FeatureSequence tokens = (FeatureSequence) instance.getData();
LabelSequence topicSequence =
new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < topics.length; position++) {
int topic = random.nextInt(numTopics);
topics[position] = topic;
}
TopicAssignment t = new TopicAssignment (instance, topicSequence);
data.add (t);
}
buildInitialTypeTopicCounts();
initializeHistograms();
}
public void initializeFromState(File stateFile) throws IOException {
String line;
String[] fields;
BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFile))));
line = reader.readLine();
// Skip some lines starting with "#" that describe the format and specify hyperparameters
while (line.startsWith("#")) {
line = reader.readLine();
}
fields = line.split(" ");
for (TopicAssignment document: data) {
FeatureSequence tokens = (FeatureSequence) document.instance.getData();
FeatureSequence topicSequence = (FeatureSequence) document.topicSequence;
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int type = tokens.getIndexAtPosition(position);
if (type == Integer.parseInt(fields[3])) {
topics[position] = Integer.parseInt(fields[5]);
}
else {
System.err.println("instance list and state do not match: " + line);
throw new IllegalStateException();
}
line = reader.readLine();
if (line != null) {
fields = line.split(" ");
}
}
}
buildInitialTypeTopicCounts();
initializeHistograms();
}
public void buildInitialTypeTopicCounts () {
// Clear the topic totals
Arrays.fill(tokensPerTopic, 0);
// Clear the type/topic counts, only
// looking at the entries before the first 0 entry.
for (int type = 0; type < numTypes; type++) {
int[] topicCounts = typeTopicCounts[type];
int position = 0;
while (position < topicCounts.length &&
topicCounts[position] > 0) {
topicCounts[position] = 0;
position++;
}
}
for (TopicAssignment document : data) {
FeatureSequence tokens = (FeatureSequence) document.instance.getData();
FeatureSequence topicSequence = (FeatureSequence) document.topicSequence;
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int topic = topics[position];
tokensPerTopic[topic]++;
// The format for these arrays is
// the topic in the rightmost bits
// the count in the remaining (left) bits.
// Since the count is in the high bits, sorting (desc)
// by the numeric value of the int guarantees that
// higher counts will be before the lower counts.
int type = tokens.getIndexAtPosition(position);
int[] currentTypeTopicCounts = typeTopicCounts[ type ];
// Start by assuming that the array is either empty
// or is in sorted (descending) order.
// Here we are only adding counts, so if we find
// an existing location with the topic, we only need
// to ensure that it is not larger than its left neighbor.
int index = 0;
int currentTopic = currentTypeTopicCounts[index] & topicMask;
int currentValue;
while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
index++;
if (index == currentTypeTopicCounts.length) {
logger.info("overflow on type " + type);
}
currentTopic = currentTypeTopicCounts[index] & topicMask;
}
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (currentValue == 0) {
// new value is 1, so we don't have to worry about sorting
// (except by topic suffix, which doesn't matter)
currentTypeTopicCounts[index] =
(1 << topicBits) + topic;
}
else {
currentTypeTopicCounts[index] =
((currentValue + 1) << topicBits) + topic;
// Now ensure that the array is still sorted by
// bubbling this value up.
while (index > 0 &&
currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
int temp = currentTypeTopicCounts[index];
currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
currentTypeTopicCounts[index - 1] = temp;
index--;
}
}
}
}
}
public void sumTypeTopicCounts (WorkerRunnable[] runnables) {
// Clear the topic totals
Arrays.fill(tokensPerTopic, 0);
// Clear the type/topic counts, only
// looking at the entries before the first 0 entry.
for (int type = 0; type < numTypes; type++) {
int[] targetCounts = typeTopicCounts[type];
int position = 0;
while (position < targetCounts.length &&
targetCounts[position] > 0) {
targetCounts[position] = 0;
position++;
}
}
for (int thread = 0; thread < numThreads; thread++) {
// Handle the total-tokens-per-topic array
int[] sourceTotals = runnables[thread].getTokensPerTopic();
for (int topic = 0; topic < numTopics; topic++) {
tokensPerTopic[topic] += sourceTotals[topic];
}
// Now handle the individual type topic counts
int[][] sourceTypeTopicCounts =
runnables[thread].getTypeTopicCounts();
for (int type = 0; type < numTypes; type++) {
// Here the source is the individual thread counts,
// and the target is the global counts.
int[] sourceCounts = sourceTypeTopicCounts[type];
int[] targetCounts = typeTopicCounts[type];
int sourceIndex = 0;
while (sourceIndex < sourceCounts.length &&
sourceCounts[sourceIndex] > 0) {
int topic = sourceCounts[sourceIndex] & topicMask;
int count = sourceCounts[sourceIndex] >> topicBits;
int targetIndex = 0;
int currentTopic = targetCounts[targetIndex] & topicMask;
int currentCount;
while (targetCounts[targetIndex] > 0 && currentTopic != topic) {
targetIndex++;
if (targetIndex == targetCounts.length) {
logger.info("overflow in merging on type " + type);
}
currentTopic = targetCounts[targetIndex] & topicMask;
}
currentCount = targetCounts[targetIndex] >> topicBits;
targetCounts[targetIndex] =
((currentCount + count) << topicBits) + topic;
// Now ensure that the array is still sorted by
// bubbling this value up.
while (targetIndex > 0 &&
targetCounts[targetIndex] > targetCounts[targetIndex - 1]) {
int temp = targetCounts[targetIndex];
targetCounts[targetIndex] = targetCounts[targetIndex - 1];
targetCounts[targetIndex - 1] = temp;
targetIndex--;
}
sourceIndex++;
}
}
}
/* // Debuggging code to ensure counts are being
// reconstructed correctly.
for (int type = 0; type < numTypes; type++) {
int[] targetCounts = typeTopicCounts[type];
int index = 0;
int count = 0;
while (index < targetCounts.length &&
targetCounts[index] > 0) {
count += targetCounts[index] >> topicBits;
index++;
}
if (count != typeTotals[type]) {
System.err.println("Expected " + typeTotals[type] + ", found " + count);
}
}
*/
}
/**
* Gather statistics on the size of documents
* and create histograms for use in Dirichlet hyperparameter
* optimization.
*/
private void initializeHistograms() {
int maxTokens = 0;
totalTokens = 0;
int seqLen;
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence fs = (FeatureSequence) data.get(doc).instance.getData();
seqLen = fs.getLength();
if (seqLen > maxTokens)
maxTokens = seqLen;
totalTokens += seqLen;
}
logger.info("max tokens: " + maxTokens);
logger.info("total tokens: " + totalTokens);
docLengthCounts = new int[maxTokens + 1];
topicDocCounts = new int[numTopics][maxTokens + 1];
}
public void optimizeAlpha(WorkerRunnable[] runnables) {
// First clear the sufficient statistic histograms
Arrays.fill(docLengthCounts, 0);
for (int topic = 0; topic < topicDocCounts.length; topic++) {
Arrays.fill(topicDocCounts[topic], 0);
}
for (int thread = 0; thread < numThreads; thread++) {
int[] sourceLengthCounts = runnables[thread].getDocLengthCounts();
int[][] sourceTopicCounts = runnables[thread].getTopicDocCounts();
for (int count=0; count < sourceLengthCounts.length; count++) {
if (sourceLengthCounts[count] > 0) {
docLengthCounts[count] += sourceLengthCounts[count];
sourceLengthCounts[count] = 0;
}
}
for (int topic=0; topic < numTopics; topic++) {
if (! usingSymmetricAlpha) {
for (int count=0; count < sourceTopicCounts[topic].length; count++) {
if (sourceTopicCounts[topic][count] > 0) {
topicDocCounts[topic][count] += sourceTopicCounts[topic][count];
sourceTopicCounts[topic][count] = 0;
}
}
}
else {
// For the symmetric version, we only need one
// count array, which I'm putting in the same
// data structure, but for topic 0. All other
// topic histograms will be empty.
// I'm duplicating this for loop, which
// isn't the best thing, but it means only checking
// whether we are symmetric or not numTopics times,
// instead of numTopics * longest document length.
for (int count=0; count < sourceTopicCounts[topic].length; count++) {
if (sourceTopicCounts[topic][count] > 0) {
topicDocCounts[0][count] += sourceTopicCounts[topic][count];
// ^ the only change
sourceTopicCounts[topic][count] = 0;
}
}
}
}
}
if (usingSymmetricAlpha) {
alphaSum = Dirichlet.learnSymmetricConcentration(topicDocCounts[0],
docLengthCounts,
numTopics,
alphaSum);
for (int topic = 0; topic < numTopics; topic++) {
alpha[topic] = alphaSum / numTopics;
}
}
else {
alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts, 1.001, 1.0, 1);
}
}
public void temperAlpha(WorkerRunnable[] runnables) {
// First clear the sufficient statistic histograms
Arrays.fill(docLengthCounts, 0);
for (int topic = 0; topic < topicDocCounts.length; topic++) {
Arrays.fill(topicDocCounts[topic], 0);
}
for (int thread = 0; thread < numThreads; thread++) {
int[] sourceLengthCounts = runnables[thread].getDocLengthCounts();
int[][] sourceTopicCounts = runnables[thread].getTopicDocCounts();
for (int count=0; count < sourceLengthCounts.length; count++) {
if (sourceLengthCounts[count] > 0) {
sourceLengthCounts[count] = 0;
}
}
for (int topic=0; topic < numTopics; topic++) {
for (int count=0; count < sourceTopicCounts[topic].length; count++) {
if (sourceTopicCounts[topic][count] > 0) {
sourceTopicCounts[topic][count] = 0;
}
}
}
}
for (int topic = 0; topic < numTopics; topic++) {
alpha[topic] = 1.0;
}
alphaSum = numTopics;
}
public void optimizeBeta(WorkerRunnable[] runnables) {
// The histogram starts at count 0, so if all of the
// tokens of the most frequent type were assigned to one topic,
// we would need to store a maxTypeCount + 1 count.
int[] countHistogram = new int[maxTypeCount + 1];
// Now count the number of type/topic pairs that have
// each number of tokens.
int index;
for (int type = 0; type < numTypes; type++) {
int[] counts = typeTopicCounts[type];
index = 0;
while (index < counts.length &&
counts[index] > 0) {
int count = counts[index] >> topicBits;
countHistogram[count]++;
index++;
}
}
// Figure out how large we need to make the "observation lengths"
// histogram.
int maxTopicSize = 0;
for (int topic = 0; topic < numTopics; topic++) {
if (tokensPerTopic[topic] > maxTopicSize) {
maxTopicSize = tokensPerTopic[topic];
}
}
// Now allocate it and populate it.
int[] topicSizeHistogram = new int[maxTopicSize + 1];
for (int topic = 0; topic < numTopics; topic++) {
topicSizeHistogram[ tokensPerTopic[topic] ]++;
}
betaSum = Dirichlet.learnSymmetricConcentration(countHistogram,
topicSizeHistogram,
numTypes,
betaSum);
beta = betaSum / numTypes;
logger.info("[beta: " + formatter.format(beta) + "] ");
// Now publish the new value
for (int thread = 0; thread < numThreads; thread++) {
runnables[thread].resetBeta(beta, betaSum);
}
}
public void estimate () throws IOException {
long startTime = System.currentTimeMillis();
WorkerRunnable[] runnables = new WorkerRunnable[numThreads];
int docsPerThread = data.size() / numThreads;
int offset = 0;
if (numThreads > 1) {
for (int thread = 0; thread < numThreads; thread++) {
int[] runnableTotals = new int[numTopics];
System.arraycopy(tokensPerTopic, 0, runnableTotals, 0, numTopics);
int[][] runnableCounts = new int[numTypes][];
for (int type = 0; type < numTypes; type++) {
int[] counts = new int[typeTopicCounts[type].length];
System.arraycopy(typeTopicCounts[type], 0, counts, 0, counts.length);
runnableCounts[type] = counts;
}
// some docs may be missing at the end due to integer division
if (thread == numThreads - 1) {
docsPerThread = data.size() - offset;
}
Randoms random = null;
if (randomSeed == -1) {
random = new Randoms();
}
else {
random = new Randoms(randomSeed);
}
runnables[thread] = new WorkerRunnable(numTopics,
alpha, alphaSum, beta,
random, data,
runnableCounts, runnableTotals,
offset, docsPerThread);
runnables[thread].initializeAlphaStatistics(docLengthCounts.length);
offset += docsPerThread;
}
}
else {
// If there is only one thread, copy the typeTopicCounts
// arrays directly, rather than allocating new memory.
Randoms random = null;
if (randomSeed == -1) {
random = new Randoms();
}
else {
random = new Randoms(randomSeed);
}
runnables[0] = new WorkerRunnable(numTopics,
alpha, alphaSum, beta,
random, data,
typeTopicCounts, tokensPerTopic,
offset, docsPerThread);
runnables[0].initializeAlphaStatistics(docLengthCounts.length);
// If there is only one thread, we
// can avoid communications overhead.
// This switch informs the thread not to
// gather statistics for its portion of the data.
runnables[0].makeOnlyThread();
}
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
for (int iteration = 1; iteration <= numIterations; iteration++) {
long iterationStart = System.currentTimeMillis();
if (showTopicsInterval != 0 && iteration != 0 && iteration % showTopicsInterval == 0) {
logger.info("\n" + displayTopWords (wordsPerTopic, false));
}
if (saveStateInterval != 0 && iteration % saveStateInterval == 0) {
this.printState(new File(stateFilename + '.' + iteration));
}
if (saveModelInterval != 0 && iteration % saveModelInterval == 0) {
this.write(new File(modelFilename + '.' + iteration));
}
if (numThreads > 1) {
// Submit runnables to thread pool
for (int thread = 0; thread < numThreads; thread++) {
if (iteration > burninPeriod && optimizeInterval != 0 &&
iteration % saveSampleInterval == 0) {
runnables[thread].collectAlphaStatistics();
}
logger.fine("submitting thread " + thread);
executor.submit(runnables[thread]);
//runnables[thread].run();
}
// I'm getting some problems that look like
// a thread hasn't started yet when it is first
// polled, so it appears to be finished.
// This only occurs in very short corpora.
try {
Thread.sleep(20);
} catch (InterruptedException e) {
}
boolean finished = false;
while (! finished) {
try {
Thread.sleep(10);
} catch (InterruptedException e) {
}
finished = true;
// Are all the threads done?
for (int thread = 0; thread < numThreads; thread++) {
//logger.info("thread " + thread + " done? " + runnables[thread].isFinished);
finished = finished && runnables[thread].isFinished;
}
}
//System.out.print("[" + (System.currentTimeMillis() - iterationStart) + "] ");
sumTypeTopicCounts(runnables);
//System.out.print("[" + (System.currentTimeMillis() - iterationStart) + "] ");
for (int thread = 0; thread < numThreads; thread++) {
int[] runnableTotals = runnables[thread].getTokensPerTopic();
System.arraycopy(tokensPerTopic, 0, runnableTotals, 0, numTopics);
int[][] runnableCounts = runnables[thread].getTypeTopicCounts();
for (int type = 0; type < numTypes; type++) {
int[] targetCounts = runnableCounts[type];
int[] sourceCounts = typeTopicCounts[type];
int index = 0;
while (index < sourceCounts.length) {
if (sourceCounts[index] != 0) {
targetCounts[index] = sourceCounts[index];
}
else if (targetCounts[index] != 0) {
targetCounts[index] = 0;
}
else {
break;
}
index++;
}
//System.arraycopy(typeTopicCounts[type], 0, counts, 0, counts.length);
}
}
}
else {
if (iteration > burninPeriod && optimizeInterval != 0 &&
iteration % saveSampleInterval == 0) {
runnables[0].collectAlphaStatistics();
}
runnables[0].run();
}
long elapsedMillis = System.currentTimeMillis() - iterationStart;
if (elapsedMillis < 1000) {
logger.fine(elapsedMillis + "ms ");
}
else {
logger.fine((elapsedMillis/1000) + "s ");
}
if (iteration > burninPeriod && optimizeInterval != 0 &&
iteration % optimizeInterval == 0) {
optimizeAlpha(runnables);
optimizeBeta(runnables);
logger.fine("[O " + (System.currentTimeMillis() - iterationStart) + "] ");
}
if (iteration % 10 == 0) {
if (printLogLikelihood) {
logger.info ("<" + iteration + "> LL/token: " + formatter.format(modelLogLikelihood() / totalTokens));
}
else {
logger.info ("<" + iteration + ">");
}
}
}
executor.shutdownNow();
long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
long minutes = seconds / 60; seconds %= 60;
long hours = minutes / 60; minutes %= 60;
long days = hours / 24; hours %= 24;
StringBuilder timeReport = new StringBuilder();
timeReport.append("\nTotal time: ");
if (days != 0) { timeReport.append(days); timeReport.append(" days "); }
if (hours != 0) { timeReport.append(hours); timeReport.append(" hours "); }
if (minutes != 0) { timeReport.append(minutes); timeReport.append(" minutes "); }
timeReport.append(seconds); timeReport.append(" seconds");
logger.info(timeReport.toString());
}
public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
PrintStream out = new PrintStream (file);
printTopWords(out, numWords, useNewLines);
out.close();
}
/**
* Return an array of sorted sets (one set per topic). Each set
* contains IDSorter objects with integer keys into the alphabet.
* To get direct access to the Strings, use getTopWords().
*/
public TreeSet[] getSortedWords () {
TreeSet[] topicSortedWords = new TreeSet[ numTopics ];
// Initialize the tree sets
for (int topic = 0; topic < numTopics; topic++) {
topicSortedWords[topic] = new TreeSet<IDSorter>();
}
// Collect counts
for (int type = 0; type < numTypes; type++) {
int[] topicCounts = typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int topic = topicCounts[index] & topicMask;
int count = topicCounts[index] >> topicBits;
topicSortedWords[topic].add(new IDSorter(type, count));
index++;
}
}
return topicSortedWords;
}
/** Return an array (one element for each topic) of arrays of words, which
* are the most probable words for that topic in descending order. These
* are returned as Objects, but will probably be Strings.
*
* @param numWords The maximum length of each topic's array of words (may be less).
*/
public Object[][] getTopWords(int numWords) {
TreeSet[] topicSortedWords = getSortedWords();
Object[][] result = new Object[ numTopics ][];
for (int topic = 0; topic < numTopics; topic++) {
TreeSet<IDSorter> sortedWords = topicSortedWords[topic];
// How many words should we report? Some topics may have fewer than
// the default number of words with non-zero weight.
int limit = numWords;
if (sortedWords.size() < numWords) { limit = sortedWords.size(); }
result[topic] = new Object[limit];
Iterator<IDSorter> iterator = sortedWords.iterator();
for (int i=0; i < limit; i++) {
IDSorter info = iterator.next();
result[topic][i] = alphabet.lookupObject(info.getID());
}
}
return result;
}
public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {
out.print(displayTopWords(numWords, usingNewLines));
}
public String displayTopWords (int numWords, boolean usingNewLines) {
StringBuilder out = new StringBuilder();
TreeSet[] topicSortedWords = getSortedWords();
// Print results for each topic
for (int topic = 0; topic < numTopics; topic++) {
TreeSet<IDSorter> sortedWords = topicSortedWords[topic];
int word = 1;
Iterator<IDSorter> iterator = sortedWords.iterator();
if (usingNewLines) {
out.append (topic + "\t" + formatter.format(alpha[topic]) + "\n");
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.append(alphabet.lookupObject(info.getID()) + "\t" + formatter.format(info.getWeight()) + "\n");
word++;
}
}
else {
out.append (topic + "\t" + formatter.format(alpha[topic]) + "\t");
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.append(alphabet.lookupObject(info.getID()) + " ");
word++;
}
out.append ("\n");
}
}
return out.toString();
}
public void topicXMLReport (PrintWriter out, int numWords) {
TreeSet[] topicSortedWords = getSortedWords();
out.println("<?xml version='1.0' ?>");
out.println("<topicModel>");
for (int topic = 0; topic < numTopics; topic++) {
out.println(" <topic id='" + topic + "' alpha='" + alpha[topic] +
"' totalTokens='" + tokensPerTopic[topic] + "'>");
int word = 1;
Iterator<IDSorter> iterator = topicSortedWords[topic].iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.println(" <word rank='" + word + "'>" +
alphabet.lookupObject(info.getID()) +
"</word>");
word++;
}
out.println(" </topic>");
}
out.println("</topicModel>");
}
public void topicPhraseXMLReport(PrintWriter out, int numWords) {
int numTopics = this.getNumTopics();
gnu.trove.TObjectIntHashMap<String>[] phrases = new gnu.trove.TObjectIntHashMap[numTopics];
Alphabet alphabet = this.getAlphabet();
// Get counts of phrases
for (int ti = 0; ti < numTopics; ti++)
phrases[ti] = new gnu.trove.TObjectIntHashMap<String>();
for (int di = 0; di < this.getData().size(); di++) {
TopicAssignment t = this.getData().get(di);
Instance instance = t.instance;
FeatureSequence fvs = (FeatureSequence) instance.getData();
boolean withBigrams = false;
if (fvs instanceof FeatureSequenceWithBigrams) withBigrams = true;
int prevtopic = -1;
int prevfeature = -1;
int topic = -1;
StringBuffer sb = null;
int feature = -1;
int doclen = fvs.size();
for (int pi = 0; pi < doclen; pi++) {
feature = fvs.getIndexAtPosition(pi);
topic = this.getData().get(di).topicSequence.getIndexAtPosition(pi);
if (topic == prevtopic && (!withBigrams || ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) != -1)) {
if (sb == null)
sb = new StringBuffer (alphabet.lookupObject(prevfeature).toString() + " " + alphabet.lookupObject(feature));
else {
sb.append (" ");
sb.append (alphabet.lookupObject(feature));
}
} else if (sb != null) {
String sbs = sb.toString();
//logger.info ("phrase:"+sbs);
if (phrases[prevtopic].get(sbs) == 0)
phrases[prevtopic].put(sbs,0);
phrases[prevtopic].increment(sbs);
prevtopic = prevfeature = -1;
sb = null;
} else {
prevtopic = topic;
prevfeature = feature;
}
}
}
// phrases[] now filled with counts
// Now start printing the XML
out.println("<?xml version='1.0' ?>");
out.println("<topics>");
TreeSet[] topicSortedWords = getSortedWords();
double[] probs = new double[alphabet.size()];
for (int ti = 0; ti < numTopics; ti++) {
out.print(" <topic id=\"" + ti + "\" alpha=\"" + alpha[ti] +
"\" totalTokens=\"" + tokensPerTopic[ti] + "\" ");
// For gathering <term> and <phrase> output temporarily
// so that we can get topic-title information before printing it to "out".
ByteArrayOutputStream bout = new ByteArrayOutputStream();
PrintStream pout = new PrintStream (bout);
// For holding candidate topic titles
AugmentableFeatureVector titles = new AugmentableFeatureVector (new Alphabet());
// Print words
int word = 1;
Iterator<IDSorter> iterator = topicSortedWords[ti].iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
pout.println(" <word weight=\""+(info.getWeight()/tokensPerTopic[ti])+"\" count=\""+Math.round(info.getWeight())+"\">"
+ alphabet.lookupObject(info.getID()) +
"</word>");
word++;
if (word < 20) // consider top 20 individual words as candidate titles
titles.add(alphabet.lookupObject(info.getID()), info.getWeight());
}
/*
for (int type = 0; type < alphabet.size(); type++)
probs[type] = this.getCountFeatureTopic(type, ti) / (double)this.getCountTokensPerTopic(ti);
RankedFeatureVector rfv = new RankedFeatureVector (alphabet, probs);
for (int ri = 0; ri < numWords; ri++) {
int fi = rfv.getIndexAtRank(ri);
pout.println (" <term weight=\""+probs[fi]+"\" count=\""+this.getCountFeatureTopic(fi,ti)+"\">"+alphabet.lookupObject(fi)+ "</term>");
if (ri < 20) // consider top 20 individual words as candidate titles
titles.add(alphabet.lookupObject(fi), this.getCountFeatureTopic(fi,ti));
}
*/
// Print phrases
Object[] keys = phrases[ti].keys();
int[] values = phrases[ti].getValues();
double counts[] = new double[keys.length];
for (int i = 0; i < counts.length; i++) counts[i] = values[i];
double countssum = MatrixOps.sum (counts);
Alphabet alph = new Alphabet(keys);
RankedFeatureVector rfv = new RankedFeatureVector (alph, counts);
int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
for (int ri = 0; ri < max; ri++) {
int fi = rfv.getIndexAtRank(ri);
pout.println (" <phrase weight=\""+counts[fi]/countssum+"\" count=\""+values[fi]+"\">"+alph.lookupObject(fi)+ "</phrase>");
// Any phrase count less than 20 is simply unreliable
if (ri < 20 && values[fi] > 20)
titles.add(alph.lookupObject(fi), 100*values[fi]); // prefer phrases with a factor of 100
}
// Select candidate titles
StringBuffer titlesStringBuffer = new StringBuffer();
rfv = new RankedFeatureVector (titles.getAlphabet(), titles);
int numTitles = 10;
for (int ri = 0; ri < numTitles && ri < rfv.numLocations(); ri++) {
// Don't add redundant titles
if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri).toString()) == -1) {
titlesStringBuffer.append (rfv.getObjectAtRank(ri));
if (ri < numTitles-1)
titlesStringBuffer.append (", ");
} else
numTitles++;
}
out.println("titles=\"" + titlesStringBuffer.toString() + "\">");
out.print(bout.toString());
out.println(" </topic>");
}
out.println("</topics>");
}
/**
* Write the internal representation of type-topic counts
* (count/topic pairs in descending order by count) to a file.
*/
public void printTypeTopicCounts(File file) throws IOException {
PrintWriter out = new PrintWriter (new FileWriter (file) );
for (int type = 0; type < numTypes; type++) {
StringBuilder buffer = new StringBuilder();
buffer.append(type + " " + alphabet.lookupObject(type));
int[] topicCounts = typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int topic = topicCounts[index] & topicMask;
int count = topicCounts[index] >> topicBits;
buffer.append(" " + topic + ":" + count);
index++;
}
out.println(buffer);
}
out.close();
}
public void printTopicWordWeights(File file) throws IOException {
PrintWriter out = new PrintWriter (new FileWriter (file) );
printTopicWordWeights(out);
out.close();
}
/**
* Print an unnormalized weight for every word in every topic.
* Most of these will be equal to the smoothing parameter beta.
*/
public void printTopicWordWeights(PrintWriter out) throws IOException {
// Probably not the most efficient way to do this...
for (int topic = 0; topic < numTopics; topic++) {
for (int type = 0; type < numTypes; type++) {
int[] topicCounts = typeTopicCounts[type];
double weight = beta;
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int currentTopic = topicCounts[index] & topicMask;
if (currentTopic == topic) {
weight += topicCounts[index] >> topicBits;
break;
}
index++;
}
out.println(topic + "\t" + alphabet.lookupObject(type) + "\t" + weight);
}
}
}
public void printDocumentTopics (File file) throws IOException {
PrintWriter out = new PrintWriter (new FileWriter (file) );
printDocumentTopics (out);
out.close();
}
public void printDocumentTopics (PrintWriter out) {
printDocumentTopics (out, 0.0, -1);
}
/**
* @param out A print writer
* @param threshold Only print topics with proportion greater than this number
* @param max Print no more than this many topics
*/
public void printDocumentTopics (PrintWriter out, double threshold, int max) {
out.print ("#doc source topic proportion ...\n");
int docLen;
int[] topicCounts = new int[ numTopics ];
IDSorter[] sortedTopics = new IDSorter[ numTopics ];
for (int topic = 0; topic < numTopics; topic++) {
// Initialize the sorters with dummy values
sortedTopics[topic] = new IDSorter(topic, topic);
}
if (max < 0 || max > numTopics) {
max = numTopics;
}
for (int doc = 0; doc < data.size(); doc++) {
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
int[] currentDocTopics = topicSequence.getFeatures();
out.print (doc); out.print (' ');
if (data.get(doc).instance.getSource() != null) {
out.print (data.get(doc).instance.getSource());
}
else {
out.print ("null-source");
}
out.print (' ');
docLen = currentDocTopics.length;
// Count up the tokens
for (int token=0; token < docLen; token++) {
topicCounts[ currentDocTopics[token] ]++;
}
// And normalize
for (int topic = 0; topic < numTopics; topic++) {
sortedTopics[topic].set(topic, (float) topicCounts[topic] / docLen);
}
Arrays.sort(sortedTopics);
for (int i = 0; i < max; i++) {
if (sortedTopics[i].getWeight() < threshold) { break; }
out.print (sortedTopics[i].getID() + " " +
sortedTopics[i].getWeight() + " ");
}
out.print (" \n");
Arrays.fill(topicCounts, 0);
}
}
public void printState (File f) throws IOException {
PrintStream out =
new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
printState(out);
out.close();
}
public void printState (PrintStream out) {
out.println ("#doc source pos typeindex type topic");
out.print("#alpha : ");
for (int topic = 0; topic < numTopics; topic++) {
out.print(alpha[topic] + " ");
}
out.println();
out.println("#beta : " + beta);
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData();
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
String source = "NA";
if (data.get(doc).instance.getSource() != null) {
source = data.get(doc).instance.getSource().toString();
}
for (int pi = 0; pi < topicSequence.getLength(); pi++) {
int type = tokenSequence.getIndexAtPosition(pi);
int topic = topicSequence.getIndexAtPosition(pi);
out.print(doc); out.print(' ');
out.print(source); out.print(' ');
out.print(pi); out.print(' ');
out.print(type); out.print(' ');
out.print(alphabet.lookupObject(type)); out.print(' ');
out.print(topic); out.println();
}
}
}
public double modelLogLikelihood() {
double logLikelihood = 0.0;
int nonZeroTopics;
// The likelihood of the model is a combination of a
// Dirichlet-multinomial for the words in each topic
// and a Dirichlet-multinomial for the topics in each
// document.
// The likelihood function of a dirichlet multinomial is
// Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i )
// prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) )
// So the log likelihood is
// logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
// sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]
// Do the documents first
int[] topicCounts = new int[numTopics];
double[] topicLogGammas = new double[numTopics];
int[] docTopics;
for (int topic=0; topic < numTopics; topic++) {
topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
}
for (int doc=0; doc < data.size(); doc++) {
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
docTopics = topicSequence.getFeatures();
for (int token=0; token < docTopics.length; token++) {
topicCounts[ docTopics[token] ]++;
}
for (int topic=0; topic < numTopics; topic++) {
if (topicCounts[topic] > 0) {
logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
topicLogGammas[ topic ]);
}
}
// subtract the (count + parameter) sum term
logLikelihood -= Dirichlet.logGammaStirling(alphaSum + docTopics.length);
Arrays.fill(topicCounts, 0);
}
// add the parameter sum term
logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);
// And the topics
// Count the number of type-topic pairs
int nonZeroTypeTopics = 0;
for (int type=0; type < numTypes; type++) {
// reuse this array as a pointer
topicCounts = typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int topic = topicCounts[index] & topicMask;
int count = topicCounts[index] >> topicBits;
nonZeroTypeTopics++;
logLikelihood += Dirichlet.logGammaStirling(beta + count);
if (Double.isNaN(logLikelihood)) {
System.err.println(count);
System.exit(1);
}
index++;
}
}
for (int topic=0; topic < numTopics; topic++) {
logLikelihood -=
Dirichlet.logGammaStirling( (beta * numTypes) +
tokensPerTopic[ topic ] );
if (Double.isNaN(logLikelihood)) {
logger.info("after topic " + topic + " " + tokensPerTopic[ topic ]);
System.exit(1);
}
}
logLikelihood +=
(Dirichlet.logGammaStirling(beta * numTypes)) -
(Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);
if (Double.isNaN(logLikelihood)) {
logger.info("at the end");
System.exit(1);
}
return logLikelihood;
}
/** Return a tool for estimating topic distributions for new documents */
public TopicInferencer getInferencer() {
return new TopicInferencer(typeTopicCounts, tokensPerTopic,
data.get(0).instance.getDataAlphabet(),
alpha, beta, betaSum);
}
/** Return a tool for evaluating the marginal probability of new documents
* under this model */
public MarginalProbEstimator getProbEstimator() {
return new MarginalProbEstimator(numTopics, alpha, alphaSum, beta,
typeTopicCounts, tokensPerTopic);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(data);
out.writeObject(alphabet);
out.writeObject(topicAlphabet);
out.writeInt(numTopics);
out.writeInt(topicMask);
out.writeInt(topicBits);
out.writeInt(numTypes);
out.writeObject(alpha);
out.writeDouble(alphaSum);
out.writeDouble(beta);
out.writeDouble(betaSum);
out.writeObject(typeTopicCounts);
out.writeObject(tokensPerTopic);
out.writeObject(docLengthCounts);
out.writeObject(topicDocCounts);
out.writeInt(numIterations);
out.writeInt(burninPeriod);
out.writeInt(saveSampleInterval);
out.writeInt(optimizeInterval);
out.writeInt(showTopicsInterval);
out.writeInt(wordsPerTopic);
out.writeInt(saveStateInterval);
out.writeObject(stateFilename);
out.writeInt(saveModelInterval);
out.writeObject(modelFilename);
out.writeInt(randomSeed);
out.writeObject(formatter);
out.writeBoolean(printLogLikelihood);
out.writeInt(numThreads);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
data = (ArrayList<TopicAssignment>) in.readObject ();
alphabet = (Alphabet) in.readObject();
topicAlphabet = (LabelAlphabet) in.readObject();
numTopics = in.readInt();
topicMask = in.readInt();
topicBits = in.readInt();
numTypes = in.readInt();
alpha = (double[]) in.readObject();
alphaSum = in.readDouble();
beta = in.readDouble();
betaSum = in.readDouble();
typeTopicCounts = (int[][]) in.readObject();
tokensPerTopic = (int[]) in.readObject();
docLengthCounts = (int[]) in.readObject();
topicDocCounts = (int[][]) in.readObject();
numIterations = in.readInt();
burninPeriod = in.readInt();
saveSampleInterval = in.readInt();
optimizeInterval = in.readInt();
showTopicsInterval = in.readInt();
wordsPerTopic = in.readInt();
saveStateInterval = in.readInt();
stateFilename = (String) in.readObject();
saveModelInterval = in.readInt();
modelFilename = (String) in.readObject();
randomSeed = in.readInt();
formatter = (NumberFormat) in.readObject();
printLogLikelihood = in.readBoolean();
numThreads = in.readInt();
}
public void write (File serializedModelFile) {
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(serializedModelFile));
oos.writeObject(this);
oos.close();
} catch (IOException e) {
System.err.println("Problem serializing ParallelTopicModel to file " +
serializedModelFile + ": " + e);
}
}
public static ParallelTopicModel read (File f) throws Exception {
ParallelTopicModel topicModel = null;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
topicModel = (ParallelTopicModel) ois.readObject();
ois.close();
return topicModel;
}
public static void main (String[] args) {
try {
InstanceList training = InstanceList.load (new File(args[0]));
int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
ParallelTopicModel lda = new ParallelTopicModel (numTopics, 50.0, 0.01);
lda.printLogLikelihood = true;
lda.setTopicDisplay(50, 7);
lda.addInstances(training);
lda.setNumThreads(Integer.parseInt(args[2]));
lda.estimate();
logger.info("printing state");
lda.printState(new File("state.gz"));
logger.info("finished printing");
} catch (Exception e) {
e.printStackTrace();
}
}
}