Package joshua.discriminative.training.oracle

Source Code of joshua.discriminative.training.oracle.OracleExtractionOnHGV2$PrefixGrammar

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.discriminative.training.oracle;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;





import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.Support;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.util.FileUtility;

/**
* approximated BLEU
* (1) do not consider clipping effect
* (2) in the dynamic programming, do not maintain different states for different hyp length
* (3) brief penalty is calculated based on the avg ref length
* (4) using sentence-level BLEU, instead of doc-level BLEU
*
* @author Zhifei Li, <zhifei.work@gmail.com> (Johns Hopkins University)
* @version $LastChangedDate: 2009-04-02 15:34:43 -0400 $
*/
public class OracleExtractionOnHGV2 extends RefineHG<DPStateOracle> {
 
  static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>";
  public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state
 
  static String NULL_LEFT_LM_STATE_SYM="<lzflnull>";
  public int NULL_LEFT_LM_STATE_SYM_ID;//used for equivelant state
 
  static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>";
  public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state
 
 
//  int[] ref_sentence;//reference string (not tree)
  protected  int srcSentLen =0;
  protected  int refSentLen =0;
  protected  int lmOrder=4; //only used for decide whether to get the LM state by this class or not in compute_state
  static protected boolean doLocalNgramClip =false;
  static protected boolean maitainLengthState = false;
  static protected  int bleuOrder=4;
 
  static boolean useLeftEquivState = true;
  static boolean useRightEquivState = true;
 
  HashMap<String, Boolean> suffixTbl = new HashMap<String, Boolean>();
  HashMap<String, Boolean> prefixTbl = new HashMap<String, Boolean>();
 
  static PrefixGrammar prefixGrammar = new PrefixGrammar();//TODO
  static PrefixGrammar suffixGrammar = new PrefixGrammar();//TODO
 
 
 
  protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>();
 

  static boolean alwaysMaintainSeperateLMState = true; //if true: the virtual item maintain its own lm state regardless whether lm_order>=g_bleu_order
 
  /**
   *
   */
  SymbolTable symbolTable;
 
  int ngramStateID=0; //the baseline LM feature id
 
  /**
   * Constructs a new object capable of extracting
   * a tree from a hypergraph that most closely matches
   * a provided oracle sentence.
   * @param symbolTable_
   * @param lmFeatID_
   */
  public OracleExtractionOnHGV2(SymbolTable symbolTable_, int lmFeatID_){
    this.symbolTable = symbolTable_;
    this.ngramStateID = lmFeatID_;
    this.BACKOFF_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM);
    this.NULL_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
    this.NULL_RIGHT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
  }
 
 
 
  //find the oracle hypothesis in the nbest list
  public Object[] oracleExtractOnNbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){
    if(hg.goalNode==null)
      return null;
    kbest_extractor.resetState();       
    int next_n=0;
    double orc_bleu=-1;
    String orc_sent=null;
    while(true){
      String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//?????????
      //System.out.println(hyp_sent);
      if(hyp_sent==null || next_n > n) break;
      double t_bleu = computeSentenceBleu(this.symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4);
      if(t_bleu>orc_bleu){
        orc_bleu = t_bleu;
        orc_sent = hyp_sent;
      }     
    }
    System.out.println("Oracle sent in nbest: " + orc_sent);
    System.out.println("Oracle bleu in nbest: " + orc_bleu);
    Object[] res = new Object[2];
    res[0]=orc_sent;
    res[1]=orc_bleu;
    return res;
  }

 
  public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn, int lmOrder_,  String refSentStr){
   
    int[] refSent = this.symbolTable.addTerminals(refSentStr.split("\\s+"));
    lmOrder= lmOrder_;   
    srcSentLen = srcSentLenIn;
    refSentLen = refSent.length;   
   
    refNgramsTbl.clear();
    getNgrams(refNgramsTbl,bleuOrder,refSent, false)
    if(useLeftEquivState || useRightEquivState){
      prefixTbl.clear();  suffixTbl.clear();
      setupPrefixSuffixTbl(refSent,  bleuOrder, prefixTbl, suffixTbl);
      setupPrefixSuffixGrammar(refSent,  bleuOrder, prefixGrammar, suffixGrammar);//TODO
    }
   
    return splitHG(hg);
  }
   
  private double computeAvgLen(int spanLen, int srcSentLen, int refSentLen){
    return (spanLen>=srcSentLen) ? refSentLen :  spanLen*refSentLen*1.0/srcSentLen;//avg len?
  }

 

  @Override
  protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) {
    return new HyperEdge(originalEdge.getRule(), dps.bestDerivationLogP, null, antVirtualItems, originalEdge.getSourcePath());
  }


   
