/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.Arrays;
import java.io.*;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
/**
* Like Latent Dirichlet Allocation, but with integrated phrase discovery.
* @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
* based on C code by Xuerui Wang.
*/
public class TopicalNGrams {
int numTopics;
Alphabet uniAlphabet;
Alphabet biAlphabet;
double alpha, beta, gamma, delta, tAlpha, vBeta, vGamma, delta1, delta2;
InstanceList ilist; // containing FeatureSequenceWithBigrams in the data field of each instance
int[][] topics; // {0...T-1}, the topic index, indexed by <document index, sequence index>
int[][] grams; // {0,1}, the bigram status, indexed by <document index, sequence index> TODO: Make this boolean?
int numTypes; // number of unique unigrams
int numBitypes; // number of unique bigrams
int numTokens; // total number of word occurrences
// "totalNgram"
int biTokens; // total number of tokens currently generated as bigrams (only used for progress messages)
// "docTopic"
int[][] docTopicCounts; // indexed by <document index, topic index>
// Used to calculate p(x|w,t). "ngramCount"
int[][][] typeNgramTopicCounts; // indexed by <feature index, ngram status, topic index>
// Used to calculate p(w|t) and p(w|t,w), "topicWord" and "topicNgramWord"
int[][] unitypeTopicCounts; // indexed by <feature index, topic index>
int[][] bitypeTopicCounts; // index by <bifeature index, topic index>
// "sumWords"
int[] tokensPerTopic; // indexed by <topic index>
// "sumNgramWords"
int[][] bitokensPerTopic; // indexed by <feature index, topic index>, where the later is the conditioned word
public TopicalNGrams (int numberOfTopics)
{
this (numberOfTopics, 50.0, 0.01, 0.01, 0.03, 0.2, 1000);
}
public TopicalNGrams (int numberOfTopics, double alphaSum, double beta, double gamma, double delta,
double delta1, double delta2)
{
this.numTopics = numberOfTopics;
this.alpha = alphaSum / numTopics; // smoothing over the choice of topic
this.beta = beta; // smoothing over the choice of unigram words
this.gamma = gamma; // smoothing over the choice of bigram words
this.delta = delta; // smoothing over the choice of unigram/bigram generation
this.delta1 = delta1; // TODO: Clean this up.
this.delta2 = delta2;
System.out.println("alpha :"+alphaSum);
System.out.println("beta :"+beta);
System.out.println("gamma :"+gamma);
System.out.println("delta :"+delta);
System.out.println("delta1 :"+delta1);
System.out.println("delta2 :"+delta2);
}
public void estimate (InstanceList documents, int numIterations, int showTopicsInterval,
int outputModelInterval, String outputModelFilename,
Randoms r)
{
ilist = documents;
uniAlphabet = ilist.getDataAlphabet();
biAlphabet = ((FeatureSequenceWithBigrams)ilist.get(0).getData()).getBiAlphabet();
numTypes = uniAlphabet.size();
numBitypes = biAlphabet.size();
int numDocs = ilist.size();
topics = new int[numDocs][];
grams = new int[numDocs][];
docTopicCounts = new int[numDocs][numTopics];
typeNgramTopicCounts = new int[numTypes][2][numTopics];
unitypeTopicCounts = new int[numTypes][numTopics];
bitypeTopicCounts = new int[numBitypes][numTopics];
tokensPerTopic = new int[numTopics];
bitokensPerTopic = new int[numTypes][numTopics];
tAlpha = alpha * numTopics;
vBeta = beta * numTypes;
vGamma = gamma * numTypes;
long startTime = System.currentTimeMillis();
// Initialize with random assignments of tokens to topics
// and finish allocating this.topics and this.tokens
int topic, gram, seqLen, fi;
for (int di = 0; di < numDocs; di++) {
FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
seqLen = fs.getLength();
numTokens += seqLen;
topics[di] = new int[seqLen];
grams[di] = new int[seqLen];
// Randomly assign tokens to topics
int prevFi = -1, prevTopic = -1;
for (int si = 0; si < seqLen; si++) {
// randomly sample a topic for the word at position si
topic = r.nextInt(numTopics);
// if a bigram is allowed at position si, then sample a gram status for it.
gram = (fs.getBiIndexAtPosition(si) == -1 ? 0 : r.nextInt(2));
if (gram != 0) biTokens++;
topics[di][si] = topic;
grams[di][si] = gram;
docTopicCounts[di][topic]++;
fi = fs.getIndexAtPosition(si);
if (prevFi != -1)
typeNgramTopicCounts[prevFi][gram][prevTopic]++;
if (gram == 0) {
unitypeTopicCounts[fi][topic]++;
tokensPerTopic[topic]++;
} else {
bitypeTopicCounts[fs.getBiIndexAtPosition(si)][topic]++;
bitokensPerTopic[prevFi][topic]++;
}
prevFi = fi; prevTopic = topic;
}
}
for (int iterations = 0; iterations < numIterations; iterations++) {
sampleTopicsForAllDocs (r);
if (iterations % 10 == 0) System.out.print (iterations); else System.out.print (".");
System.out.flush();
if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) {
System.out.println ();
printTopWords (5, false);
}
if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) {
this.write (new File(outputModelFilename+'.'+iterations));
}
}
System.out.println ("\nTotal time (sec): " + ((System.currentTimeMillis() - startTime)/1000.0));
}
/* One iteration of Gibbs sampling, across all documents. */
private void sampleTopicsForAllDocs (Randoms r)
{
double[] uniTopicWeights = new double[numTopics];
double[] biTopicWeights = new double[numTopics*2];
// Loop over every word in the corpus
for (int di = 0; di < topics.length; di++) {
sampleTopicsForOneDoc ((FeatureSequenceWithBigrams)ilist.get(di).getData(),
topics[di], grams[di], docTopicCounts[di],
uniTopicWeights, biTopicWeights,
r);
}
}
private void sampleTopicsForOneDoc (FeatureSequenceWithBigrams oneDocTokens,
int[] oneDocTopics, int[] oneDocGrams,
int[] oneDocTopicCounts, // indexed by topic index
double[] uniTopicWeights, // length==numTopics
double[] biTopicWeights, // length==numTopics*2: joint topic/gram sampling
Randoms r)
{
int[] currentTypeTopicCounts;
int[] currentBitypeTopicCounts;
int[] previousBitokensPerTopic;
int type, bitype, oldGram, nextGram, newGram, oldTopic, newTopic;
double topicWeightsSum, tw;
// xxx int docLen = oneDocTokens.length;
int docLen = oneDocTokens.getLength();
// Iterate over the positions (words) in the document
for (int si = 0; si < docLen; si++) {
type = oneDocTokens.getIndexAtPosition(si);
bitype = oneDocTokens.getBiIndexAtPosition(si);
//if (bitype == -1) System.out.println ("biblock "+si+" at "+uniAlphabet.lookupObject(type));
oldTopic = oneDocTopics[si];
oldGram = oneDocGrams[si];
nextGram = (si == docLen-1) ? -1 : oneDocGrams[si+1];
//nextGram = (si == docLen-1) ? -1 : (oneDocTokens.getBiIndexAtPosition(si+1) == -1 ? 0 : 1);
boolean bigramPossible = (bitype != -1);
assert (!(!bigramPossible && oldGram == 1));
if (!bigramPossible) {
// Remove this token from all counts
oneDocTopicCounts[oldTopic]--;
tokensPerTopic[oldTopic]--;
unitypeTopicCounts[type][oldTopic]--;
if (si != docLen-1) {
typeNgramTopicCounts[type][nextGram][oldTopic]--;
assert (typeNgramTopicCounts[type][nextGram][oldTopic] >= 0);
}
assert (oneDocTopicCounts[oldTopic] >= 0);
assert (tokensPerTopic[oldTopic] >= 0);
assert (unitypeTopicCounts[type][oldTopic] >= 0);
// Build a distribution over topics for this token
Arrays.fill (uniTopicWeights, 0.0);
topicWeightsSum = 0;
currentTypeTopicCounts = unitypeTopicCounts[type];
for (int ti = 0; ti < numTopics; ti++) {
tw = ((currentTypeTopicCounts[ti] + beta) / (tokensPerTopic[ti] + vBeta))
* ((oneDocTopicCounts[ti] + alpha)); // additional term is constance across all topics
topicWeightsSum += tw;
uniTopicWeights[ti] = tw;
}
// Sample a topic assignment from this distribution
newTopic = r.nextDiscrete (uniTopicWeights, topicWeightsSum);
// Put that new topic into the counts
oneDocTopics[si] = newTopic;
oneDocTopicCounts[newTopic]++;
unitypeTopicCounts[type][newTopic]++;
tokensPerTopic[newTopic]++;
if (si != docLen-1)
typeNgramTopicCounts[type][nextGram][newTopic]++;
} else {
// Bigram is possible
int prevType = oneDocTokens.getIndexAtPosition(si-1);
int prevTopic = oneDocTopics[si-1];
// Remove this token from all counts
oneDocTopicCounts[oldTopic]--;
typeNgramTopicCounts[prevType][oldGram][prevTopic]--;
if (si != docLen-1)
typeNgramTopicCounts[type][nextGram][oldTopic]--;
if (oldGram == 0) {
unitypeTopicCounts[type][oldTopic]--;
tokensPerTopic[oldTopic]--;
} else {
bitypeTopicCounts[bitype][oldTopic]--;
bitokensPerTopic[prevType][oldTopic]--;
biTokens--;
}
assert (oneDocTopicCounts[oldTopic] >= 0);
assert (typeNgramTopicCounts[prevType][oldGram][prevTopic] >= 0);
assert (si == docLen-1 || typeNgramTopicCounts[type][nextGram][oldTopic] >= 0);
assert (unitypeTopicCounts[type][oldTopic] >= 0);
assert (tokensPerTopic[oldTopic] >= 0);
assert (bitypeTopicCounts[bitype][oldTopic] >= 0);
assert (bitokensPerTopic[prevType][oldTopic] >= 0);
assert (biTokens >= 0);
// Build a joint distribution over topics and ngram-status for this token
Arrays.fill (biTopicWeights, 0.0);
topicWeightsSum = 0;
currentTypeTopicCounts = unitypeTopicCounts[type];
currentBitypeTopicCounts = bitypeTopicCounts[bitype];
previousBitokensPerTopic = bitokensPerTopic[prevType];
for (int ti = 0; ti < numTopics; ti++) {
newTopic = ti << 1; // just using this variable as an index into [ti*2+gram]
// The unigram outcome
tw =
(currentTypeTopicCounts[ti] + beta) / (tokensPerTopic[ti] + vBeta)
* (oneDocTopicCounts[ti] + alpha)
* (typeNgramTopicCounts[prevType][0][prevTopic] + delta1);
topicWeightsSum += tw;
biTopicWeights[newTopic] = tw;
// The bigram outcome
newTopic++;
tw =
(currentBitypeTopicCounts[ti] + gamma) / (previousBitokensPerTopic[ti] + vGamma)
* (oneDocTopicCounts[ti] + alpha)
* (typeNgramTopicCounts[prevType][1][prevTopic] + delta2);
topicWeightsSum += tw;
biTopicWeights[newTopic] = tw;
}
// Sample a topic assignment from this distribution
newTopic = r.nextDiscrete (biTopicWeights, topicWeightsSum);
// Put that new topic into the counts
newGram = newTopic % 2;
newTopic /= 2;
// Put that new topic into the counts
oneDocTopics[si] = newTopic;
oneDocGrams[si] = newGram;
oneDocTopicCounts[newTopic]++;
typeNgramTopicCounts[prevType][newGram][prevTopic]++;
if (si != docLen-1)
typeNgramTopicCounts[type][nextGram][newTopic]++;
if (newGram == 0) {
unitypeTopicCounts[type][newTopic]++;
tokensPerTopic[newTopic]++;
} else {
bitypeTopicCounts[bitype][newTopic]++;
bitokensPerTopic[prevType][newTopic]++;
biTokens++;
}
}
}
}
public void printTopWords (int numWords, boolean useNewLines)
{
class WordProb implements Comparable {
int wi; double p;
public WordProb (int wi, double p) { this.wi = wi; this.p = p; }
public final int compareTo (Object o2) {
if (p > ((WordProb)o2).p)
return -1;
else if (p == ((WordProb)o2).p)
return 0;
else return 1;
}
}
for (int ti = 0; ti < numTopics; ti++) {
// Unigrams
WordProb[] wp = new WordProb[numTypes];
for (int wi = 0; wi < numTypes; wi++)
wp[wi] = new WordProb (wi, (double)unitypeTopicCounts[wi][ti]);
Arrays.sort (wp);
int numToPrint = Math.min(wp.length, numWords);
if (useNewLines) {
System.out.println ("\nTopic "+ti+" unigrams");
for (int i = 0; i < numToPrint; i++)
System.out.println (uniAlphabet.lookupObject(wp[i].wi).toString()
+ " " + wp[i].p/tokensPerTopic[ti]);
} else {
System.out.print ("Topic "+ti+": ");
for (int i = 0; i < numToPrint; i++)
System.out.print (uniAlphabet.lookupObject(wp[i].wi).toString() + " ");
}
// Bigrams
/*
wp = new WordProb[numBitypes];
int bisum = 0;
for (int wi = 0; wi < numBitypes; wi++) {
wp[wi] = new WordProb (wi, ((double)bitypeTopicCounts[wi][ti]));
bisum += bitypeTopicCounts[wi][ti];
}
Arrays.sort (wp);
numToPrint = Math.min(wp.length, numWords);
if (useNewLines) {
System.out.println ("\nTopic "+ti+" bigrams");
for (int i = 0; i < numToPrint; i++)
System.out.println (biAlphabet.lookupObject(wp[i].wi).toString() + " " + wp[i].p/bisum);
} else {
System.out.print (" ");
for (int i = 0; i < numToPrint; i++)
System.out.print (biAlphabet.lookupObject(wp[i].wi).toString() + " ");
System.out.println();
}
*/
// Ngrams
AugmentableFeatureVector afv = new AugmentableFeatureVector(new Alphabet(), 10000, false);
for (int di = 0; di < topics.length; di++) {
FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
for (int si = topics[di].length-1; si >= 0; si--) {
if (topics[di][si] == ti && grams[di][si] == 1) {
String gramString = uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString();
while (grams[di][si] == 1 && --si >= 0)
gramString = uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString() + "_" + gramString;
afv.add(gramString, 1.0);
}
}
}
//System.out.println ("pre-sorting");
int numNgrams = afv.numLocations();
//System.out.println ("post-sorting "+numNgrams);
wp = new WordProb[numNgrams];
int ngramSum = 0;
for (int loc = 0; loc < numNgrams; loc++) {
wp[loc] = new WordProb (afv.indexAtLocation(loc), afv.valueAtLocation(loc));
ngramSum += wp[loc].p;
}
Arrays.sort (wp);
int numUnitypeTokens = 0, numBitypeTokens = 0, numUnitypeTypes = 0, numBitypeTypes = 0;
for (int fi = 0; fi < numTypes; fi++) {
numUnitypeTokens += unitypeTopicCounts[fi][ti];
if (unitypeTopicCounts[fi][ti] != 0)
numUnitypeTypes++;
}
for (int fi = 0; fi < numBitypes; fi++) {
numBitypeTokens += bitypeTopicCounts[fi][ti];
if (bitypeTopicCounts[fi][ti] != 0)
numBitypeTypes++;
}
if (useNewLines) {
System.out.println ("\nTopic "+ti+" unigrams "+numUnitypeTokens+"/"+numUnitypeTypes+" bigrams "+numBitypeTokens+"/"+numBitypeTypes
+" phrases "+Math.round(afv.oneNorm())+"/"+numNgrams);
for (int i = 0; i < Math.min(numNgrams,numWords); i++)
System.out.println (afv.getAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p/ngramSum);
} else {
System.out.print (" (unigrams "+numUnitypeTokens+"/"+numUnitypeTypes+" bigrams "+numBitypeTokens+"/"+numBitypeTypes
+" phrases "+Math.round(afv.oneNorm())+"/"+numNgrams+")\n ");
//System.out.print (" (unique-ngrams="+numNgrams+" ngram-count="+Math.round(afv.oneNorm())+")\n ");
for (int i = 0; i < Math.min(numNgrams, numWords); i++)
System.out.print (afv.getAlphabet().lookupObject(wp[i].wi).toString() + " ");
System.out.println();
}
}
}
public void printDocumentTopics (File f) throws IOException
{
printDocumentTopics (new PrintWriter (new FileWriter (f)));
}
public void printDocumentTopics (PrintWriter pw) {
}
public void printDocumentTopics (PrintWriter pw, double threshold, int max)
{
pw.println ("#doc source topic proportions");
int docLen;
double topicDist[] = new double[topics.length];
for (int di = 0; di < topics.length; di++) {
pw.print (di); pw.print (' ');
pw.print (ilist.get(di).getSource().toString()); pw.print (' ');
docLen = topics[di].length;
for (int ti = 0; ti < numTopics; ti++)
topicDist[ti] = (((float)docTopicCounts[di][ti])/docLen);
if (max < 0) max = numTopics;
for (int tp = 0; tp < max; tp++) {
double maxvalue = 0;
int maxindex = -1;
for (int ti = 0; ti < numTopics; ti++)
if (topicDist[ti] > maxvalue) {
maxvalue = topicDist[ti];
maxindex = ti;
}
if (maxindex == -1 || topicDist[maxindex] < threshold)
break;
pw.print (maxindex+" "+topicDist[maxindex]+" ");
topicDist[maxindex] = 0;
}
pw.println (' ');
}
}
public void printState (File f) throws IOException
{
PrintWriter writer = new PrintWriter (new FileWriter(f));
printState (writer);
writer.close();
}
public void printState (PrintWriter pw)
{
pw.println ("#doc pos typeindex type bigrampossible? topic bigram");
for (int di = 0; di < topics.length; di++) {
FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
for (int si = 0; si < topics[di].length; si++) {
int type = fs.getIndexAtPosition(si);
pw.print(di); pw.print(' ');
pw.print(si); pw.print(' ');
pw.print(type); pw.print(' ');
pw.print(uniAlphabet.lookupObject(type)); pw.print(' ');
pw.print(fs.getBiIndexAtPosition(si)==-1 ? 0 : 1); pw.print(' ');
pw.print(topics[di][si]); pw.print(' ');
pw.print(grams[di][si]); pw.println();
}
}
}
public void write (File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f));
oos.writeObject(this);
oos.close();
}
catch (IOException e) {
System.err.println("Exception writing file " + f + ": " + e);
}
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeIntArray2 (int[][] a, ObjectOutputStream out) throws IOException {
out.writeInt (a.length);
int d2 = a[0].length;
out.writeInt (d2);
for (int i = 0; i < a.length; i++)
for (int j = 0; j < d2; j++)
out.writeInt (a[i][j]);
}
private int[][] readIntArray2 (ObjectInputStream in) throws IOException {
int d1 = in.readInt();
int d2 = in.readInt();
int[][] a = new int[d1][d2];
for (int i = 0; i < d1; i++)
for (int j = 0; j < d2; j++)
a[i][j] = in.readInt();
return a;
}
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (ilist);
out.writeInt (numTopics);
out.writeDouble (alpha);
out.writeDouble (beta);
out.writeDouble (gamma);
out.writeDouble (delta);
out.writeDouble (tAlpha);
out.writeDouble (vBeta);
out.writeDouble (vGamma);
out.writeInt (numTypes);
out.writeInt (numBitypes);
out.writeInt (numTokens);
out.writeInt (biTokens);
for (int di = 0; di < topics.length; di ++)
for (int si = 0; si < topics[di].length; si++)
out.writeInt (topics[di][si]);
for (int di = 0; di < topics.length; di ++)
for (int si = 0; si < topics[di].length; si++)
out.writeInt (grams[di][si]);
writeIntArray2 (docTopicCounts, out);
for (int fi = 0; fi < numTypes; fi++)
for (int n = 0; n < 2; n++)
for (int ti = 0; ti < numTopics; ti++)
out.writeInt (typeNgramTopicCounts[fi][n][ti]);
writeIntArray2 (unitypeTopicCounts, out);
writeIntArray2 (bitypeTopicCounts, out);
for (int ti = 0; ti < numTopics; ti++)
out.writeInt (tokensPerTopic[ti]);
writeIntArray2 (bitokensPerTopic, out);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int featuresLength;
int version = in.readInt ();
ilist = (InstanceList) in.readObject ();
numTopics = in.readInt();
alpha = in.readDouble();
beta = in.readDouble();
gamma = in.readDouble();
delta = in.readDouble();
tAlpha = in.readDouble();
vBeta = in.readDouble();
vGamma = in.readDouble();
numTypes = in.readInt();
numBitypes = in.readInt();
numTokens = in.readInt();
biTokens = in.readInt();
int numDocs = ilist.size();
topics = new int[numDocs][];
grams = new int[numDocs][];
for (int di = 0; di < ilist.size(); di++) {
int docLen = ((FeatureSequence)ilist.get(di).getData()).getLength();
topics[di] = new int[docLen];
for (int si = 0; si < docLen; si++)
topics[di][si] = in.readInt();
}
for (int di = 0; di < ilist.size(); di++) {
int docLen = ((FeatureSequence)ilist.get(di).getData()).getLength();
grams[di] = new int[docLen];
for (int si = 0; si < docLen; si++)
grams[di][si] = in.readInt();
}
docTopicCounts = readIntArray2 (in);
typeNgramTopicCounts = new int[numTypes][2][numTopics];
for (int fi = 0; fi < numTypes; fi++)
for (int n = 0; n < 2; n++)
for (int ti = 0; ti < numTopics; ti++)
typeNgramTopicCounts[fi][n][ti] = in.readInt();
unitypeTopicCounts = readIntArray2 (in);
bitypeTopicCounts = readIntArray2 (in);
tokensPerTopic = new int[numTopics];
for (int ti = 0; ti < numTopics; ti++)
tokensPerTopic[ti] = in.readInt();
bitokensPerTopic = readIntArray2 (in);
}
// Just for testing. Recommend instead is mallet/bin/vectors2topics
public static void main (String[] args)
{
InstanceList ilist = InstanceList.load (new File(args[0]));
int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000;
int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20;
System.out.println ("Data loaded.");
TopicalNGrams tng = new TopicalNGrams (10);
tng.estimate (ilist, 200, 1, 0, null, new Randoms());
tng.printTopWords (60, true);
}
}