boolean shouldSaveState,
boolean readjustTopicsAndStats /* currently ignored */) {
int[] oneDocTopics = topicSequence.getFeatures();
TIntIntHashMap currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
// populate topic counts
TIntIntHashMap localTopicCounts = new TIntIntHashMap();
for (int position = 0; position < docLength; position++) {
localTopicCounts.adjustOrPutValue(oneDocTopics[position], 1, 1);
}
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
for (int topic: localTopicCounts.keys()) {
int n = localTopicCounts.get(topic);
// initialize the normalization constant for the (B * n_{t|d}) term
topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
// update the coefficients for the non-zero topics
cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
}
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
currentTypeTopicCounts = typeTopicCounts[type];
assert(currentTypeTopicCounts.get(oldTopic) >= 0);
// Remove this token from all counts.
// Note that we actually want to remove the key if it goes
// to zero, not set it to 0.
if (currentTypeTopicCounts.get(oldTopic) == 1) {
currentTypeTopicCounts.remove(oldTopic);
}
else {
currentTypeTopicCounts.adjustValue(oldTopic, -1);
}
smoothingOnlyMass -= alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts.get(oldTopic) /
(tokensPerTopic[oldTopic] + betaSum);
if (localTopicCounts.get(oldTopic) == 1) {
localTopicCounts.remove(oldTopic);
}
else {
localTopicCounts.adjustValue(oldTopic, -1);
}
tokensPerTopic[oldTopic]--;
smoothingOnlyMass += alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass += beta * localTopicCounts.get(oldTopic) /
(tokensPerTopic[oldTopic] + betaSum);
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts.get(oldTopic)) /
(tokensPerTopic[oldTopic] + betaSum);
topicTermMass = 0.0;
topicTermIndices = currentTypeTopicCounts.keys();
topicTermValues = currentTypeTopicCounts.getValues();
for (i=0; i < topicTermIndices.length; i++) {
int topic = topicTermIndices[i];
score =
cachedCoefficients[topic] * topicTermValues[i];
// ((alpha[topic] + localTopicCounts.get(topic)) *
// topicTermValues[i]) /
// (tokensPerTopic[topic] + betaSum);
// Note: I tried only doing this next bit if
// score > 0, but it didn't make any difference,
// at least in the first few iterations.
topicTermMass += score;
topicTermScores[i] = score;
// topicTermIndices[i] = topic;
}
// indicate that this is the last topic
// topicTermIndices[i] = -1;
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
//topicTermCount++;
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = topicTermIndices[i];
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
topicTermIndices = localTopicCounts.keys();
topicTermValues = localTopicCounts.getValues();
for (i=0; i < topicTermIndices.length; i++) {
newTopic = topicTermIndices[i];
sample -= topicTermValues[i] /
(tokensPerTopic[newTopic] + betaSum);
if (sample <= 0.0) {
break;
}
}
}
else {
//smoothingOnlyCount++;
sample -= topicBetaMass;
sample /= beta;
for (int topic = 0; topic < numTopics; topic++) {
sample -= alpha[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
}
if (newTopic == -1) {
System.err.println("LDAHyper sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("LDAHyper: New topic not sampled.");
}
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
smoothingOnlyMass -= alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts.get(newTopic) /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts.adjustOrPutValue(newTopic, 1, 1);
tokensPerTopic[newTopic]++;
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts.get(newTopic)) /
(tokensPerTopic[newTopic] + betaSum);
smoothingOnlyMass += alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts.get(newTopic) /
(tokensPerTopic[newTopic] + betaSum);
assert(currentTypeTopicCounts.get(newTopic) >= 0);
}
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (int topic: localTopicCounts.keys()) {
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
if (shouldSaveState) {
// Update the document-topic count histogram,
// for dirichlet estimation
docLengthCounts[ docLength ]++;
for (int topic: localTopicCounts.keys()) {
topicDocCounts[topic][ localTopicCounts.get(topic) ]++;
}
}
}