package cc.mallet.topics;
import java.util.ArrayList;
import java.util.Arrays;
import java.io.*;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
import gnu.trove.*;
public class HierarchicalLDA {
InstanceList instances;
InstanceList testing;
NCRPNode rootNode, node;
int numLevels;
int numDocuments;
int numTypes;
double alpha; // smoothing on topic distributions
double gamma; // "imaginary" customers at the next, as yet unused table
double eta; // smoothing on word distributions
double etaSum;
int[][] levels; // indexed < doc, token >
NCRPNode[] documentLeaves; // currently selected path (ie leaf node) through the NCRP tree
int totalNodes = 0;
String stateFile = "hlda.state";
Randoms random;
boolean showProgress = true;
int displayTopicsInterval = 50;
int numWordsToDisplay = 10;
public HierarchicalLDA () {
alpha = 10.0;
gamma = 1.0;
eta = 0.1;
}
public void setAlpha(double alpha) {
this.alpha = alpha;
}
public void setGamma(double gamma) {
this.gamma = gamma;
}
public void setEta(double eta) {
this.eta = eta;
}
public void setStateFile(String stateFile) {
this.stateFile = stateFile;
}
public void setTopicDisplay(int interval, int words) {
displayTopicsInterval = interval;
numWordsToDisplay = words;
}
/**
* This parameter determines whether the sampler outputs
* shows progress by outputting a character after every iteration.
*/
public void setProgressDisplay(boolean showProgress) {
this.showProgress = showProgress;
}
public void initialize(InstanceList instances, InstanceList testing,
int numLevels, Randoms random) {
this.instances = instances;
this.testing = testing;
this.numLevels = numLevels;
this.random = random;
if (! (instances.get(0).getData() instanceof FeatureSequence)) {
throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
}
numDocuments = instances.size();
numTypes = instances.getDataAlphabet().size();
etaSum = eta * numTypes;
// Initialize a single path
NCRPNode[] path = new NCRPNode[numLevels];
rootNode = new NCRPNode(numTypes);
levels = new int[numDocuments][];
documentLeaves = new NCRPNode[numDocuments];
// Initialize and fill the topic pointer arrays for
// every document. Set everything to the single path that
// we added earlier.
for (int doc=0; doc < numDocuments; doc++) {
FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
int seqLen = fs.getLength();
path[0] = rootNode;
rootNode.customers++;
for (int level = 1; level < numLevels; level++) {
path[level] = path[level-1].select();
path[level].customers++;
}
node = path[numLevels - 1];
levels[doc] = new int[seqLen];
documentLeaves[doc] = node;
for (int token=0; token < seqLen; token++) {
int type = fs.getIndexAtPosition(token);
levels[doc][token] = random.nextInt(numLevels);
node = path[ levels[doc][token] ];
node.totalTokens++;
node.typeCounts[type]++;
}
}
}
public void estimate(int numIterations) {
for (int iteration = 1; iteration <= numIterations; iteration++) {
for (int doc=0; doc < numDocuments; doc++) {
samplePath(doc, iteration);
}
for (int doc=0; doc < numDocuments; doc++) {
sampleTopics(doc);
}
if (showProgress) {
System.out.print(".");
if (iteration % 50 == 0) {
System.out.println(" " + iteration);
}
}
if (iteration % displayTopicsInterval == 0) {
printNodes();
}
}
}
public void samplePath(int doc, int iteration) {
NCRPNode[] path = new NCRPNode[numLevels];
NCRPNode node;
int level, token, type, topicCount;
double weight;
node = documentLeaves[doc];
for (level = numLevels - 1; level >= 0; level--) {
path[level] = node;
node = node.parent;
}
documentLeaves[doc].dropPath();
TObjectDoubleHashMap<NCRPNode> nodeWeights =
new TObjectDoubleHashMap<NCRPNode>();
// Calculate p(c_m | c_{-m})
calculateNCRP(nodeWeights, rootNode, 0.0);
// Add weights for p(w_m | c, w_{-m}, z)
// The path may have no further customers and therefore
// be unavailable, but it should still exist since we haven't
// reset documentLeaves[doc] yet...
TIntIntHashMap[] typeCounts = new TIntIntHashMap[numLevels];
int[] docLevels;
for (level = 0; level < numLevels; level++) {
typeCounts[level] = new TIntIntHashMap();
}
docLevels = levels[doc];
FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
// Save the counts of every word at each level, and remove
// counts from the current path
for (token = 0; token < docLevels.length; token++) {
level = docLevels[token];
type = fs.getIndexAtPosition(token);
if (! typeCounts[level].containsKey(type)) {
typeCounts[level].put(type, 1);
}
else {
typeCounts[level].increment(type);
}
path[level].typeCounts[type]--;
assert(path[level].typeCounts[type] >= 0);
path[level].totalTokens--;
assert(path[level].totalTokens >= 0);
}
// Calculate the weight for a new path at a given level.
double[] newTopicWeights = new double[numLevels];
for (level = 1; level < numLevels; level++) { // Skip the root...
int[] types = typeCounts[level].keys();
int totalTokens = 0;
for (int t: types) {
for (int i=0; i<typeCounts[level].get(t); i++) {
newTopicWeights[level] +=
Math.log((eta + i) / (etaSum + totalTokens));
totalTokens++;
}
}
//if (iteration > 1) { System.out.println(newTopicWeights[level]); }
}
calculateWordLikelihood(nodeWeights, rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
NCRPNode[] nodes = nodeWeights.keys(new NCRPNode[] {});
double[] weights = new double[nodes.length];
double sum = 0.0;
double max = Double.NEGATIVE_INFINITY;
// To avoid underflow, we're using log weights and normalizing the node weights so that
// the largest weight is always 1.
for (int i=0; i<nodes.length; i++) {
if (nodeWeights.get(nodes[i]) > max) {
max = nodeWeights.get(nodes[i]);
}
}
for (int i=0; i<nodes.length; i++) {
weights[i] = Math.exp(nodeWeights.get(nodes[i]) - max);
/*
if (iteration > 1) {
if (nodes[i] == documentLeaves[doc]) {
System.out.print("* ");
}
System.out.println(((NCRPNode) nodes[i]).level + "\t" + weights[i] +
"\t" + nodeWeights.get(nodes[i]));
}
*/
sum += weights[i];
}
//if (iteration > 1) {System.out.println();}
node = nodes[ random.nextDiscrete(weights, sum) ];
// If we have picked an internal node, we need to
// add a new path.
if (! node.isLeaf()) {
node = node.getNewLeaf();
}
node.addPath();
documentLeaves[doc] = node;
for (level = numLevels - 1; level >= 0; level--) {
int[] types = typeCounts[level].keys();
for (int t: types) {
node.typeCounts[t] += typeCounts[level].get(t);
node.totalTokens += typeCounts[level].get(t);
}
node = node.parent;
}
}
public void calculateNCRP(TObjectDoubleHashMap<NCRPNode> nodeWeights,
NCRPNode node, double weight) {
for (NCRPNode child: node.children) {
calculateNCRP(nodeWeights, child,
weight + Math.log((double) child.customers / (node.customers + gamma)));
}
nodeWeights.put(node, weight + Math.log(gamma / (node.customers + gamma)));
}
public void calculateWordLikelihood(TObjectDoubleHashMap<NCRPNode> nodeWeights,
NCRPNode node, double weight,
TIntIntHashMap[] typeCounts, double[] newTopicWeights,
int level, int iteration) {
// First calculate the likelihood of the words at this level, given
// this topic.
double nodeWeight = 0.0;
int[] types = typeCounts[level].keys();
int totalTokens = 0;
//if (iteration > 1) { System.out.println(level + " " + nodeWeight); }
for (int type: types) {
for (int i=0; i<typeCounts[level].get(type); i++) {
nodeWeight +=
Math.log((eta + node.typeCounts[type] + i) /
(etaSum + node.totalTokens + totalTokens));
totalTokens++;
/*
if (iteration > 1) {
System.out.println("(" +eta + " + " + node.typeCounts[type] + " + " + i + ") /" +
"(" + etaSum + " + " + node.totalTokens + " + " + totalTokens + ")" +
" : " + nodeWeight);
}
*/
}
}
//if (iteration > 1) { System.out.println(level + " " + nodeWeight); }
// Propagate that weight to the child nodes
for (NCRPNode child: node.children) {
calculateWordLikelihood(nodeWeights, child, weight + nodeWeight,
typeCounts, newTopicWeights, level + 1, iteration);
}
// Finally, if this is an internal node, add the weight of
// a new path
level++;
while (level < numLevels) {
nodeWeight += newTopicWeights[level];
level++;
}
nodeWeights.adjustValue(node, nodeWeight);
}
/** Propagate a topic weight to a node and all its children.
weight is assumed to be a log.
*/
public void propagateTopicWeight(TObjectDoubleHashMap<NCRPNode> nodeWeights,
NCRPNode node, double weight) {
if (! nodeWeights.containsKey(node)) {
// calculating the NCRP prior proceeds from the
// root down (ie following child links),
// but adding the word-topic weights comes from
// the bottom up, following parent links and then
// child links. It's possible that the leaf node may have
// been removed just prior to this round, so the current
// node may not have an NCRP weight. If so, it's not
// going to be sampled anyway, so ditch it.
return;
}
for (NCRPNode child: node.children) {
propagateTopicWeight(nodeWeights, child, weight);
}
nodeWeights.adjustValue(node, weight);
}
public void sampleTopics(int doc) {
FeatureSequence fs = (FeatureSequence) instances.get(doc).getData();
int seqLen = fs.getLength();
int[] docLevels = levels[doc];
NCRPNode[] path = new NCRPNode[numLevels];
NCRPNode node;
int[] levelCounts = new int[numLevels];
int type, token, level;
double sum;
// Get the leaf
node = documentLeaves[doc];
for (level = numLevels - 1; level >= 0; level--) {
path[level] = node;
node = node.parent;
}
double[] levelWeights = new double[numLevels];
// Initialize level counts
for (token = 0; token < seqLen; token++) {
levelCounts[ docLevels[token] ]++;
}
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
levelCounts[ docLevels[token] ]--;
node = path[ docLevels[token] ];
node.typeCounts[type]--;
node.totalTokens--;
sum = 0.0;
for (level=0; level < numLevels; level++) {
levelWeights[level] =
(alpha + levelCounts[level]) *
(eta + path[level].typeCounts[type]) /
(etaSum + path[level].totalTokens);
sum += levelWeights[level];
}
level = random.nextDiscrete(levelWeights, sum);
docLevels[token] = level;
levelCounts[ docLevels[token] ]++;
node = path[ level ];
node.typeCounts[type]++;
node.totalTokens++;
}
}
/**
* Writes the current sampling state to the file specified in <code>stateFile</code>.
*/
public void printState() throws IOException, FileNotFoundException {
printState(new PrintWriter(new BufferedWriter(new FileWriter(stateFile))));
}
/**
* Write a text file describing the current sampling state.
*/
public void printState(PrintWriter out) throws IOException {
int doc = 0;
Alphabet alphabet = instances.getDataAlphabet();
for (Instance instance: instances) {
FeatureSequence fs = (FeatureSequence) instance.getData();
int seqLen = fs.getLength();
int[] docLevels = levels[doc];
NCRPNode node;
int type, token, level;
StringBuffer path = new StringBuffer();
// Start with the leaf, and build a string describing the path for this doc
node = documentLeaves[doc];
for (level = numLevels - 1; level >= 0; level--) {
path.append(node.nodeID + " ");
node = node.parent;
}
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
level = docLevels[token];
// The "" just tells java we're not trying to add a string and an int
out.println(path + "" + type + " " + alphabet.lookupObject(type) + " " + level + " ");
}
doc++;
}
}
public void printNodes() {
printNode(rootNode, 0);
}
public void printNode(NCRPNode node, int indent) {
StringBuffer out = new StringBuffer();
for (int i=0; i<indent; i++) {
out.append(" ");
}
out.append(node.totalTokens + "/" + node.customers + " ");
out.append(node.getTopWords(numWordsToDisplay));
System.out.println(out);
for (NCRPNode child: node.children) {
printNode(child, indent + 1);
}
}
/** For use with empirical likelihood evaluation:
* sample a path through the tree, then sample a multinomial over
* topics in that path, then return a weighted sum of words.
*/
public double empiricalLikelihood(int numSamples, InstanceList testing) {
NCRPNode[] path = new NCRPNode[numLevels];
NCRPNode node;
double weight;
path[0] = rootNode;
FeatureSequence fs;
int sample, level, type, token, doc, seqLen;
Dirichlet dirichlet = new Dirichlet(numLevels, alpha);
double[] levelWeights;
double[] multinomial = new double[numTypes];
double[][] likelihoods = new double[ testing.size() ][ numSamples ];
for (sample = 0; sample < numSamples; sample++) {
Arrays.fill(multinomial, 0.0);
for (level = 1; level < numLevels; level++) {
path[level] = path[level-1].selectExisting();
}
levelWeights = dirichlet.nextDistribution();
for (type = 0; type < numTypes; type++) {
for (level = 0; level < numLevels; level++) {
node = path[level];
multinomial[type] +=
levelWeights[level] *
(eta + node.typeCounts[type]) /
(etaSum + node.totalTokens);
}
}
for (type = 0; type < numTypes; type++) {
multinomial[type] = Math.log(multinomial[type]);
}
for (doc=0; doc<testing.size(); doc++) {
fs = (FeatureSequence) testing.get(doc).getData();
seqLen = fs.getLength();
for (token = 0; token < seqLen; token++) {
type = fs.getIndexAtPosition(token);
likelihoods[doc][sample] += multinomial[type];
}
}
}
double averageLogLikelihood = 0.0;
double logNumSamples = Math.log(numSamples);
for (doc=0; doc<testing.size(); doc++) {
double max = Double.NEGATIVE_INFINITY;
for (sample = 0; sample < numSamples; sample++) {
if (likelihoods[doc][sample] > max) {
max = likelihoods[doc][sample];
}
}
double sum = 0.0;
for (sample = 0; sample < numSamples; sample++) {
sum += Math.exp(likelihoods[doc][sample] - max);
}
averageLogLikelihood += Math.log(sum) + max - logNumSamples;
}
return averageLogLikelihood;
}
/**
* This method is primarily for testing purposes. The {@link cc.mallet.topics.tui.HierarchicalLDATUI}
* class has a more flexible interface for command-line use.
*/
public static void main (String[] args) {
try {
InstanceList instances = InstanceList.load(new File(args[0]));
InstanceList testing = InstanceList.load(new File(args[1]));
HierarchicalLDA sampler = new HierarchicalLDA();
sampler.initialize(instances, testing, 5, new Randoms());
sampler.estimate(250);
} catch (Exception e) {
e.printStackTrace();
}
}
class NCRPNode {
int customers;
ArrayList<NCRPNode> children;
NCRPNode parent;
int level;
int totalTokens;
int[] typeCounts;
public int nodeID;
public NCRPNode(NCRPNode parent, int dimensions, int level) {
customers = 0;
this.parent = parent;
children = new ArrayList<NCRPNode>();
this.level = level;
//System.out.println("new node at level " + level);
totalTokens = 0;
typeCounts = new int[dimensions];
nodeID = totalNodes;
totalNodes++;
}
public NCRPNode(int dimensions) {
this(null, dimensions, 0);
}
public NCRPNode addChild() {
NCRPNode node = new NCRPNode(this, typeCounts.length, level + 1);
children.add(node);
return node;
}
public boolean isLeaf() {
return level == numLevels - 1;
}
public NCRPNode getNewLeaf() {
NCRPNode node = this;
for (int l=level; l<numLevels - 1; l++) {
node = node.addChild();
}
return node;
}
public void dropPath() {
NCRPNode node = this;
node.customers--;
if (node.customers == 0) {
node.parent.remove(node);
}
for (int l = 1; l < numLevels; l++) {
node = node.parent;
node.customers--;
if (node.customers == 0) {
node.parent.remove(node);
}
}
}
public void remove(NCRPNode node) {
children.remove(node);
}
public void addPath() {
NCRPNode node = this;
node.customers++;
for (int l = 1; l < numLevels; l++) {
node = node.parent;
node.customers++;
}
}
public NCRPNode selectExisting() {
double[] weights = new double[children.size()];
int i = 0;
for (NCRPNode child: children) {
weights[i] = (double) child.customers / (gamma + customers);
i++;
}
int choice = random.nextDiscrete(weights);
return children.get(choice);
}
public NCRPNode select() {
double[] weights = new double[children.size() + 1];
weights[0] = gamma / (gamma + customers);
int i = 1;
for (NCRPNode child: children) {
weights[i] = (double) child.customers / (gamma + customers);
i++;
}
int choice = random.nextDiscrete(weights);
if (choice == 0) {
return(addChild());
}
else {
return children.get(choice - 1);
}
}
public String getTopWords(int numWords) {
IDSorter[] sortedTypes = new IDSorter[numTypes];
for (int type=0; type < numTypes; type++) {
sortedTypes[type] = new IDSorter(type, typeCounts[type]);
}
Arrays.sort(sortedTypes);
Alphabet alphabet = instances.getDataAlphabet();
StringBuffer out = new StringBuffer();
for (int i=0; i<10; i++) {
out.append(alphabet.lookupObject(sortedTypes[i].getID()) + " ");
}
return out.toString();
}
}
}