Package joshua.decoder.ff.lm

Source Code of joshua.decoder.ff.lm.NgramExtractor

package joshua.decoder.ff.lm;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.util.Ngram;

public class NgramExtractor {

  private int ngramStateID;
  private int baselineLMOrder;
  private SymbolTable symbolTable;
  private boolean useIntegerNgram;
 
  private static String START_SYM="<s>";
  private int START_SYM_ID;
  private static String STOP_SYM="</s>";
  private int STOP_SYM_ID;
 
 
  //default is useIntegerNgram
  public NgramExtractor(SymbolTable symbolTable, int ngramStateID, int baselineLMOrder){
    this(symbolTable, ngramStateID, true, baselineLMOrder);   
  }
 
  public NgramExtractor(SymbolTable symbolTable, int ngramStateID, boolean useIntegerNgram, int baselineLMOrder){
   
    this.symbolTable = symbolTable;
    this.ngramStateID = ngramStateID;
    this.useIntegerNgram = useIntegerNgram;
    this.baselineLMOrder = baselineLMOrder;
   
    this.START_SYM_ID = this.symbolTable.addTerminal(START_SYM);
    this.STOP_SYM_ID = this.symbolTable.addTerminal(STOP_SYM);   
  }
 
   /**for generative model, should set startNgramOrder=endNgramOrder*/
   public HashMap<String,Integer> getTransitionNgrams(HyperEdge dt, int startNgramOrder, int endNgramOrder){
     return getTransitionNgrams(dt.getRule(), dt.getAntNodes(), startNgramOrder, endNgramOrder);
   }
 
   /**for generative model, should set startNgramOrder=endNgramOrder*/
   public HashMap<String,Integer> getTransitionNgrams(Rule rule, List<HGNode> antNodes, int startNgramOrder, int endNgramOrder){
     return computeTransitionNgrams(rule, antNodes, startNgramOrder, endNgramOrder);
   }
  
   /**does not work for generative model*/
   public HashMap<String,Integer>  getRuleNgrams(Rule rule, int startNgramOrder, int endNgramOrder){
     return computeRuleNgrams(rule, startNgramOrder, endNgramOrder);
   }
 
   /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/
   public HashMap<String,Integer>  getFutureNgrams(Rule rule, DPState curDPState, int startNgramOrder, int endNgramOrder){
     //TODO: do not consider <s> and </s>
     boolean addStart = false;
     boolean addEnd = false;
    
     return  computeFutureNgrams( (NgramDPState)curDPState, startNgramOrder, endNgramOrder, addStart, addEnd);
   }
 
   /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/
   public HashMap<String,Integer>  getFinalTransitionNgrams(HyperEdge edge, int startNgramOrder, int endNgramOrder){
     return getFinalTransitionNgrams(edge.getAntNodes().get(0), startNgramOrder, endNgramOrder);
   }
  
   /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/
   public HashMap<String,Integer>  getFinalTransitionNgrams(HGNode antNode, int startNgramOrder, int endNgramOrder){
     return computeFinalTransitionNgrams(antNode,  startNgramOrder,  endNgramOrder);
   }
  
  
 
   /**work for both generative and discriminative model
    * */
   //TODO: consider speed up this function
   private HashMap<String,Integer> computeTransitionNgrams(Rule rule, List<HGNode> antNodes, int startNgramOrder, int endNgramOrder){
       
        if(baselineLMOrder < endNgramOrder){
          System.out.println("baselineLMOrder is too small");
          System.exit(0);
        }
       
      //====   hyperedges not under "goal item"   
      HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination
      HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed
      int[] enWords = rule.getEnglish();
     
      //a continous sequence of words due to combination; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words)
      List<Integer> words = new ArrayList<Integer>();   
     
      for(int c=0; c<enWords.length; c++){
          int curID = enWords[c];
          if(symbolTable.isNonterminal(curID)==true){//non-terminal words         
            //== get the left and right context
            int index = symbolTable.getTargetNonterminalIndex(curID);
            HGNode antNode =  antNodes.get(index);
            NgramDPState state     = (NgramDPState) antNode.getDPState(this.ngramStateID);
            //System.out.println("lm_feat_is: " + this.lm_feat_id + " ; state is: " + state);
            List<Integer>   leftContext = state.getLeftLMStateWords();
            List<Integer>   rightContext = state.getRightLMStateWords();
          if (leftContext.size() != rightContext.size()) {
            System.out.println("getAllNgrams: left and right contexts have unequal lengths");
            System.exit(1);
          }
         
          //== find new ngrams created
          for(int t : leftContext)
              words.add(t);               
            this.getNgrams(oldNgramCounts, startNgramOrder, endNgramOrder, leftContext);
           
            if(rightContext.size()>=baselineLMOrder-1){//the right and left are NOT overlapping
              this.getNgrams(oldNgramCounts, startNgramOrder, endNgramOrder, rightContext);
              this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words);
             
              //start a new chunk; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words)                   
              words.clear()
              for(int t : rightContext)
                words.add(t);           
            }
          }else{//terminal words
            words.add(curID);
          }
        }
     
