HashMap<WordSequence, Integer> bigrams = new HashMap<WordSequence, Integer>();
HashMap<WordSequence, Integer> trigrams = new HashMap<WordSequence, Integer>();
int wordCount = 0;
if (words.size() > 0) {
addSequence(unigrams, new WordSequence(words.get(0)));
wordCount++;
}
if (words.size() > 1) {
wordCount++;
addSequence(unigrams, new WordSequence(words.get(1)));
addSequence(bigrams, new WordSequence(words.get(0), words.get(1)));
}
if (words.size() > 2) {
addSequence(bigrams, new WordSequence(words.get(1), words.get(2)));
addSequence(trigrams, new WordSequence(words.get(0), words.get(1), words
.get(2)));
}
for (int i = 2; i < words.size(); ++i) {
wordCount++;
addSequence(unigrams, new WordSequence(words.get(i)));
addSequence(bigrams, new WordSequence(words.get(i - 1), words.get(i)));
addSequence(trigrams, new WordSequence(words.get(i - 2),
words.get(i - 1),
words.get(i)));
}
float discount = .5f;
float deflate = 1 - discount;
Map<WordSequence, Float> uniprobs = new HashMap<WordSequence, Float>();
for (Map.Entry<WordSequence, Integer> e : unigrams.entrySet()) {
uniprobs.put(e.getKey(),
(float) e.getValue() * deflate / wordCount);
}
LogMath lmath = LogMath.getLogMath();
float logUnigramWeight = lmath.linearToLog(unigramWeight);
float invLogUnigramWeight = lmath.linearToLog(1 - unigramWeight);
float logUniformProb = -lmath.linearToLog(uniprobs.size());
Set<WordSequence> sorted1grams = new TreeSet<WordSequence>(unigrams.keySet());
Iterator<WordSequence> iter =
new TreeSet<WordSequence>(bigrams.keySet()).iterator();
WordSequence ws = iter.hasNext() ? iter.next() : null;
for (WordSequence unigram : sorted1grams) {
float p = lmath.linearToLog(uniprobs.get(unigram));
p += logUnigramWeight;
p = lmath.addAsLinear(p, logUniformProb + invLogUnigramWeight);
logProbs.put(unigram, p);
float sum = 0.f;
while (ws != null) {
int cmp = ws.getOldest().compareTo(unigram);
if (cmp > 0) {
break;
}
if (cmp == 0) {
sum += uniprobs.get(ws.getNewest());
}
ws = iter.hasNext() ? iter.next() : null;
}
logBackoffs.put(unigram, lmath.linearToLog(discount / (1 - sum)));
}
Map<WordSequence, Float> biprobs = new HashMap<WordSequence, Float>();
for (Map.Entry<WordSequence, Integer> entry : bigrams.entrySet()) {
int unigramCount = unigrams.get(entry.getKey().getOldest());
biprobs.put(entry.getKey(),
entry.getValue() * deflate / unigramCount);
}
Set<WordSequence> sorted2grams = new TreeSet<WordSequence>(bigrams.keySet());
iter = new TreeSet<WordSequence>(trigrams.keySet()).iterator();
ws = iter.hasNext() ? iter.next() : null;
for (WordSequence biword : sorted2grams) {
logProbs.put(biword, lmath.linearToLog(biprobs.get(biword)));
float sum = 0.f;
while (ws != null) {
int cmp = ws.getOldest().compareTo(biword);
if (cmp > 0) {
break;
}
if (cmp == 0) {
sum += biprobs.get(ws.getNewest());
}
ws = iter.hasNext() ? iter.next() : null;
}
logBackoffs.put(biword, lmath.linearToLog(discount / (1 - sum)));
}