/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.*;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
/**
* Latent Dirichlet Allocation for loosely parallel corpora in arbitrary languages
*
* @author David Mimno, Andrew McCallum
*/
public class PolylingualTopicModel implements Serializable {
// Analogous to a cc.mallet.classify.Classification
public class TopicAssignment implements Serializable {
public Instance[] instances;
public LabelSequence[] topicSequences;
public Labeling topicDistribution;
public TopicAssignment (Instance[] instances, LabelSequence[] topicSequences) {
this.instances = instances;
this.topicSequences = topicSequences;
}
}
int numLanguages = 1;
protected ArrayList<TopicAssignment> data; // the training instances and their topic assignments
protected LabelAlphabet topicAlphabet; // the alphabet for the topics
protected int numStopwords = 0;
protected int numTopics; // Number of topics to be fit
HashSet<String> testingIDs = null;
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected Alphabet[] alphabets;
protected int[] vocabularySizes;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double[] betas; // Prior on per-topic multinomial distribution over words
protected double[] betaSums;
protected int[] languageMaxTypeCounts;
public static final double DEFAULT_BETA = 0.01;
protected double[] languageSmoothingOnlyMasses;
protected double[][] languageCachedCoefficients;
int topicTermCount = 0;
int betaTopicCount = 0;
int smoothingOnlyCount = 0;
// An array to put the topic counts for the current document.
// Initialized locally below. Defined here to avoid
// garbage collection overhead.
protected int[] oneDocTopicCounts; // indexed by <document index, topic index>
protected int[][][] languageTypeTopicCounts; // indexed by <feature index, topic index>
protected int[][] languageTokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
protected int[] docLengthCounts; // histogram of document sizes, summed over languages
protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
protected int iterationsSoFar = 1;
public int numIterations = 1000;
public int burninPeriod = 5;
public int saveSampleInterval = 5; // was 10;
public int optimizeInterval = 10;
public int showTopicsInterval = 10; // was 50;
public int wordsPerTopic = 7;
protected int outputModelInterval = 0;
protected String outputModelFilename;
protected int saveStateInterval = 0;
protected String stateFilename = null;
protected Randoms random;
protected NumberFormat formatter;
protected boolean printLogLikelihood = false;
public PolylingualTopicModel (int numberOfTopics) {
this (numberOfTopics, numberOfTopics);
}
public PolylingualTopicModel (int numberOfTopics, double alphaSum) {
this (numberOfTopics, alphaSum, new Randoms());
}
private static LabelAlphabet newLabelAlphabet (int numTopics) {
LabelAlphabet ret = new LabelAlphabet();
for (int i = 0; i < numTopics; i++)
ret.lookupIndex("topic"+i);
return ret;
}
public PolylingualTopicModel (int numberOfTopics, double alphaSum, Randoms random) {
this (newLabelAlphabet (numberOfTopics), alphaSum, random);
}
public PolylingualTopicModel (LabelAlphabet topicAlphabet, double alphaSum, Randoms random)
{
this.data = new ArrayList<TopicAssignment>();
this.topicAlphabet = topicAlphabet;
this.numTopics = topicAlphabet.size();
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.alphaSum = alphaSum;
this.alpha = new double[numTopics];
Arrays.fill(alpha, alphaSum / numTopics);
this.random = random;
formatter = NumberFormat.getInstance();
formatter.setMaximumFractionDigits(5);
System.err.println("Polylingual LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
}
public void loadTestingIDs(File testingIDFile) throws IOException {
testingIDs = new HashSet();
BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
String id = null;
while ((id = in.readLine()) != null) {
testingIDs.add(id);
}
in.close();
}
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public int getNumTopics() { return numTopics; }
public ArrayList<TopicAssignment> getData() { return data; }
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
}
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
}
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
}
public void setRandomSeed(int seed) {
random = new Randoms(seed);
}
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
}
public void setModelOutput(int interval, String filename) {
this.outputModelInterval = interval;
this.outputModelFilename = filename;
}
/** Define how often and where to save the state
*
* @param interval Save a copy of the state every <code>interval</code> iterations.
* @param filename Save the state to this file, with the iteration number as a suffix
*/
public void setSaveState(int interval, String filename) {
this.saveStateInterval = interval;
this.stateFilename = filename;
}
public void addInstances (InstanceList[] training) {
numLanguages = training.length;
languageTokensPerTopic = new int[numLanguages][numTopics];
alphabets = new Alphabet[ numLanguages ];
vocabularySizes = new int[ numLanguages ];
betas = new double[ numLanguages ];
betaSums = new double[ numLanguages ];
languageMaxTypeCounts = new int[ numLanguages ];
languageTypeTopicCounts = new int[ numLanguages ][][];
int numInstances = training[0].size();
HashSet[] stoplists = new HashSet[ numLanguages ];
for (int language = 0; language < numLanguages; language++) {
if (training[language].size() != numInstances) {
System.err.println("Warning: language " + language + " has " +
training[language].size() + " instances, lang 0 has " +
numInstances);
}
alphabets[ language ] = training[ language ].getDataAlphabet();
vocabularySizes[ language ] = alphabets[ language ].size();
betas[language] = DEFAULT_BETA;
betaSums[language] = betas[language] * vocabularySizes[ language ];
languageTypeTopicCounts[language] = new int[ vocabularySizes[language] ][];
int[][] typeTopicCounts = languageTypeTopicCounts[language];
// Get the total number of occurrences of each word type
int[] typeTotals = new int[ vocabularySizes[language] ];
for (Instance instance : training[language]) {
if (testingIDs != null &&
testingIDs.contains(instance.getName())) {
continue;
}
FeatureSequence tokens = (FeatureSequence) instance.getData();
for (int position = 0; position < tokens.getLength(); position++) {
int type = tokens.getIndexAtPosition(position);
typeTotals[ type ]++;
}
}
/* Automatic stoplist creation, currently disabled
TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
for (int type = 0; type < vocabularySizes[language]; type++) {
sortedWords.add(new IDSorter(type, typeTotals[type]));
}
stoplists[language] = new HashSet<Integer>();
Iterator<IDSorter> typeIterator = sortedWords.iterator();
int totalStopwords = 0;
while (typeIterator.hasNext() && totalStopwords < numStopwords) {
stoplists[language].add(typeIterator.next().getID());
}
*/
// Allocate enough space so that we never have to worry about
// overflows: either the number of topics or the number of times
// the type occurs.
for (int type = 0; type < vocabularySizes[language]; type++) {
if (typeTotals[type] > languageMaxTypeCounts[language]) {
languageMaxTypeCounts[language] = typeTotals[type];
}
typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ];
}
}
for (int doc = 0; doc < numInstances; doc++) {
if (testingIDs != null &&
testingIDs.contains(training[0].get(doc).getName())) {
continue;
}
Instance[] instances = new Instance[ numLanguages ];
LabelSequence[] topicSequences = new LabelSequence[ numLanguages ];
for (int language = 0; language < numLanguages; language++) {
int[][] typeTopicCounts = languageTypeTopicCounts[language];
int[] tokensPerTopic = languageTokensPerTopic[language];
instances[language] = training[language].get(doc);
FeatureSequence tokens = (FeatureSequence) instances[language].getData();
topicSequences[language] =
new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
int[] topics = topicSequences[language].getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int type = tokens.getIndexAtPosition(position);
int[] currentTypeTopicCounts = typeTopicCounts[ type ];
int topic = random.nextInt(numTopics);
// If the word is one of the [numStopwords] most
// frequent words, put it in a non-sampled topic.
//if (stoplists[language].contains(type)) {
// topic = -1;
//}
topics[position] = topic;
tokensPerTopic[topic]++;
// The format for these arrays is
// the topic in the rightmost bits
// the count in the remaining (left) bits.
// Since the count is in the high bits, sorting (desc)
// by the numeric value of the int guarantees that
// higher counts will be before the lower counts.
// Start by assuming that the array is either empty
// or is in sorted (descending) order.
// Here we are only adding counts, so if we find
// an existing location with the topic, we only need
// to ensure that it is not larger than its left neighbor.
int index = 0;
int currentTopic = currentTypeTopicCounts[index] & topicMask;
int currentValue;
while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
index++;
/*
// Debugging output...
if (index >= currentTypeTopicCounts.length) {
for (int i=0; i < currentTypeTopicCounts.length; i++) {
System.out.println((currentTypeTopicCounts[i] & topicMask) + ":" +
(currentTypeTopicCounts[i] >> topicBits) + " ");
}
System.out.println(type + " " + typeTotals[type]);
}
*/
currentTopic = currentTypeTopicCounts[index] & topicMask;
}
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (currentValue == 0) {
// new value is 1, so we don't have to worry about sorting
// (except by topic suffix, which doesn't matter)
currentTypeTopicCounts[index] =
(1 << topicBits) + topic;
}
else {
currentTypeTopicCounts[index] =
((currentValue + 1) << topicBits) + topic;
// Now ensure that the array is still sorted by
// bubbling this value up.
while (index > 0 &&
currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
int temp = currentTypeTopicCounts[index];
currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
currentTypeTopicCounts[index - 1] = temp;
index--;
}
}
}
}
TopicAssignment t = new TopicAssignment (instances, topicSequences);
data.add (t);
}
initializeHistograms();
languageSmoothingOnlyMasses = new double[ numLanguages ];
languageCachedCoefficients = new double[ numLanguages ][ numTopics ];
cacheValues();
}
/**
* Gather statistics on the size of documents
* and create histograms for use in Dirichlet hyperparameter
* optimization.
*/
private void initializeHistograms() {
int maxTokens = 0;
int totalTokens = 0;
for (int doc = 0; doc < data.size(); doc++) {
int length = 0;
for (LabelSequence sequence : data.get(doc).topicSequences) {
length += sequence.getLength();
}
if (length > maxTokens) {
maxTokens = length;
}
totalTokens += length;
}
System.err.println("max tokens: " + maxTokens);
System.err.println("total tokens: " + totalTokens);
docLengthCounts = new int[maxTokens + 1];
topicDocCounts = new int[numTopics][maxTokens + 1];
}
private void cacheValues() {
for (int language = 0; language < numLanguages; language++) {
languageSmoothingOnlyMasses[language] = 0.0;
for (int topic=0; topic < numTopics; topic++) {
languageSmoothingOnlyMasses[language] +=
alpha[topic] * betas[language] /
(languageTokensPerTopic[language][topic] + betaSums[language]);
languageCachedCoefficients[language][topic] =
alpha[topic] / (languageTokensPerTopic[language][topic] + betaSums[language]);
}
}
}
private void clearHistograms() {
Arrays.fill(docLengthCounts, 0);
for (int topic = 0; topic < topicDocCounts.length; topic++)
Arrays.fill(topicDocCounts[topic], 0);
}
public void estimate () throws IOException {
estimate (numIterations);
}
public void estimate (int iterationsThisRound) throws IOException {
long startTime = System.currentTimeMillis();
int maxIteration = iterationsSoFar + iterationsThisRound;
long totalTime = 0;
for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
long iterationStart = System.currentTimeMillis();
if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
System.out.println();
printTopWords (System.out, wordsPerTopic, false);
}
if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
this.printState(new File(stateFilename + '.' + iterationsSoFar));
}
/*
if (outputModelInterval != 0 && iterations % outputModelInterval == 0) {
this.write (new File(outputModelFilename+'.'+iterations));
}
*/
// TODO this condition should also check that we have more than one sample to work with here
// (The number of samples actually obtained is not yet tracked.)
if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
iterationsSoFar % optimizeInterval == 0) {
alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts);
optimizeBetas();
clearHistograms();
cacheValues();
}
// Loop over every document in the corpus
topicTermCount = betaTopicCount = smoothingOnlyCount = 0;
for (int doc = 0; doc < data.size(); doc++) {
sampleTopicsForOneDoc (data.get(doc),
(iterationsSoFar >= burninPeriod &&
iterationsSoFar % saveSampleInterval == 0));
}
long elapsedMillis = System.currentTimeMillis() - iterationStart;
totalTime += elapsedMillis;
if ((iterationsSoFar + 1) % 10 == 0) {
double ll = modelLogLikelihood();
System.out.println(elapsedMillis + "\t" + totalTime + "\t" +
ll);
}
else {
System.out.print(elapsedMillis + " ");
}
}
/*
long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
long minutes = seconds / 60; seconds %= 60;
long hours = minutes / 60; minutes %= 60;
long days = hours / 24; hours %= 24;
System.out.print ("\nTotal time: ");
if (days != 0) { System.out.print(days); System.out.print(" days "); }
if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
System.out.print(seconds); System.out.println(" seconds");
*/
}
public void optimizeBetas() {
for (int language = 0; language < numLanguages; language++) {
// The histogram starts at count 0, so if all of the
// tokens of the most frequent type were assigned to one topic,
// we would need to store a maxTypeCount + 1 count.
int[] countHistogram = new int[languageMaxTypeCounts[language] + 1];
// Now count the number of type/topic pairs that have
// each number of tokens.
int[][] typeTopicCounts = languageTypeTopicCounts[language];
int[] tokensPerTopic = languageTokensPerTopic[language];
int index;
for (int type = 0; type < vocabularySizes[language]; type++) {
int[] counts = typeTopicCounts[type];
index = 0;
while (index < counts.length &&
counts[index] > 0) {
int count = counts[index] >> topicBits;
countHistogram[count]++;
index++;
}
}
// Figure out how large we need to make the "observation lengths"
// histogram.
int maxTopicSize = 0;
for (int topic = 0; topic < numTopics; topic++) {
if (tokensPerTopic[topic] > maxTopicSize) {
maxTopicSize = tokensPerTopic[topic];
}
}
// Now allocate it and populate it.
int[] topicSizeHistogram = new int[maxTopicSize + 1];
for (int topic = 0; topic < numTopics; topic++) {
topicSizeHistogram[ tokensPerTopic[topic] ]++;
}
betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram,
topicSizeHistogram,
vocabularySizes[ language ],
betaSums[language]);
betas[language] = betaSums[language] / vocabularySizes[ language ];
}
}
protected void sampleTopicsForOneDoc (TopicAssignment topicAssignment,
boolean shouldSaveState) {
int[] currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int[] localTopicCounts = new int[numTopics];
int[] localTopicIndex = new int[numTopics];
for (int language = 0; language < numLanguages; language++) {
int[] oneDocTopics =
topicAssignment.topicSequences[language].getFeatures();
int docLength =
topicAssignment.topicSequences[language].getLength();
// populate topic counts
for (int position = 0; position < docLength; position++) {
localTopicCounts[oneDocTopics[position]]++;
}
}
// Build an array that densely lists the topics that
// have non-zero counts.
int denseIndex = 0;
for (int topic = 0; topic < numTopics; topic++) {
if (localTopicCounts[topic] != 0) {
localTopicIndex[denseIndex] = topic;
denseIndex++;
}
}
// Record the total number of non-zero topics
int nonZeroTopics = denseIndex;
for (int language = 0; language < numLanguages; language++) {
int[] oneDocTopics =
topicAssignment.topicSequences[language].getFeatures();
int docLength =
topicAssignment.topicSequences[language].getLength();
FeatureSequence tokenSequence =
(FeatureSequence) topicAssignment.instances[language].getData();
int[][] typeTopicCounts = languageTypeTopicCounts[language];
int[] tokensPerTopic = languageTokensPerTopic[language];
double beta = betas[language];
double betaSum = betaSums[language];
// Initialize the smoothing-only sampling bucket
double smoothingOnlyMass = languageSmoothingOnlyMasses[language];
//for (int topic = 0; topic < numTopics; topic++)
//smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
// Initialize the cached coefficients, using only smoothing.
//cachedCoefficients = new double[ numTopics ];
//for (int topic=0; topic < numTopics; topic++)
// cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
double[] cachedCoefficients =
languageCachedCoefficients[language];
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
// Initialize cached coefficients and the topic/beta
// normalizing constant.
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
int n = localTopicCounts[topic];
// initialize the normalization constant for the (B * n_{t|d}) term
topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
// update the coefficients for the non-zero topics
cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
}
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
if (oldTopic == -1) { continue; }
currentTypeTopicCounts = typeTopicCounts[type];
// Remove this token from all counts.
// Remove this topic's contribution to the
// normalizing constants
smoothingOnlyMass -= alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Decrement the local doc/topic counts
localTopicCounts[oldTopic]--;
// Maintain the dense index, if we are deleting
// the old topic
if (localTopicCounts[oldTopic] == 0) {
// First get to the dense location associated with
// the old topic.
denseIndex = 0;
// We know it's in there somewhere, so we don't
// need bounds checking.
while (localTopicIndex[denseIndex] != oldTopic) {
denseIndex++;
}
// shift all remaining dense indices to the left.
while (denseIndex < nonZeroTopics) {
if (denseIndex < localTopicIndex.length - 1) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex + 1];
}
denseIndex++;
}
nonZeroTopics --;
}
// Decrement the global topic count totals
tokensPerTopic[oldTopic]--;
//assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
// Add the old topic's contribution back into the
// normalizing constants.
smoothingOnlyMass += alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Reset the cached coefficient for this topic
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts[oldTopic]) /
(tokensPerTopic[oldTopic] + betaSum);
// Now go over the type/topic counts, decrementing
// where appropriate, and calculating the score
// for each topic at the same time.
int index = 0;
int currentTopic, currentValue;
boolean alreadyDecremented = false;
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (! alreadyDecremented &&
currentTopic == oldTopic) {
// We're decrementing and adding up the
// sampling weights at the same time, but
// decrementing may require us to reorder
// the topics, so after we're done here,
// look at this cell in the array again.
currentValue --;
if (currentValue == 0) {
currentTypeTopicCounts[index] = 0;
}
else {
currentTypeTopicCounts[index] =
(currentValue << topicBits) + oldTopic;
}
// Shift the reduced value to the right, if necessary.
int subIndex = index;
while (subIndex < currentTypeTopicCounts.length - 1 &&
currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
int temp = currentTypeTopicCounts[subIndex];
currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
currentTypeTopicCounts[subIndex + 1] = temp;
subIndex++;
}
alreadyDecremented = true;
}
else {
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
index++;
}
}
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
//topicTermCount++;
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = currentTypeTopicCounts[i] & topicMask;
currentValue = currentTypeTopicCounts[i] >> topicBits;
currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic;
// Bubble the new value up, if necessary
while (i > 0 &&
currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
int temp = currentTypeTopicCounts[i];
currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
currentTypeTopicCounts[i - 1] = temp;
i--;
}
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
else {
//smoothingOnlyCount++;
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
newTopic++;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
// Move to the position for the new topic,
// which may be the first empty position if this
// is a new topic for this word.
index = 0;
while (currentTypeTopicCounts[index] > 0 &&
(currentTypeTopicCounts[index] & topicMask) != newTopic) {
index++;
}
// index should now be set to the position of the new topic,
// which may be an empty cell at the end of the list.
if (currentTypeTopicCounts[index] == 0) {
// inserting a new topic, guaranteed to be in
// order w.r.t. count, if not topic.
currentTypeTopicCounts[index] = (1 << topicBits) + newTopic;
}
else {
currentValue = currentTypeTopicCounts[index] >> topicBits;
currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic;
// Bubble the increased value left, if necessary
while (index > 0 &&
currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
int temp = currentTypeTopicCounts[index];
currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
currentTypeTopicCounts[index - 1] = temp;
index--;
}
}
}
if (newTopic == -1) {
System.err.println("PolylingualTopicModel sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("PolylingualTopicModel: New topic not sampled.");
}
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
smoothingOnlyMass -= alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts[newTopic]++;
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
denseIndex--;
}
localTopicIndex[denseIndex] = newTopic;
nonZeroTopics++;
}
tokensPerTopic[newTopic]++;
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
smoothingOnlyMass += alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
// Save the smoothing-only mass to the global cache
languageSmoothingOnlyMasses[language] = smoothingOnlyMass;
}
}
if (shouldSaveState) {
// Update the document-topic count histogram,
// for dirichlet estimation
int totalLength = 0;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
topicDocCounts[topic][ localTopicCounts[topic] ]++;
totalLength += localTopicCounts[topic];
}
docLengthCounts[ totalLength ]++;
}
}
public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
PrintStream out = new PrintStream (file);
printTopWords(out, numWords, useNewLines);
out.close();
}
public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {
TreeSet[][] languageTopicSortedWords = new TreeSet[numLanguages][numTopics];
for (int language = 0; language < numLanguages; language++) {
TreeSet[] topicSortedWords = languageTopicSortedWords[language];
int[][] typeTopicCounts = languageTypeTopicCounts[language];
for (int topic = 0; topic < numTopics; topic++) {
topicSortedWords[topic] = new TreeSet<IDSorter>();
}
for (int type = 0; type < vocabularySizes[language]; type++) {
int[] topicCounts = typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int topic = topicCounts[index] & topicMask;
int count = topicCounts[index] >> topicBits;
topicSortedWords[topic].add(new IDSorter(type, count));
index++;
}
}
}
for (int topic = 0; topic < numTopics; topic++) {
out.println (topic + "\t" + formatter.format(alpha[topic]));
for (int language = 0; language < numLanguages; language++) {
out.print(" " + language + "\t" + languageTokensPerTopic[language][topic] + "\t" + betas[language] + "\t");
TreeSet<IDSorter> sortedWords = languageTopicSortedWords[language][topic];
Alphabet alphabet = alphabets[language];
int word = 1;
Iterator<IDSorter> iterator = sortedWords.iterator();
while (iterator.hasNext() && word < numWords) {
IDSorter info = iterator.next();
out.print(alphabet.lookupObject(info.getID()) + " ");
word++;
}
out.println();
}
}
}
public void printDocumentTopics (File f) throws IOException {
printDocumentTopics (new PrintWriter (f, "UTF-8") );
}
public void printDocumentTopics (PrintWriter pw) {
printDocumentTopics (pw, 0.0, -1);
}
/**
* @param pw A print writer
* @param threshold Only print topics with proportion greater than this number
* @param max Print no more than this many topics
*/
public void printDocumentTopics (PrintWriter pw, double threshold, int max) {
pw.print ("#doc source topic proportion ...\n");
int docLength;
int[] topicCounts = new int[ numTopics ];
IDSorter[] sortedTopics = new IDSorter[ numTopics ];
for (int topic = 0; topic < numTopics; topic++) {
// Initialize the sorters with dummy values
sortedTopics[topic] = new IDSorter(topic, topic);
}
if (max < 0 || max > numTopics) {
max = numTopics;
}
for (int di = 0; di < data.size(); di++) {
pw.print (di); pw.print (' ');
int totalLength = 0;
for (int language = 0; language < numLanguages; language++) {
LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequences[language];
int[] currentDocTopics = topicSequence.getFeatures();
docLength = topicSequence.getLength();
totalLength += docLength;
// Count up the tokens
for (int token=0; token < docLength; token++) {
topicCounts[ currentDocTopics[token] ]++;
}
}
// And normalize
for (int topic = 0; topic < numTopics; topic++) {
sortedTopics[topic].set(topic, (float) topicCounts[topic] / totalLength);
}
Arrays.sort(sortedTopics);
for (int i = 0; i < max; i++) {
if (sortedTopics[i].getWeight() < threshold) { break; }
pw.print (sortedTopics[i].getID() + " " +
sortedTopics[i].getWeight() + " ");
}
pw.print (" \n");
Arrays.fill(topicCounts, 0);
}
}
public void printState (File f) throws IOException {
PrintStream out =
new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))),
false, "UTF-8");
printState(out);
out.close();
}
public void printState (PrintStream out) {
out.println ("#doc lang pos typeindex type topic");
for (int doc = 0; doc < data.size(); doc++) {
for (int language =0; language < numLanguages; language++) {
FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instances[language].getData();
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language];
for (int pi = 0; pi < topicSequence.getLength(); pi++) {
int type = tokenSequence.getIndexAtPosition(pi);
int topic = topicSequence.getIndexAtPosition(pi);
out.print(doc); out.print(' ');
out.print(language); out.print(' ');
out.print(pi); out.print(' ');
out.print(type); out.print(' ');
out.print(alphabets[language].lookupObject(type)); out.print(' ');
out.print(topic); out.println();
}
}
}
}
public double modelLogLikelihood() {
double logLikelihood = 0.0;
int nonZeroTopics;
// The likelihood of the model is a combination of a
// Dirichlet-multinomial for the words in each topic
// and a Dirichlet-multinomial for the topics in each
// document.
// The likelihood function of a dirichlet multinomial is
// Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i )
// prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) )
// So the log likelihood is
// logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
// sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]
// Do the documents first
int[] topicCounts = new int[numTopics];
double[] topicLogGammas = new double[numTopics];
int[] docTopics;
for (int topic=0; topic < numTopics; topic++) {
topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
}
for (int doc=0; doc < data.size(); doc++) {
int totalLength = 0;
for (int language = 0; language < numLanguages; language++) {
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language];
int[] currentDocTopics = topicSequence.getFeatures();
totalLength += topicSequence.getLength();
// Count up the tokens
for (int token=0; token < topicSequence.getLength(); token++) {
topicCounts[ currentDocTopics[token] ]++;
}
}
for (int topic=0; topic < numTopics; topic++) {
if (topicCounts[topic] > 0) {
logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
topicLogGammas[ topic ]);
}
}
// subtract the (count + parameter) sum term
logLikelihood -= Dirichlet.logGammaStirling(alphaSum + totalLength);
Arrays.fill(topicCounts, 0);
}
// add the parameter sum term
logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);
// And the topics
for (int language = 0; language < numLanguages; language++) {
int[][] typeTopicCounts = languageTypeTopicCounts[language];
int[] tokensPerTopic = languageTokensPerTopic[language];
double beta = betas[language];
// Count the number of type-topic pairs
int nonZeroTypeTopics = 0;
for (int type=0; type < vocabularySizes[language]; type++) {
// reuse this array as a pointer
topicCounts = typeTopicCounts[type];
int index = 0;
while (index < topicCounts.length &&
topicCounts[index] > 0) {
int topic = topicCounts[index] & topicMask;
int count = topicCounts[index] >> topicBits;
nonZeroTypeTopics++;
logLikelihood += Dirichlet.logGammaStirling(beta + count);
if (Double.isNaN(logLikelihood)) {
System.out.println(count);
System.exit(1);
}
index++;
}
}
for (int topic=0; topic < numTopics; topic++) {
logLikelihood -=
Dirichlet.logGammaStirling( (beta * numTopics) +
tokensPerTopic[ topic ] );
if (Double.isNaN(logLikelihood)) {
System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
System.exit(1);
}
}
logLikelihood +=
(Dirichlet.logGammaStirling(beta * numTopics)) -
(Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);
}
if (Double.isNaN(logLikelihood)) {
System.out.println("at the end");
System.exit(1);
}
return logLikelihood;
}
public static void main (String[] args) throws IOException {
if (args.length < 4) {
System.err.println("Usage: PolylingualTopicModel [num topics] [file to save state] [testing IDs file] [language 0 instances] ...");
System.exit(1);
}
int numTopics = Integer.parseInt(args[0]);
String stateFileName = args[1];
File testingIDsFile = new File(args[2]);
InstanceList[] training = new InstanceList[ args.length - 3 ];
for (int language=0; language < training.length; language++) {
training[language] = InstanceList.load(new File(args[language + 3]));
System.err.println("loaded " + args[language + 3]);
}
PolylingualTopicModel lda = new PolylingualTopicModel (numTopics, 2.0);
lda.printLogLikelihood = true;
lda.setTopicDisplay(50, 7);
lda.loadTestingIDs(testingIDsFile);
lda.addInstances(training);
lda.setSaveState(200, stateFileName);
lda.estimate();
lda.printState(new File(stateFileName));
}
}