/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder;
import joshua.util.io.LineReader;
import joshua.util.FileUtility;
import joshua.util.Ngram;
import joshua.util.Regex;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.TimeUnit;
/**
* this class implements:
* (1) nbest min risk (MBR) reranking using BLEU as a gain funtion.
* <p>
* This assume that the string is unique in the nbest list In Hiero,
* due to spurious ambiguity, a string may correspond to many
* possible derivations, and ideally the probability of a string
* should be the sum of all the derivataions leading to that string.
* But, in practice, one normally uses a Viterbi approximation: the
* probability of a string is its best derivation probability So,
* if one want to deal with spurious ambiguity, he/she should do
* that before calling this class
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-07 22:36:11 -0600 (Thu, 07 Jan 2010) $
*/
public class NbestMinRiskReranker {
//TODO: this functionality is not implemented yet; default is to produce 1best without any feature scores;
boolean produceRerankedNbest = false;
double scalingFactor = 1.0;
static int bleuOrder = 4;
static boolean doNgramClip = true;
static boolean useGoogleLinearCorpusGain = false;
final PriorityBlockingQueue<RankerResult> resultsQueue =
new PriorityBlockingQueue<RankerResult>();
public NbestMinRiskReranker(boolean produceRerankedNbest, double scalingFactor) {
this.produceRerankedNbest = produceRerankedNbest;
this.scalingFactor = scalingFactor;
}
public String processOneSent( List<String> nbest, int sentID) {
System.out.println("Now process sentence " + sentID);
//step-0: preprocess
//assumption: each hyp has a formate: "sent_id ||| hyp_itself ||| feature scores ||| linear-combination-of-feature-scores(this should be logP)"
List<String> hypsItself = new ArrayList<String>();
//ArrayList<String> l_feat_scores = new ArrayList<String>();
List<Double> baselineScores = new ArrayList<Double>(); // linear combination of all baseline features
List<HashMap<String,Integer>> ngramTbls = new ArrayList<HashMap<String,Integer>>();
List<Integer> sentLens = new ArrayList<Integer>();
for (String hyp : nbest) {
String[] fds = Regex.threeBarsWithSpace.split(hyp);
int tSentID = Integer.parseInt(fds[0]);
if (sentID != tSentID) {
throw new RuntimeException("sentence_id does not match");
}
String hypothesis = (fds.length==4) ? fds[1] : "";
hypsItself.add(hypothesis);
String[] words = Regex.spaces.split(hypothesis);
sentLens.add(words.length);
HashMap<String,Integer> ngramTbl = new HashMap<String,Integer>();
Ngram.getNgrams(ngramTbl, 1, bleuOrder, words);
ngramTbls.add(ngramTbl);
//l_feat_scores.add(fds[2]);
// The value of finalIndex is expected to be 3,
// unless the hyp_itself is empty,
// in which case finalIndex will be 2.
int finalIndex = fds.length - 1;
baselineScores.add(Double.parseDouble(fds[finalIndex]));
}
//step-1: get normalized distribution
/**value in baselineScores will be changed to normalized probability
* */
computeNormalizedProbs(baselineScores, scalingFactor);
List<Double> normalizedProbs = baselineScores;
//=== required by google linear corpus gain
HashMap<String, Double> posteriorCountsTbl = null;
if (useGoogleLinearCorpusGain) {
posteriorCountsTbl = new HashMap<String,Double>();
getGooglePosteriorCounts(ngramTbls, normalizedProbs, posteriorCountsTbl);
}
//step-2: rerank the nbest
/**TODO: zhifei: now the re-ranking takes O(n^2) where n is the size of the nbest.
* But, we can significantly speed up this (leadding to O(n)) by
* first estimating a model on nbest, and then rerank the nbest
* using the estimated model.
* */
double bestGain = -1000000000;//set as worst gain
String bestHyp = null;
List<Double> gains = new ArrayList<Double>();
for (int i = 0; i < hypsItself.size(); i++) {
String curHyp = hypsItself.get(i);
int curHypLen = sentLens.get(i);
HashMap<String, Integer> curHypNgramTbl = ngramTbls.get(i);
//double cur_gain = computeGain(cur_hyp, l_hyp_itself, l_normalized_probs);
double curGain = 0;
if (useGoogleLinearCorpusGain) {
curGain = computeExpectedLinearCorpusGain(curHypLen, curHypNgramTbl, posteriorCountsTbl);
} else {
curGain = computeExpectedGain(curHypLen, curHypNgramTbl, ngramTbls, sentLens,normalizedProbs);
}
gains.add( curGain);
if (i == 0 || curGain > bestGain) { // maximize
bestGain = curGain;
bestHyp = curHyp;
}
}
//step-3: output the 1best or nbest
if (this.produceRerankedNbest) {
//TOTO: sort the list and write the reranked nbest; Use Collections.sort(List list, Comparator c)
} else {
/*
this.out.write(best_hyp);
this.out.write("\n");
out.flush();
*/
}
System.out.println("best gain: " + bestGain);
if (null == bestHyp) {
throw new RuntimeException("mbr reranked one best is null, must be wrong");
}
return bestHyp;
}
/**based on a list of log-probabilities in nbestLogProbs, obtain a
* normalized distribution, and put the normalized probability (real value in [0,1]) into nbestLogProbs
* */
//get a normalized distributeion and put it back to nbestLogProbs
static public void computeNormalizedProbs(List<Double> nbestLogProbs, double scalingFactor){
//=== get noralization constant, remember features, remember the combined linear score
double normalizationConstant = Double.NEGATIVE_INFINITY;//log-semiring
for (double logp : nbestLogProbs) {
normalizationConstant = addInLogSemiring(normalizationConstant, logp * scalingFactor, 0);
}
//System.out.println("normalization_constant (logP) is " + normalization_constant);
//=== get normalized prob for each hyp
double tSum = 0;
for (int i = 0; i < nbestLogProbs.size(); i++) {
double normalizedProb = Math.exp(nbestLogProbs.get(i) * scalingFactor-normalizationConstant);
tSum += normalizedProb;
nbestLogProbs.set(i, normalizedProb);
if (Double.isNaN(normalizedProb)) {
throw new RuntimeException(
"prob is NaN, must be wrong\nnbest_logps.get(i): "
+ nbestLogProbs.get(i)
+ "; scaling_factor: " + scalingFactor
+ "; normalization_constant:" + normalizationConstant );
}
//logger.info("probability: " + normalized_prob);
}
//sanity check
if (Math.abs(tSum - 1.0) > 1e-4) {
throw new RuntimeException("probabilities not sum to one, must be wrong");
}
}
//Gain(e) = negative risk = \sum_{e'} G(e, e')P(e')
//curHyp: e
//trueHyp: e'
public double computeExpectedGain(int curHypLen, HashMap<String, Integer> curHypNgramTbl, List<HashMap<String,Integer>> ngramTbls,
List<Integer> sentLens, List<Double> nbestProbs) {
//### get noralization constant, remember features, remember the combined linear score
double gain = 0;
for (int i = 0; i < nbestProbs.size(); i++) {
HashMap<String,Integer> trueHypNgramTbl = ngramTbls.get(i);
double trueProb = nbestProbs.get(i);
int trueLen = sentLens.get(i);
gain += trueProb * BLEU.computeSentenceBleu(trueLen, trueHypNgramTbl, curHypLen, curHypNgramTbl, doNgramClip, bleuOrder);
}
//System.out.println("Gain is " + gain);
return gain;
}
//Gain(e) = negative risk = \sum_{e'} G(e, e')P(e')
//curHyp: e
//trueHyp: e'
static public double computeExpectedGain(String curHyp, List<String> nbestHyps, List<Double> nbestProbs) {
//### get noralization constant, remember features, remember the combined linear score
double gain = 0;
for (int i = 0; i < nbestHyps.size(); i++) {
String trueHyp = nbestHyps.get(i);
double trueProb = nbestProbs.get(i);
gain += trueProb * BLEU.computeSentenceBleu(trueHyp, curHyp, doNgramClip, bleuOrder);
}
//System.out.println("Gain is " + gain);
return gain;
}
void getGooglePosteriorCounts( List<HashMap<String,Integer>> ngramTbls, List<Double> normalizedProbs, HashMap<String,Double> posteriorCountsTbl) {
//TODO
}
double computeExpectedLinearCorpusGain(int curHypLen, HashMap<String,Integer> curHypNgramTbl, HashMap<String,Double> posteriorCountsTbl) {
//TODO
double[] thetas = { -1, 1, 1, 1, 1 };
double res = 0;
res += thetas[0] * curHypLen;
for (Entry<String,Integer> entry : curHypNgramTbl.entrySet()) {
String key = entry.getKey();
String[] tem = Regex.spaces.split(key);
double post_prob = posteriorCountsTbl.get(key);
res += entry.getValue() * post_prob * thetas[tem.length];
}
return res;
}
// OR: return Math.log(Math.exp(x) + Math.exp(y));
static private double addInLogSemiring(double x, double y, int addMode){//prevent over-flow
if (addMode == 0) { // sum
if (x == Double.NEGATIVE_INFINITY) {//if y is also n-infinity, then return n-infinity
return y;
}
if (y == Double.NEGATIVE_INFINITY) {
return x;
}
if (y <= x) {
return x + Math.log(1+Math.exp(y-x));
} else {
return y + Math.log(1+Math.exp(x-y));
}
} else if (addMode == 1) { // viter-min
return (x <= y) ? x : y;
} else if (addMode == 2) { // viter-max
return (x >= y) ? x : y;
} else {
throw new RuntimeException("invalid add mode");
}
}
public static void main(String[] args) throws IOException {
// If you don't know what to use for scaling factor, try using 1
if (args.length<4 || args.length>5) {
System.out.println("wrong command, correct command should be: java NbestMinRiskReranker f_nbest_in f_out produce_reranked_nbest scaling_factor [numThreads]");
System.out.println("num of args is "+ args.length);
for(int i = 0; i < args.length; i++) {
System.out.println("arg is: " + args[i]);
}
System.exit(-1);
}
long startTime = System.currentTimeMillis();
String inputNbest = args[0].trim();
String output = args[1].trim();
boolean produceRerankedNbest = Boolean.valueOf(args[2].trim());
double scalingFactor = Double.parseDouble(args[3].trim());
int numThreads = (args.length==5) ? Integer.parseInt(args[4].trim()) : 1;
BufferedWriter outWriter = FileUtility.getWriteFileStream(output);
NbestMinRiskReranker mbrReranker =
new NbestMinRiskReranker(produceRerankedNbest, scalingFactor);
System.out.println("##############running mbr reranking");
int oldSentID = -1;
LineReader nbestReader = new LineReader(inputNbest);
List<String> nbest = new ArrayList<String>();
if (numThreads==1) {
try { for (String line : nbestReader) {
String[] fds = Regex.threeBarsWithSpace.split(line);
int newSentID = Integer.parseInt(fds[0]);
if (oldSentID != -1 && oldSentID != newSentID) {
String best_hyp = mbrReranker.processOneSent(nbest, oldSentID);//nbest: list of unique strings
outWriter.write(best_hyp);
outWriter.newLine();
outWriter.flush();
nbest.clear();
}
oldSentID = newSentID;
nbest.add(line);
} } finally { nbestReader.close(); }
//last nbest
String bestHyp = mbrReranker.processOneSent(nbest, oldSentID);
outWriter.write(bestHyp);
outWriter.newLine();
outWriter.flush();
nbest.clear();
outWriter.close();
} else {
ExecutorService threadPool = Executors.newFixedThreadPool(numThreads);
for (String line : nbestReader) {
String[] fds = Regex.threeBarsWithSpace.split(line);
int newSentID = Integer.parseInt(fds[0]);
if (oldSentID != -1 && oldSentID != newSentID) {
threadPool.execute(mbrReranker.new RankerTask(nbest, oldSentID));
nbest.clear();
}
oldSentID = newSentID;
nbest.add(line);
}
//last nbest
threadPool.execute(mbrReranker.new RankerTask(nbest, oldSentID));
nbest.clear();
threadPool.shutdown();
try {
threadPool.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS);
while (! mbrReranker.resultsQueue.isEmpty()) {
RankerResult result = mbrReranker.resultsQueue.remove();
String best_hyp = result.toString();
outWriter.write(best_hyp);
outWriter.newLine();
}
outWriter.flush();
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
outWriter.close();
}
}
System.out.println("Total running time (seconds) is "
+ (System.currentTimeMillis() - startTime) / 1000.0);
}
private class RankerTask implements Runnable {
final List<String> nbest;
final int sentID;
RankerTask(final List<String> nbest, final int sentID) {
this.nbest = new ArrayList<String>(nbest);
this.sentID = sentID;
}
public void run() {
String result = processOneSent(nbest, sentID);
resultsQueue.add(new RankerResult(result,sentID));
}
}
private static class RankerResult implements Comparable<RankerResult> {
final String result;
final Integer sentenceNumber;
RankerResult(String result, int sentenceNumber) {
this.result = result;
this.sentenceNumber = sentenceNumber;
}
public int compareTo(RankerResult o) {
return sentenceNumber.compareTo(o.sentenceNumber);
}
public String toString() {
return result;
}
}
}