//  =================== commmon funcions ==========================
  //based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state
  //get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS
  protected  DPStateOracle computeState(HGNode parentNode, HyperEdge dt, List<HGNode> antVirtualItems){ 
    double refLen = computeAvgLen(parentNode.j-parentNode.i, srcSentLen, refSentLen);
   
    //=== hypereges under "goal item" does not have rule
    if(dt.getRule()==null){
      if(antVirtualItems.size()!=1){
        System.out.println("error deduction under goal item have more than one item");
        System.exit(0);
      }
      double bleu = antVirtualItems.get(0).bestHyperedge.bestDerivationLogP;
      return  new DPStateOracle(0, null, null,null, bleu);//no DPState at all
    }
   
    //======== hypereges *not* under "goal item"   
    HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination
    HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed
    int hypLen =0;
    int[] numNgramMatches = new int[bleuOrder];
    int[] enWords = dt.getRule().getEnglish();
   
    //=== calulate new and old ngram counts, and len
      ArrayList<Integer> words= new ArrayList<Integer>();
      ArrayList<Integer> leftStateSequence = null; //used for compute left-lm state
      ArrayList<Integer> rightStateSequence = null; //used for compute right-lm state
      int correctLMOrder = lmOrder;
      if(alwaysMaintainSeperateLMState==true || lmOrder<bleuOrder) {
        leftStateSequence = new ArrayList<Integer>();
        rightStateSequence = new ArrayList<Integer>();
        correctLMOrder = bleuOrder;//if lm_order is smaller than g_bleu_order, we will get the lm state by ourself
      }
     
      //==== get leftStateSequence, rightStateSequence, hypLen, num_ngram_match
      for(int c=0; c<enWords.length; c++){
       
        int c_id = enWords[c];
        if(symbolTable.isNonterminal(c_id)==true){
         
          int index=this.symbolTable.getTargetNonterminalIndex(c_id);
          DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(index)).dpState;         
          hypLen += antDPState.bestLen;
          for(int t=0; t<bleuOrder; t++)
            numNgramMatches[t] += antDPState.ngramMatches[t];
             
          List<Integer> l_context = antDPState.leftLMState;
          List<Integer> r_context = antDPState.rightLMState;
         
          for(int t : l_context){//always have l_context
            words.add(t);
            if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
              leftStateSequence.add(t);
          }
          getNgrams(oldNgramCounts, bleuOrder, l_context, true);         
          if(r_context.size()>=correctLMOrder-1){//the right and left are NOT overlapping       
            getNgrams(newNgramCounts, bleuOrder, words, true);
            getNgrams(oldNgramCounts, bleuOrder, r_context, true);
            words.clear();//start a new chunk   
            if(rightStateSequence!=null)rightStateSequence.clear();
            for(int t : r_context)
              words.add(t);           
          }
          if(rightStateSequence!=null)
            for(int t : r_context)
              rightStateSequence.add(t);
        }else{
          words.add(c_id);
          hypLen += 1;
         
          if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
            leftStateSequence.add(c_id);
         
          if(rightStateSequence!=null)
            rightStateSequence.add(c_id);
        }
      }
      getNgrams(newNgramCounts, bleuOrder, words, true);
   
      //=== now deduct ngram counts
      Iterator iter = newNgramCounts.keySet().iterator();
      while(iter.hasNext()){
        String ngram = (String)iter.next();
        if(refNgramsTbl.containsKey(ngram)){
          int finalCount = newNgramCounts.get(ngram);
          if(oldNgramCounts.containsKey(ngram)){
            finalCount -= oldNgramCounts.get(ngram);
            if(finalCount<0){
              System.out.println("error: negative count for ngram: "+ this.symbolTable.getWord(11844) + "; new: " + newNgramCounts.get(ngram) +"; old: " +oldNgramCounts.get(ngram) );
              System.exit(0);
            }
          }
          if(finalCount>0){//TODO: not correct/global ngram clip
            if(doLocalNgramClip)
              numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ;
            else
              numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping         
          }
        }
      }
     
      //=== now calculate the BLEU score and state
      List<Integer> leftLMState = null;
      List<Integer>  rightLMState= null;
    if(alwaysMaintainSeperateLMState==false && lmOrder>=bleuOrder){  //do not need to change lm state, just use orignal lm state
      NgramDPState state     = (NgramDPState) parentNode.getDPState(this.ngramStateID);
      leftLMState = state.getLeftLMStateWords();
      rightLMState = state.getRightLMStateWords();
    }else{
      leftLMState = getLeftEquivState(leftStateSequence, suffixTbl);
      rightLMState = getRightEquivState(rightStateSequence, prefixTbl);
     
      //debug
      //System.out.println("lm_order is " + lm_order);
      //compare_two_int_arrays(left_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID));
      //compare_two_int_arrays(right_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID));
      //end           
    }
   
    double bleu = computeBleu(hypLen, refLen, numNgramMatches, bleuOrder);
   
    return  new DPStateOracle(hypLen, numNgramMatches, leftLMState, rightLMState, bleu);
  }
 
 
  private List<Integer> getLeftEquivState(List<Integer> leftStateSequence, HashMap<String, Boolean> suffixTbl){
   
    int l_size = (leftStateSequence.size()<bleuOrder-1)? leftStateSequence.size() : (bleuOrder-1);
   
    if(useLeftEquivState==false || l_size<bleuOrder-1){//regular
      return leftStateSequence;
    }else{
      List<Integer> leftLMState = new ArrayList<Integer>(l_size);
      for(int i=l_size-1; i>=0; i--){//right to left
        if(isASuffixInTbl(leftStateSequence, 0, i, suffixTbl)){
        //if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){
          for(int j=i; j>=0; j--)
            leftLMState.set(j, leftStateSequence.get(j));
          break;
        }else{
          leftLMState.set(i, this.NULL_LEFT_LM_STATE_SYM_ID);
        }
      }
      return leftLMState;
    }
   
  }
 
 
  private  List<Integer> getRightEquivState(List<Integer> rightStateSequence, HashMap<String, Boolean> prefixTbl){
   
    int r_size = (rightStateSequence.size()<bleuOrder-1)? rightStateSequence.size() : (bleuOrder-1);
   
    if(useRightEquivState==false || r_size<bleuOrder-1){//regular
      return rightStateSequence;
    }else{
       List<Integer> rightLMState = new ArrayList<Integer>(r_size);
      for(int i=0; i<r_size; i++){//left to right
        if(isAPrefixInTbl(rightStateSequence, rightStateSequence.size()-r_size+i, rightStateSequence.size()-1, prefixTbl)){
        //if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, grammar_prefix)){
          for(int j=i; j<r_size; j++)
            rightLMState.set(j, rightStateSequence.get(rightStateSequence.size()-r_size+j) );
          break;
        }else{
          rightLMState.set(i, this.NULL_RIGHT_LM_STATE_SYM_ID );
        }
      }
      //System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+ "; equiv right:" + Symbol.get_string(right_lm_state));
      return rightLMState; 
    }
   
  }
 
  //=================================================================================================
  //==================== ngram extraction functions ==========================================
  //=================================================================================================
  public void getNgrams(HashMap<String, Integer> tbl, int order, int[] wrds, boolean ignoreNullEquivSymbol){
    for(int i=0; i<wrds.length; i++)
      for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
        boolean contain_null=false;
        StringBuffer ngram = new StringBuffer();
        for(int k=i; k<=i+j; k++){
          if(wrds[k]==this.NULL_LEFT_LM_STATE_SYM_ID || wrds[k]==this.NULL_RIGHT_LM_STATE_SYM_ID ){
            contain_null=true;
            if(ignoreNullEquivSymbol)
              break;
          }
          ngram.append(wrds[k]);
          if(k<i+j) ngram.append(" ");
        }
        if(ignoreNullEquivSymbol && contain_null)
          continue;//skip this ngram
        String ngram_str = ngram.toString();
        if(tbl.containsKey(ngram_str))
          tbl.put(ngram_str,  tbl.get(ngram_str)+1);
        else
          tbl.put(ngram_str, 1);
      }
  }
 
