/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.discriminative.training.oracle;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.Support;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.util.FileUtility;
/**
* approximated BLEU
* (1) do not consider clipping effect
* (2) in the dynamic programming, do not maintain different states for different hyp length
* (3) brief penalty is calculated based on the avg ref length
* (4) using sentence-level BLEU, instead of doc-level BLEU
*
* @author Zhifei Li, <zhifei.work@gmail.com> (Johns Hopkins University)
* @version $LastChangedDate: 2009-04-02 15:34:43 -0400 $
*/
public class OracleExtractionOnHGV2 extends RefineHG<DPStateOracle> {
static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>";
public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state
static String NULL_LEFT_LM_STATE_SYM="<lzflnull>";
public int NULL_LEFT_LM_STATE_SYM_ID;//used for equivelant state
static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>";
public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state
// int[] ref_sentence;//reference string (not tree)
protected int srcSentLen =0;
protected int refSentLen =0;
protected int lmOrder=4; //only used for decide whether to get the LM state by this class or not in compute_state
static protected boolean doLocalNgramClip =false;
static protected boolean maitainLengthState = false;
static protected int bleuOrder=4;
static boolean useLeftEquivState = true;
static boolean useRightEquivState = true;
HashMap<String, Boolean> suffixTbl = new HashMap<String, Boolean>();
HashMap<String, Boolean> prefixTbl = new HashMap<String, Boolean>();
static PrefixGrammar prefixGrammar = new PrefixGrammar();//TODO
static PrefixGrammar suffixGrammar = new PrefixGrammar();//TODO
protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>();
static boolean alwaysMaintainSeperateLMState = true; //if true: the virtual item maintain its own lm state regardless whether lm_order>=g_bleu_order
/**
*
*/
SymbolTable symbolTable;
int ngramStateID=0; //the baseline LM feature id
/**
* Constructs a new object capable of extracting
* a tree from a hypergraph that most closely matches
* a provided oracle sentence.
* @param symbolTable_
* @param lmFeatID_
*/
public OracleExtractionOnHGV2(SymbolTable symbolTable_, int lmFeatID_){
this.symbolTable = symbolTable_;
this.ngramStateID = lmFeatID_;
this.BACKOFF_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM);
this.NULL_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
this.NULL_RIGHT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
}
//find the oracle hypothesis in the nbest list
public Object[] oracleExtractOnNbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){
if(hg.goalNode==null)
return null;
kbest_extractor.resetState();
int next_n=0;
double orc_bleu=-1;
String orc_sent=null;
while(true){
String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//?????????
//System.out.println(hyp_sent);
if(hyp_sent==null || next_n > n) break;
double t_bleu = computeSentenceBleu(this.symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4);
if(t_bleu>orc_bleu){
orc_bleu = t_bleu;
orc_sent = hyp_sent;
}
}
System.out.println("Oracle sent in nbest: " + orc_sent);
System.out.println("Oracle bleu in nbest: " + orc_bleu);
Object[] res = new Object[2];
res[0]=orc_sent;
res[1]=orc_bleu;
return res;
}
public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn, int lmOrder_, String refSentStr){
int[] refSent = this.symbolTable.addTerminals(refSentStr.split("\\s+"));
lmOrder= lmOrder_;
srcSentLen = srcSentLenIn;
refSentLen = refSent.length;
refNgramsTbl.clear();
getNgrams(refNgramsTbl,bleuOrder,refSent, false);
if(useLeftEquivState || useRightEquivState){
prefixTbl.clear(); suffixTbl.clear();
setupPrefixSuffixTbl(refSent, bleuOrder, prefixTbl, suffixTbl);
setupPrefixSuffixGrammar(refSent, bleuOrder, prefixGrammar, suffixGrammar);//TODO
}
return splitHG(hg);
}
private double computeAvgLen(int spanLen, int srcSentLen, int refSentLen){
return (spanLen>=srcSentLen) ? refSentLen : spanLen*refSentLen*1.0/srcSentLen;//avg len?
}
@Override
protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) {
return new HyperEdge(originalEdge.getRule(), dps.bestDerivationLogP, null, antVirtualItems, originalEdge.getSourcePath());
}
// =================== commmon funcions ==========================
//based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state
//get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS
protected DPStateOracle computeState(HGNode parentNode, HyperEdge dt, List<HGNode> antVirtualItems){
double refLen = computeAvgLen(parentNode.j-parentNode.i, srcSentLen, refSentLen);
//=== hypereges under "goal item" does not have rule
if(dt.getRule()==null){
if(antVirtualItems.size()!=1){
System.out.println("error deduction under goal item have more than one item");
System.exit(0);
}
double bleu = antVirtualItems.get(0).bestHyperedge.bestDerivationLogP;
return new DPStateOracle(0, null, null,null, bleu);//no DPState at all
}
//======== hypereges *not* under "goal item"
HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination
HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed
int hypLen =0;
int[] numNgramMatches = new int[bleuOrder];
int[] enWords = dt.getRule().getEnglish();
//=== calulate new and old ngram counts, and len
ArrayList<Integer> words= new ArrayList<Integer>();
ArrayList<Integer> leftStateSequence = null; //used for compute left-lm state
ArrayList<Integer> rightStateSequence = null; //used for compute right-lm state
int correctLMOrder = lmOrder;
if(alwaysMaintainSeperateLMState==true || lmOrder<bleuOrder) {
leftStateSequence = new ArrayList<Integer>();
rightStateSequence = new ArrayList<Integer>();
correctLMOrder = bleuOrder;//if lm_order is smaller than g_bleu_order, we will get the lm state by ourself
}
//==== get leftStateSequence, rightStateSequence, hypLen, num_ngram_match
for(int c=0; c<enWords.length; c++){
int c_id = enWords[c];
if(symbolTable.isNonterminal(c_id)==true){
int index=this.symbolTable.getTargetNonterminalIndex(c_id);
DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(index)).dpState;
hypLen += antDPState.bestLen;
for(int t=0; t<bleuOrder; t++)
numNgramMatches[t] += antDPState.ngramMatches[t];
List<Integer> l_context = antDPState.leftLMState;
List<Integer> r_context = antDPState.rightLMState;
for(int t : l_context){//always have l_context
words.add(t);
if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
leftStateSequence.add(t);
}
getNgrams(oldNgramCounts, bleuOrder, l_context, true);
if(r_context.size()>=correctLMOrder-1){//the right and left are NOT overlapping
getNgrams(newNgramCounts, bleuOrder, words, true);
getNgrams(oldNgramCounts, bleuOrder, r_context, true);
words.clear();//start a new chunk
if(rightStateSequence!=null)rightStateSequence.clear();
for(int t : r_context)
words.add(t);
}
if(rightStateSequence!=null)
for(int t : r_context)
rightStateSequence.add(t);
}else{
words.add(c_id);
hypLen += 1;
if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
leftStateSequence.add(c_id);
if(rightStateSequence!=null)
rightStateSequence.add(c_id);
}
}
getNgrams(newNgramCounts, bleuOrder, words, true);
//=== now deduct ngram counts
Iterator iter = newNgramCounts.keySet().iterator();
while(iter.hasNext()){
String ngram = (String)iter.next();
if(refNgramsTbl.containsKey(ngram)){
int finalCount = newNgramCounts.get(ngram);
if(oldNgramCounts.containsKey(ngram)){
finalCount -= oldNgramCounts.get(ngram);
if(finalCount<0){
System.out.println("error: negative count for ngram: "+ this.symbolTable.getWord(11844) + "; new: " + newNgramCounts.get(ngram) +"; old: " +oldNgramCounts.get(ngram) );
System.exit(0);
}
}
if(finalCount>0){//TODO: not correct/global ngram clip
if(doLocalNgramClip)
numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ;
else
numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping
}
}
}
//=== now calculate the BLEU score and state
List<Integer> leftLMState = null;
List<Integer> rightLMState= null;
if(alwaysMaintainSeperateLMState==false && lmOrder>=bleuOrder){ //do not need to change lm state, just use orignal lm state
NgramDPState state = (NgramDPState) parentNode.getDPState(this.ngramStateID);
leftLMState = state.getLeftLMStateWords();
rightLMState = state.getRightLMStateWords();
}else{
leftLMState = getLeftEquivState(leftStateSequence, suffixTbl);
rightLMState = getRightEquivState(rightStateSequence, prefixTbl);
//debug
//System.out.println("lm_order is " + lm_order);
//compare_two_int_arrays(left_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID));
//compare_two_int_arrays(right_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID));
//end
}
double bleu = computeBleu(hypLen, refLen, numNgramMatches, bleuOrder);
return new DPStateOracle(hypLen, numNgramMatches, leftLMState, rightLMState, bleu);
}
private List<Integer> getLeftEquivState(List<Integer> leftStateSequence, HashMap<String, Boolean> suffixTbl){
int l_size = (leftStateSequence.size()<bleuOrder-1)? leftStateSequence.size() : (bleuOrder-1);
if(useLeftEquivState==false || l_size<bleuOrder-1){//regular
return leftStateSequence;
}else{
List<Integer> leftLMState = new ArrayList<Integer>(l_size);
for(int i=l_size-1; i>=0; i--){//right to left
if(isASuffixInTbl(leftStateSequence, 0, i, suffixTbl)){
//if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){
for(int j=i; j>=0; j--)
leftLMState.set(j, leftStateSequence.get(j));
break;
}else{
leftLMState.set(i, this.NULL_LEFT_LM_STATE_SYM_ID);
}
}
return leftLMState;
}
}
private List<Integer> getRightEquivState(List<Integer> rightStateSequence, HashMap<String, Boolean> prefixTbl){
int r_size = (rightStateSequence.size()<bleuOrder-1)? rightStateSequence.size() : (bleuOrder-1);
if(useRightEquivState==false || r_size<bleuOrder-1){//regular
return rightStateSequence;
}else{
List<Integer> rightLMState = new ArrayList<Integer>(r_size);
for(int i=0; i<r_size; i++){//left to right
if(isAPrefixInTbl(rightStateSequence, rightStateSequence.size()-r_size+i, rightStateSequence.size()-1, prefixTbl)){
//if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, grammar_prefix)){
for(int j=i; j<r_size; j++)
rightLMState.set(j, rightStateSequence.get(rightStateSequence.size()-r_size+j) );
break;
}else{
rightLMState.set(i, this.NULL_RIGHT_LM_STATE_SYM_ID );
}
}
//System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+ "; equiv right:" + Symbol.get_string(right_lm_state));
return rightLMState;
}
}
//=================================================================================================
//==================== ngram extraction functions ==========================================
//=================================================================================================
public void getNgrams(HashMap<String, Integer> tbl, int order, int[] wrds, boolean ignoreNullEquivSymbol){
for(int i=0; i<wrds.length; i++)
for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
boolean contain_null=false;
StringBuffer ngram = new StringBuffer();
for(int k=i; k<=i+j; k++){
if(wrds[k]==this.NULL_LEFT_LM_STATE_SYM_ID || wrds[k]==this.NULL_RIGHT_LM_STATE_SYM_ID ){
contain_null=true;
if(ignoreNullEquivSymbol)
break;
}
ngram.append(wrds[k]);
if(k<i+j) ngram.append(" ");
}
if(ignoreNullEquivSymbol && contain_null)
continue;//skip this ngram
String ngram_str = ngram.toString();
if(tbl.containsKey(ngram_str))
tbl.put(ngram_str, tbl.get(ngram_str)+1);
else
tbl.put(ngram_str, 1);
}
}
// accumulate ngram counts into tbl
public void getNgrams(HashMap<String, Integer> tbl, int order, List<Integer> wrds, boolean ignoreNullEquivSymbol){
for(int i=0; i<wrds.size(); i++)
for(int j=0; j<order && j+i<wrds.size(); j++){//ngram: [i,i+j]
boolean contain_null=false;
StringBuffer ngram = new StringBuffer();
for(int k=i; k<=i+j; k++){
int t_wrd = wrds.get(k);
if(t_wrd==this.NULL_LEFT_LM_STATE_SYM_ID || t_wrd==this.NULL_RIGHT_LM_STATE_SYM_ID ){
contain_null=true;
if(ignoreNullEquivSymbol)
break;
}
ngram.append(t_wrd);
if(k<i+j) ngram.append(" ");
}
if(ignoreNullEquivSymbol && contain_null)
continue;//skip this ngram
String ngram_str = ngram.toString();
if(tbl.containsKey(ngram_str))
tbl.put(ngram_str, tbl.get(ngram_str)+1);
else
tbl.put(ngram_str, 1);
}
}
//=================================================================================================
//==================== BLEU-related functions ==========================================
//=================================================================================================
//TODO: consider merge with joshua.decoder.BLEU
//do_ngram_clip: consider global n-gram clip
public double computeSentenceBleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order){
int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+"));
int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+"));
return computeSentenceBleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);
}
public double computeSentenceBleu( int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order){
double res_bleu = 0;
int order =4;
HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer> ();
getNgrams(ref_ngram_tbl, order, ref_sent,false);
HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer> ();
getNgrams(hyp_ngram_tbl, order, hyp_sent,false);
int[] num_ngram_match = new int[order];
for(Iterator it = hyp_ngram_tbl.keySet().iterator(); it.hasNext();){
String ngram = (String) it.next();
if(ref_ngram_tbl.containsKey(ngram)){
if(do_ngram_clip)
num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin(ref_ngram_tbl.get(ngram),hyp_ngram_tbl.get(ngram)); //ngram clip
else
num_ngram_match[ngram.split("\\s+").length-1] += hyp_ngram_tbl.get(ngram);//without ngram count clipping
}
}
res_bleu = computeBleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
//System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
// " " + num_ngram_match[2] + " " +num_ngram_match[3]);
return res_bleu;
}
// sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
public static double computeBleu(int hypLen, double refLen, int[] numNgramMatches, int bleuOrder){
if(hypLen<=0 || refLen<=0){
System.out.println("error: ref or hyp is zero len");
System.exit(0);
}
double res=0;
double wt = 1.0/bleuOrder;
double prec = 0;
double smoothFactor=1.0;
for(int t=0; t<bleuOrder && t<hypLen; t++){
if(numNgramMatches[t]>0)
prec += wt*Math.log(numNgramMatches[t]*1.0/(hypLen-t));
else{
smoothFactor *= 0.5;//TODO
prec += wt*Math.log(smoothFactor/(hypLen-t));
}
}
double bp = (hypLen>=refLen) ? 1.0 : Math.exp(1-refLen/hypLen);
res = bp*Math.exp(prec);
//System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res);
return res;
}
//=================================================================================================
//==================== table-based suffix/prefix lookup==========================================
//=================================================================================================
public static void setupPrefixSuffixTbl(int[] wrds, int order, HashMap<String, Boolean> prefixTbl, HashMap<String, Boolean> suffixTbl){
for(int i=0; i<wrds.length; i++)
for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
StringBuffer ngram = new StringBuffer();
//=== prefix
for(int k=i; k<i+j; k++){//all ngrams [i,i+j-1]
ngram.append(wrds[k]);
prefixTbl.put(ngram.toString(),true);
ngram.append(" ");
}
//=== suffix: right-most wrd first
ngram = new StringBuffer();
for(int k=i+j; k>i; k--){//all ngrams [i+1,i+j]: reverse order
ngram.append(wrds[k]);
suffixTbl.put(ngram.toString(),true);//stored in reverse order
ngram.append(" ");
}
}
}
private boolean isAPrefixInTbl(List<Integer> rightStateSequence, int startPos, int endPos, HashMap<String, Boolean> prefixTbl){
if( rightStateSequence.get(startPos)==this.NULL_RIGHT_LM_STATE_SYM_ID)
return false;
StringBuffer prefix = new StringBuffer();
for(int i=startPos; i<=endPos; i++){
prefix.append(rightStateSequence.get(i));
if(i<endPos)
prefix.append(" ");
}
return prefixTbl.containsKey(prefix.toString());
}
private boolean isASuffixInTbl(List<Integer> leftStateSequence, int startPos, int endPos, HashMap<String, Boolean> suffixTbl){
if( leftStateSequence.get(endPos)==this.NULL_LEFT_LM_STATE_SYM_ID)
return false;
StringBuffer suffix = new StringBuffer();
for(int i=endPos; i>=startPos; i--){//right-most first
suffix.append(leftStateSequence.get(i));
if(i>startPos) suffix.append(" ");
}
return suffixTbl.containsKey(suffix.toString());
}
//=================================================================================================
//==================== grammar-based suffix/prefix lookup==========================================
//=================================================================================================
public static void setupPrefixSuffixGrammar(int[] wrds, int order, PrefixGrammar prefix_gr, PrefixGrammar suffix_gr){
for(int i=0; i<wrds.length; i++)
for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
//### prefix
prefix_gr.add_ngram(wrds, i, i+j-1);//ngram: [i,i+j-1]
//### suffix: right-most wrd first
int[] reverse_wrds = new int[j];
for(int k=i+j, t=0; k>i; k--){//all ngrams [i+1,i+j]: reverse order
reverse_wrds[t++] = wrds[k];
}
suffix_gr.add_ngram(reverse_wrds, 0, j-1);
}
}
private boolean isAPrefixInGrammar(ArrayList<Integer> rightStateSequence, int start_pos, int end_pos, PrefixGrammar gr_prefix){
if( rightStateSequence.get(start_pos)==this.NULL_RIGHT_LM_STATE_SYM_ID)
return false;
return gr_prefix.containNgram(rightStateSequence, start_pos, end_pos);
}
private boolean isASuffixInGrammar(ArrayList<Integer> leftStateSequence, int start_pos, int end_pos, PrefixGrammar grammar_suffix){
if( leftStateSequence.get(end_pos)== this.NULL_LEFT_LM_STATE_SYM_ID)
return false;
ArrayList<Integer> suffix = new ArrayList<Integer>();
for(int i=end_pos; i>=start_pos; i--){//right-most first
suffix.add(leftStateSequence.get(i));
}
return grammar_suffix.containNgram(suffix, 0, suffix.size()-1);
}
/*a backoff node is a hashtable, it may include:
* (1) probabilititis for next words
* (2) pointers to a next-layer backoff node (hashtable)
* (3) backoff weight for this node
* (4) suffix/prefix flag to indicate that there is ngrams start from this suffix
*/
private static class PrefixGrammar {
HashMap<Integer, HashMap> root = new HashMap<Integer, HashMap>();
//add prefix information
public void add_ngram(int[] wrds, int start_pos, int end_pos){
//######### identify the position, and insert the trinodes if necessary
HashMap<Integer, HashMap> pos = root;
for(int k=start_pos; k <=end_pos; k++){
int cur_sym_id=wrds[k];
HashMap<Integer, HashMap> next_layer = pos.get(cur_sym_id);
if(next_layer!=null){
pos=next_layer;
}else{
HashMap<Integer, HashMap> tem = new HashMap<Integer, HashMap>();//next layer node
pos.put(cur_sym_id, tem);
pos = tem;
}
}
}
public boolean containNgram(ArrayList<Integer> wrds, int start_pos, int end_pos){
if(end_pos<start_pos)return false;
HashMap pos = root;
for(int k=start_pos; k <=end_pos; k++){
int cur_sym_id= wrds.get(k);
HashMap next_layer = (HashMap) pos.get(cur_sym_id);
if(next_layer!=null){
pos=next_layer;
}else{
return false;
}
}
return true;
}
}
//=================================================================================================
//==================== example main function ==========================================
//=================================================================================================
/*for 919 sent, time_on_reading: 148797
time_on_orc_extract: 580286*/
public static void main(String[] args) throws IOException {
/*String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items";
String f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules";
String f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1";
String f_orc_out ="C:\\Users\\zli\\Documents\\mt03.orc.txt";*/
if(args.length!=6){
System.out.println("wrong command, correct command should be: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest");
System.out.println("num of args is "+ args.length);
for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]);
System.exit(0);
}
String f_hypergraphs = args[0].trim();
String f_rule_tbl = args[1].trim();
String f_ref_files = args[2].trim();
String f_orc_out = args[3].trim();
int lm_order = Integer.parseInt(args[4].trim());
boolean orc_extract_nbest = new Boolean(args[5].trim()); //oracle extraction from nbest or hg
boolean saveModelScores = true;
//????????????????????????????????????????????????????
int baseline_lm_feat_id = 0;
//??????????????????????????????????????
SymbolTable p_symbolTable = new BuildinSymbol(null);
KBestExtractor kbestExtractor =null;
int topN=300;//TODO
boolean extract_unique_nbest = true;//TODO
boolean do_ngram_clip_nbest = true; //TODO
if(orc_extract_nbest==true){
System.out.println("oracle extraction from nbest list");
kbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false, false, true);
}
BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out);
boolean rerankKbestOracles = true;
BufferedWriter rerankOrcOut=null;
if(rerankKbestOracles==true){
rerankOrcOut = FileUtility.getWriteFileStream(f_orc_out+".rerank");
}
long start_time0 = System.currentTimeMillis();
long time_on_reading = 0;
long time_on_orc_extract = 0;
BufferedReader t_reader_ref = FileUtility.getReadFileStream(f_ref_files);
DiskHyperGraph dhg_read = new DiskHyperGraph(p_symbolTable, baseline_lm_feat_id, saveModelScores, null);
dhg_read.initRead(f_hypergraphs, f_rule_tbl, null);
KBestExtractor oracleKbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, true, false, true);//extract kbest oracles
KBestExtractor rerankOracleKbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false, false, true);//extract kbest oracles
int topKOracles= 500;//TODO
//OracleExtractionOnHGV2 orc_extractor = new OracleExtractionOnHGV2(p_symbolTable, baseline_lm_feat_id);
OracleExtractionOnHGV3 orc_extractor = new OracleExtractionOnHGV3(p_symbolTable);
String ref_sent= null;
int sent_id=0;
long start_time = System.currentTimeMillis();
while( (ref_sent=FileUtility.read_line_lzf(t_reader_ref))!= null ){
System.out.println("############Process sentence " + sent_id);
start_time = System.currentTimeMillis();
sent_id++;
//if(sent_id>10)break;
HyperGraph hg = dhg_read.readHyperGraph();
if(hg==null)continue;
double orc_bleu=0;
//System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time));
time_on_reading += System.currentTimeMillis()-start_time;
start_time = System.currentTimeMillis();
if(orc_extract_nbest){
Object[] res = orc_extractor.oracleExtractOnNbest(kbestExtractor, hg, topN, do_ngram_clip_nbest, ref_sent);
String orc_sent = (String) res[0];
orc_bleu = (Double) res[1];
orc_out.write(orc_sent+"\n");
}else{
HyperGraph hg_oracle = orc_extractor.oracleExtractOnHG(hg, hg.sentLen, lm_order, ref_sent);
oracleKbestExtractor.lazyKBestExtractOnHG(hg_oracle, null, topKOracles, hg.sentID, orc_out);
orc_bleu = hg_oracle.goalNode.bestHyperedge.bestDerivationLogP;
time_on_orc_extract += System.currentTimeMillis()-start_time;
//System.out.println("num_virtual_items: " + orc_extractor.numRefinedNodes + " num_virtual_dts: " + orc_extractor.numRefinedEdges);
//System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time));
//==== rerank the kbest-oracles to verify the approximation for DP is ok
if(rerankKbestOracles){
Object[] res = orc_extractor.oracleExtractOnNbest(rerankOracleKbestExtractor, hg_oracle, topKOracles, do_ngram_clip_nbest, ref_sent);
String orc_sent = (String) res[0];
//double rerankedOrcBleu = (Double) res[1];
rerankOrcOut.write(orc_sent+"\n");
}
}
System.out.println("orc bleu is " + orc_bleu);
}
t_reader_ref.close();
orc_out.close();
if(rerankOrcOut!=null)
rerankOrcOut.close();
System.out.println("time_on_reading: " + time_on_reading);
System.out.println("time_on_orc_extract: " + time_on_orc_extract);
System.out.println("total running time: "
+ (System.currentTimeMillis() - start_time0));
}
}