      this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words);
     
     
        //=== now deduct ngram counts
      HashMap<String, Integer> res = new HashMap<String, Integer>();
      for(Map.Entry<String, Integer> entry : newNgramCounts.entrySet()){
          String ngram =  entry.getKey();
          int finalCount = entry.getValue();
          if(oldNgramCounts.containsKey(ngram)){
            finalCount -= oldNgramCounts.get(ngram);
            if(finalCount<0){
              System.out.println("error: negative count for ngram: "+ entry.getValue() +"; old: " +oldNgramCounts.get(ngram) );
              System.exit(0);
            }
          }
          if(finalCount>0)
            res.put(ngram, finalCount);
        }
      return res;
 
  
   /**work for both generative and discriminative model
    * */
   private HashMap<String,Integer>  computeFinalTransitionNgrams(HGNode antNode, int startNgramOrder, int endNgramOrder){
      
       if(baselineLMOrder < endNgramOrder){
          System.out.println("baselineLMOrder is too small");
          System.exit(0);
        }
     
      HashMap<String, Integer> res = new HashMap<String, Integer>();
      NgramDPState state     = (NgramDPState) antNode.getDPState(this.ngramStateID);
     
      List<Integer> currentNgram = new ArrayList<Integer>();
      List<Integer>   leftContext = state.getLeftLMStateWords();   
      List<Integer>   rightContext = state.getRightLMStateWords();
      if (leftContext.size() != rightContext.size()) {
        System.out.println("computeFinalTransition: left and right contexts have unequal lengths");
        System.exit(1);
      }
     
      //============ left context
      currentNgram.add(START_SYM_ID);
      for (int i = 0; i < leftContext.size(); i++) {
        int t = leftContext.get(i);
        currentNgram.add(t);
       
        if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder)
          this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram);
       
        if (currentNgram.size() == baselineLMOrder) {
          currentNgram.remove(0);
        }
      }
     
      //============ right context
      //switch context: get the last possible new ngram: this ngram can be <s> a </s>
      int tSize = currentNgram.size();
      for (int i = 0; i < rightContext.size(); i++) {//replace context
        currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
      }     
      currentNgram.add(STOP_SYM_ID);

      if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder)
        this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram);
     
      return res;
   }
  
 
 
  /**TODO: This does not work for a generative model.
   * For example, for a rule: a b x_0 c d, under generative model, we only want ngrams:
   * a; a b; c; c d;, but not b and d
   *
   * **/
  private HashMap<String, Integer> computeRuleNgrams(Rule rule, int startNgramOrder, int endNgramOrder) {
   
    if(baselineLMOrder < endNgramOrder){
        System.out.println("baselineLMOrder is too small");
        System.exit(0);
      }
     
    HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination
   
    int[] enWords = rule.getEnglish();
    List<Integer> words = new ArrayList<Integer>()
    for (int c = 0; c < enWords.length; c++) {
      int curWrd = enWords[c];
      if (symbolTable.isNonterminal(curWrd)) {
        this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words);
        words.clear();
      } else {
        words.add(curWrd);
      }
    }
    this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words);
    return newNgramCounts;
  }
   
 
  /**TODO:
   * This does not work when addStart == true or addEnd == true.
   * But, if both addStart == false or addEnd == false, then it works both for discrimnaitve and generative
   **/
  private HashMap<String, Integer>  computeFutureNgrams(NgramDPState state, int startNgramOrder, int endNgramOrder, boolean addStart, boolean addEnd) {
   
    if(baselineLMOrder < endNgramOrder){
        System.out.println("baselineLMOrder is too small");
        System.exit(0);
      }
   
    HashMap<String, Integer> res = new HashMap<String, Integer>();
   
    List<Integer> currentNgram = new ArrayList<Integer>();
    List<Integer>   leftContext = state.getLeftLMStateWords();   
    List<Integer>   rightContext = state.getRightLMStateWords();
    if (leftContext.size() != rightContext.size()) {
      System.out.println("computeFinalTransition: left and right contexts have unequal lengths");
      System.exit(1);
    }
   
    //============ left context
    if (addStart == true){//this does not really work
      /**TODO: this does not really work as
       * the new ngrams generated should be different for discriminative or generative model.
       * */
      currentNgram.add(START_SYM_ID);
    }
    //approximate the full-ngram with smaller-order ngrams
    for (int i = 0; i < leftContext.size(); i++) {
      int t = leftContext.get(i);
      currentNgram.add(t);
     
      if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder-1)
        this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram);
     
      if (currentNgram.size() == baselineLMOrder) {
        currentNgram.remove(0);
      }
    }
   
    //============ right context
    //switch context: get the last possible new ngram: this ngram can be <s> a </s>
    if (addEnd == true){//only when add_end is true, we get a complete ngram, otherwise, all ngrams in r_state are incomplete and we should do nothing
      /**TODO: this will be different for discriminative or generative model.
       * For example, for discriminative model, we may get new ngrams like "a b </s>""b </s>".
       * But, for generative model, we will get at most one ngram, whose order is baselineLMOrder
      */
    }
    return res;
  }
 
 
 
  /*
  private void getNgrams(HashMap<String,Integer> tbl,  int startNgramOrder, int endNgramOrder,  int[] wrds){
    if(useIntegerNgram)
      Ngram.getNgrams(tbl, startNgramOrder, endNgramOrder, wrds);
    else
      Ngram.getNgrams(symbolTable, tbl, startNgramOrder, endNgramOrder, wrds);
  }
  */
 
   private void getNgrams(HashMap<String,Integer> tbl,  int startNgramOrder, int endNgramOrder,  List<Integer> wrds){
    if(useIntegerNgram)
      Ngram.getNgrams(tbl, startNgramOrder, endNgramOrder, wrds);
    else
      Ngram.getNgrams(symbolTable, tbl, startNgramOrder, endNgramOrder, wrds);
   
  }
   
  

}
TOP

Related Classes of joshua.decoder.ff.lm.NgramExtractor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.