//  accumulate ngram counts into tbl
  public void getNgrams(HashMap<String, Integer>  tbl, int order, List<Integer> wrds, boolean ignoreNullEquivSymbol){
    for(int i=0; i<wrds.size(); i++)
      for(int j=0; j<order && j+i<wrds.size(); j++){//ngram: [i,i+j]
        boolean contain_null=false;
        StringBuffer ngram = new StringBuffer();
        for(int k=i; k<=i+j; k++){
          int t_wrd = wrds.get(k);
          if(t_wrd==this.NULL_LEFT_LM_STATE_SYM_ID || t_wrd==this.NULL_RIGHT_LM_STATE_SYM_ID ){
            contain_null=true;
            if(ignoreNullEquivSymbol)
              break;
          }
          ngram.append(t_wrd);
          if(k<i+j) ngram.append(" ");
        }
        if(ignoreNullEquivSymbol && contain_null)
          continue;//skip this ngram
        String ngram_str = ngram.toString();
        if(tbl.containsKey(ngram_str))
          tbl.put(ngram_str, tbl.get(ngram_str)+1);
        else
          tbl.put(ngram_str, 1);
      }
  }
 
 
  //=================================================================================================
  //==================== BLEU-related functions ==========================================
  //=================================================================================================
  //TODO: consider merge with joshua.decoder.BLEU
 
  //do_ngram_clip: consider global n-gram clip
  public  double computeSentenceBleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order){
    int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+"));
    int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+"));
    return computeSentenceBleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);   
  }
 
 
  public  double computeSentenceBleu( int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order){
    double res_bleu = 0;
    int order =4;
    HashMap<String, Integer>  ref_ngram_tbl = new HashMap<String, Integer> ();
    getNgrams(ref_ngram_tbl, order, ref_sent,false);
    HashMap<String, Integer>  hyp_ngram_tbl = new HashMap<String, Integer> ();
    getNgrams(hyp_ngram_tbl, order, hyp_sent,false);
   
    int[] num_ngram_match = new int[order];
    for(Iterator it = hyp_ngram_tbl.keySet().iterator(); it.hasNext();){
      String ngram = (String) it.next();
      if(ref_ngram_tbl.containsKey(ngram)){
        if(do_ngram_clip)
          num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin(ref_ngram_tbl.get(ngram),hyp_ngram_tbl.get(ngram)); //ngram clip
        else
          num_ngram_match[ngram.split("\\s+").length-1] += hyp_ngram_tbl.get(ngram);//without ngram count clipping         
        }
    }
    res_bleu = computeBleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
    //System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
    //    " " + num_ngram_match[2] + " " +num_ngram_match[3]);

    return res_bleu;
  }
   
 
