* @param cumProbThreshold
* @param maxNumTrans
*/
public static void normalize(Map<String, HMapSFW> probMap, float lexProbThreshold, float cumProbThreshold, int maxNumTrans) {
for (String sourceTerm : probMap.keySet()) {
HMapSFW probDist = probMap.get(sourceTerm);
TreeSet<PairOfStringFloat> sortedFilteredProbDist = new TreeSet<PairOfStringFloat>();
HMapSFW normProbDist = new HMapSFW();
// compute normalization factor
float sumProb = 0;
for (Entry<String> entry : probDist.entrySet()) {
sumProb += entry.getValue();
}
// normalize values and remove low-prob entries based on normalized values
float sumProb2 = 0;
for (Entry<String> entry : probDist.entrySet()) {
float pr = entry.getValue() / sumProb;
if (pr > lexProbThreshold) {
sumProb2 += pr;
sortedFilteredProbDist.add(new PairOfStringFloat(entry.getKey(), pr));
}
}
// re-normalize values after removal of low-prob entries
float cumProb = 0;
int cnt = 0;
while (cnt < maxNumTrans && cumProb < cumProbThreshold && !sortedFilteredProbDist.isEmpty()) {
PairOfStringFloat entry = sortedFilteredProbDist.pollLast();
float pr = entry.getValue() / sumProb2;
cumProb += pr;
normProbDist.put(entry.getKey(), pr);
cnt++;
}
probMap.put(sourceTerm, normProbDist);
}