package joshua.discriminative.variational_decoder.nbest;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.semiring_parsing.AtomicSemiring;
public class NbestCrunching {
int topN=300;
boolean useUniqueNbest =false;
boolean useTreeNbest = false;//still produce string, though the p(y)=p(y(d))
boolean addCombinedCost = true;
SymbolTable symbolTbl;
KBestExtractor kbestExtractor;
double scalingFactor= 1.0;
AtomicSemiring atomicSemirng = new AtomicSemiring(1,0);
public NbestCrunching(SymbolTable symbolTbl, double insideOutsideScalingFactor, int topN){
this.scalingFactor = insideOutsideScalingFactor;
this.topN = topN;
this.symbolTbl = symbolTbl;
this.kbestExtractor = new KBestExtractor(this.symbolTbl, this.useUniqueNbest, this.useTreeNbest, false, this.addCombinedCost, false, true);
}
//if return_disorder_nbest=true; then a disorder nbest with sum log_prob (sent_id ||| hyp ||| empty_feature_scores ||| sum_log_prob)
//else return 1best (just the hypotheses itself)
public List<String> processOneSent(HyperGraph hg, int sentenceID, boolean returnDisorderNbest){
List<String> result = new ArrayList<String> ();
/*
//### step-1: run inside-outside, rank the hg, and get normalization constant
//note, inside and outside will use the transition_cost of each hyperedge, this cost is already linearly interpolated
TrivialInsideOutside p_inside_outside = new TrivialInsideOutside();
p_inside_outside.run_inside_outside(hg, 0, 1, inside_outside_scaling_factor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
double norm_constant = p_inside_outside.get_normalization_constant();
p_inside_outside.clear_state();
*/
//### step-2: get a nbest derivations
List<String> nonUniqueNbestStrings = new ArrayList<String>();
kbestExtractor.lazyKBestExtractOnHG(hg, null, this.topN, sentenceID, nonUniqueNbestStrings);
//### step-3: get the sum for each of the unique strings
HashMap<String, Double> uniqueStringsSumProbTbl = new HashMap<String, Double>();
HashMap<String, Double> uniqueStringsViterbiProbTbl = new HashMap<String, Double>();//debug
HashMap<String, Integer> uniqueStringsNumDuplicatesTbl = new HashMap<String, Integer>();//debug
for(String derivationString : nonUniqueNbestStrings){
//System.out.println(derivation_string);
String[] fds = derivationString.split("\\s+\\|{3}\\s+");
String hypString = fds[1];
double logProb = new Double(fds[fds.length-1])*scalingFactor;//normalized log prob //TODO: use inside_outside_scaling_factor here
Double oldSum = (Double)uniqueStringsSumProbTbl.get(hypString);
if(oldSum==null){
oldSum= Double.NEGATIVE_INFINITY;//zero prob
uniqueStringsNumDuplicatesTbl.put(hypString, 1);
uniqueStringsViterbiProbTbl.put(hypString, logProb);
}else{
uniqueStringsNumDuplicatesTbl.put(hypString, uniqueStringsNumDuplicatesTbl.get(hypString)+1);
}
uniqueStringsSumProbTbl.put(hypString, atomicSemirng.add_in_atomic_semiring(oldSum, logProb));
}
//### step-4: find the nbest or find the translation string having the best sum-probablity
if(returnDisorderNbest){
for(String hyp : uniqueStringsSumProbTbl.keySet()){
StringBuffer fullHyp = new StringBuffer();
fullHyp.append(sentenceID); fullHyp.append(" ||| ");
fullHyp.append(hyp); fullHyp.append(" ||| ");
fullHyp.append("empty_feature_scores"); fullHyp.append(" ||| ");
fullHyp.append(uniqueStringsSumProbTbl.get(hyp));
result.add(fullHyp.toString());
//System.out.println(full_hyp.toString());
}
System.out.println("n_derivations=" + nonUniqueNbestStrings.size() + "; n_strings=" + uniqueStringsSumProbTbl.size());
}else{
double bestSumProb = Double.NEGATIVE_INFINITY;
String bestString = null;
double sumProb = Double.NEGATIVE_INFINITY;;
for(String hyp : uniqueStringsSumProbTbl.keySet()){
sumProb = uniqueStringsSumProbTbl.get(hyp);
if(sumProb > bestSumProb){
bestSumProb = sumProb;
bestString = hyp;
}
//System.out.println(sentenceID + " ||| " + hyp +" ||| " + n_duplicates + " ||| " + viter_prob + " ||| " + sum_prob);//+ " ||| " + Math.exp(max_log_prob-viter_prob)
}
System.out.println(sentenceID + " ||| " + bestString +" ||| " + bestSumProb);//un-normalized logProb
result.add(bestString);
}
return result;
}
public static void main(String[] args) throws InterruptedException, IOException {
if(args.length!=6){
System.out.println("Wrong number of parameters, it must be 5");
System.exit(1);
}
String testItemsFile=args[0].trim();
String testRulesFile=args[1].trim();
int numSents=new Integer(args[2].trim());
String onebestFile=args[3].trim();//output
int topN = new Integer(args[4].trim());
double insideOutsideScalingFactor = new Double(args[5].trim());
int ngramStateID = 0;
SymbolTable symbolTbl = new BuildinSymbol(null);
NbestCrunching cruncher = new NbestCrunching(symbolTbl, insideOutsideScalingFactor, topN);
BufferedWriter onebestWriter = FileUtilityOld.getWriteFileStream(onebestFile);
System.out.println("############Process file " + testItemsFile);
DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, true, null); //have model costs stored
diskHG.initRead(testItemsFile, testRulesFile,null);
for(int sentID=0; sentID < numSents; sentID ++){
System.out.println("#Process sentence " + sentID);
HyperGraph testHG = diskHG.readHyperGraph();
List<String> oneBest = cruncher.processOneSent(testHG, sentID, false);//produce the reranked onebest
FileUtilityOld.writeLzf(onebestWriter, oneBest.get(0) + "\n");
}
FileUtilityOld.closeWriteFile(onebestWriter);
}
}