//  sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
  public static double computeBleu(int hypLen, double refLen, int[] numNgramMatches, int bleuOrder){
    if(hypLen<=0 || refLen<=0){
      System.out.println("error: ref or hyp is zero len");
      System.exit(0);
    }
   
    double res=0;   
    double wt = 1.0/bleuOrder;
    double prec = 0;
    double smoothFactor=1.0;
    for(int t=0; t<bleuOrder && t<hypLen; t++){
      if(numNgramMatches[t]>0)
        prec += wt*Math.log(numNgramMatches[t]*1.0/(hypLen-t));
      else{
        smoothFactor *= 0.5;//TODO
        prec += wt*Math.log(smoothFactor/(hypLen-t));
      }
    }
   
    double bp = (hypLen>=refLen) ? 1.0 : Math.exp(1-refLen/hypLen)
   
    res = bp*Math.exp(prec);
    //System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res);
    return res;
  }
 
 
 
 
 
  //=================================================================================================
  //==================== table-based suffix/prefix lookup==========================================
  //=================================================================================================
 
 
  public static void setupPrefixSuffixTbl(int[] wrds, int order, HashMap<String, Boolean>  prefixTbl, HashMap<String, Boolean> suffixTbl){
    for(int i=0; i<wrds.length; i++)
      for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
        StringBuffer ngram = new StringBuffer()

        //=== prefix
        for(int k=i; k<i+j; k++){//all ngrams [i,i+j-1]
          ngram.append(wrds[k]);
          prefixTbl.put(ngram.toString(),true);
          ngram.append(" ");
        }       
       
        //=== suffix: right-most wrd first
        ngram = new StringBuffer();
        for(int k=i+j; k>i; k--){//all ngrams [i+1,i+j]: reverse order
          ngram.append(wrds[k]);
          suffixTbl.put(ngram.toString(),true);//stored in reverse order
          ngram.append(" ");
        }       
      }
  }
 

  private boolean isAPrefixInTbl(List<Integer> rightStateSequence, int startPos, int endPos, HashMap<String, Boolean> prefixTbl){
   
    if( rightStateSequence.get(startPos)==this.NULL_RIGHT_LM_STATE_SYM_ID)
      return false;
   
    StringBuffer prefix = new StringBuffer();
    for(int i=startPos; i<=endPos; i++){
      prefix.append(rightStateSequence.get(i));
      if(i<endPos)
        prefix.append(" ");
    }   
   
    return  prefixTbl.containsKey(prefix.toString());
  }
 
 
  private boolean isASuffixInTbl(List<Integer> leftStateSequence, int startPos, int endPos, HashMap<String, Boolean> suffixTbl){
   
    if( leftStateSequence.get(endPos)==this.NULL_LEFT_LM_STATE_SYM_ID)
      return false;
   
    StringBuffer suffix = new StringBuffer();
    for(int i=endPos; i>=startPos; i--){//right-most first
      suffix.append(leftStateSequence.get(i));
      if(i>startPos) suffix.append(" ");
    }   
   
    return suffixTbl.containsKey(suffix.toString());
  }
 
 
 
 
  //=================================================================================================
  //==================== grammar-based suffix/prefix lookup==========================================
  //=================================================================================================
 
  public static void setupPrefixSuffixGrammar(int[] wrds,  int order, PrefixGrammar prefix_gr, PrefixGrammar suffix_gr){
    for(int i=0; i<wrds.length; i++)
      for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
        //### prefix
        prefix_gr.add_ngram(wrds, i, i+j-1);//ngram: [i,i+j-1]
              
        //### suffix: right-most wrd first
        int[] reverse_wrds = new int[j];
        for(int k=i+j, t=0; k>i; k--){//all ngrams [i+1,i+j]: reverse order
          reverse_wrds[t++] = wrds[k];
        }
        suffix_gr.add_ngram(reverse_wrds, 0, j-1);
      }
  }
 
 
 
  private boolean isAPrefixInGrammar(ArrayList<Integer> rightStateSequence, int start_pos, int end_pos, PrefixGrammar gr_prefix){
    if( rightStateSequence.get(start_pos)==this.NULL_RIGHT_LM_STATE_SYM_ID)
      return false;
    return gr_prefix.containNgram(rightStateSequence,  start_pos,  end_pos);
  }
 
  private boolean isASuffixInGrammar(ArrayList<Integer> leftStateSequence, int start_pos, int end_pos, PrefixGrammar grammar_suffix){
    if( leftStateSequence.get(end_pos)== this.NULL_LEFT_LM_STATE_SYM_ID)
      return false;
    ArrayList<Integer> suffix = new ArrayList<Integer>();
    for(int i=end_pos; i>=start_pos; i--){//right-most first
      suffix.add(leftStateSequence.get(i));
    }   
    return grammar_suffix.containNgram(suffix,  0,  suffix.size()-1);
  }
 
 
  /*a backoff node is a hashtable, it may include:
   * (1) probabilititis for next words
   * (2) pointers to a next-layer backoff node (hashtable)
   * (3) backoff weight for this node
   * (4) suffix/prefix flag to indicate that there is ngrams start from this suffix
     */
  private static class PrefixGrammar {
    HashMap<Integer, HashMap> root = new HashMap<Integer, HashMap>();
   
    //add prefix information
    public void add_ngram(int[] wrds, int start_pos, int end_pos){     
      //######### identify the position, and insert the trinodes if necessary
      HashMap<Integer, HashMap> pos = root;
      for(int k=start_pos; k <=end_pos; k++){
        int cur_sym_id=wrds[k];
        HashMap<Integer, HashMap> next_layer = pos.get(cur_sym_id);
        if(next_layer!=null){
          pos=next_layer;
        }else{   
          HashMap<Integer, HashMap> tem = new HashMap<Integer, HashMap>();//next layer node
          pos.put(cur_sym_id, tem);
          pos = tem;
        }
      }
    }
   
    public boolean containNgram(ArrayList<Integer> wrds, int start_pos, int end_pos){
      if(end_pos<start_pos)return false;
      HashMap pos = root;
      for(int k=start_pos; k <=end_pos; k++){
        int cur_sym_id=  wrds.get(k);
        HashMap next_layer = (HashMap) pos.get(cur_sym_id);
        if(next_layer!=null){
          pos=next_layer;
        }else{
          return false;
        }
      }
      return true;
    }     
  }
 
 
 


  //=================================================================================================
  //====================  example main function ==========================================
  //=================================================================================================
 
 
  /*for 919 sent, time_on_reading: 148797
  time_on_orc_extract: 580286*/
  public static void main(String[] args) throws IOException {
 
    /*String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items";
    String f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules";
    String f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1";
    String f_orc_out ="C:\\Users\\zli\\Documents\\mt03.orc.txt";*/
    if(args.length!=6){
      System.out.println("wrong command, correct command should be: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]);
      System.exit(0);   
    }   
    String f_hypergraphs = args[0].trim();
    String f_rule_tbl = args[1].trim();
    String f_ref_files = args[2].trim();
    String f_orc_out =  args[3].trim();
    int lm_order = Integer.parseInt(args[4].trim());
    boolean orc_extract_nbest = new Boolean(args[5].trim()); //oracle extraction from nbest or hg
   
    boolean saveModelScores = true;
   
    //????????????????????????????????????????????????????
    int baseline_lm_feat_id = 0;
    //??????????????????????????????????????
   
    SymbolTable p_symbolTable = new BuildinSymbol(null);
   
    KBestExtractor kbestExtractor =null;
    int topN=300;//TODO
    boolean extract_unique_nbest = true;//TODO
    boolean do_ngram_clip_nbest = true; //TODO
    if(orc_extract_nbest==true){
      System.out.println("oracle extraction from nbest list");
      kbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false,  false, true);
    }
   
    BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out);
    boolean rerankKbestOracles = true;
    BufferedWriter rerankOrcOut=null;
    if(rerankKbestOracles==true){
      rerankOrcOut = FileUtility.getWriteFileStream(f_orc_out+".rerank");
    }
   
    long start_time0 = System.currentTimeMillis();
    long time_on_reading = 0;
    long time_on_orc_extract = 0;
    BufferedReader t_reader_ref = FileUtility.getReadFileStream(f_ref_files);
   
    DiskHyperGraph dhg_read  = new DiskHyperGraph(p_symbolTable, baseline_lm_feat_id, saveModelScores, null);
   
 
    dhg_read.initRead(f_hypergraphs, f_rule_tbl, null);
   
    KBestExtractor oracleKbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, true,  false, true);//extract kbest oracles
    KBestExtractor rerankOracleKbestExtractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false,  false, true);//extract kbest oracles
    int topKOracles= 500;//TODO
    //OracleExtractionOnHGV2 orc_extractor = new OracleExtractionOnHGV2(p_symbolTable, baseline_lm_feat_id);
    OracleExtractionOnHGV3 orc_extractor = new OracleExtractionOnHGV3(p_symbolTable);
    String ref_sent= null;
    int sent_id=0;
    long start_time = System.currentTimeMillis();
    while( (ref_sent=FileUtility.read_line_lzf(t_reader_ref))!= null ){
      System.out.println("############Process sentence " + sent_id);
      start_time = System.currentTimeMillis();
      sent_id++;
      //if(sent_id>10)break;
     
      HyperGraph hg = dhg_read.readHyperGraph();
      if(hg==null)continue;
     
      double orc_bleu=0;
     
      //System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time));
      time_on_reading += System.currentTimeMillis()-start_time;
      start_time = System.currentTimeMillis();
     
      if(orc_extract_nbest){
        Object[] res = orc_extractor.oracleExtractOnNbest(kbestExtractor, hg, topN, do_ngram_clip_nbest, ref_sent);
        String orc_sent = (String) res[0];
        orc_bleu = (Double) res[1];
        orc_out.write(orc_sent+"\n");
      }else{       
        HyperGraph hg_oracle = orc_extractor.oracleExtractOnHG(hg, hg.sentLen, lm_order, ref_sent);
        oracleKbestExtractor.lazyKBestExtractOnHG(hg_oracle, null, topKOracles, hg.sentID, orc_out);
        orc_bleu = hg_oracle.goalNode.bestHyperedge.bestDerivationLogP;
        time_on_orc_extract += System.currentTimeMillis()-start_time;
        //System.out.println("num_virtual_items: " + orc_extractor.numRefinedNodes + " num_virtual_dts: " + orc_extractor.numRefinedEdges);
        //System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time));
       
       
        //==== rerank the kbest-oracles to verify the approximation for DP is ok
        if(rerankKbestOracles){
          Object[] res = orc_extractor.oracleExtractOnNbest(rerankOracleKbestExtractor, hg_oracle, topKOracles, do_ngram_clip_nbest, ref_sent);
          String orc_sent = (String) res[0];
          //double rerankedOrcBleu = (Double) res[1];
          rerankOrcOut.write(orc_sent+"\n");
        }       
      }     
     
      System.out.println("orc bleu is " + orc_bleu);
    }
    t_reader_ref.close();
    orc_out.close();
    if(rerankOrcOut!=null)
      rerankOrcOut.close();
   
    System.out.println("time_on_reading: " + time_on_reading);
    System.out.println("time_on_orc_extract: " + time_on_orc_extract);
    System.out.println("total running time: "
      + (System.currentTimeMillis() - start_time0));
  }



 
 
}
TOP

Related Classes of joshua.discriminative.training.oracle.OracleExtractionOnHGV2$PrefixGrammar

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.