Package joshua.discriminative.syntax_reorder

Source Code of joshua.discriminative.syntax_reorder.TrieBasedHieroGrammarScorer$TrieNode

package joshua.discriminative.syntax_reorder;

import java.util.Comparator;
import java.util.HashMap ;
import java.util.PriorityQueue;
import java.util.ArrayList;

import joshua.corpus.vocab.SymbolTable;

/*Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/

/*public interfaces
* TMGrammar: init and load the grammar
* TrieNode: match symbol for next layer
* RuleBin: get sorted rules
* Rule: rule information
* */

public class TrieBasedHieroGrammarScorer  {
  public static int MAX_N_RULES_SAME_FRENCH=40;
 
  private  int num_rule_read=0;
  private int num_rule_pruned=0;
  private int num_rule_bin=0;
  private TrieNode  root = null
 
  static private double tem_estcost =0.0;//debug
 
 
  SymbolTable symbolTable = null; //TODO
 
  /*TMGrammar is composed by Trie nodes
  Each trie node has:
  (1) RuleBin: a list of rules matching the french sides so far
  (2) a HashMap  of next-layer trie nodes, the next french word used as the key in HashMap 
  */
 
  public TrieBasedHieroGrammarScorer(){
    root = new TrieNode();
  }
   
  public TrieNode get_root(){
    return root;
  }
 
  //normalize the grammar
  public void score_grammar(){
   
  }
 
  public void dump_grammar(){
   
  }
 
  public Rule add_rule(Rule p_rule){
    num_rule_read++;       
    //######### identify the position, and insert the trinodes if necessary
    TrieNode  pos = root;
    for(int k=0; k < p_rule.french.length; k++){
      int cur_sym_id=p_rule.french[k];     
      TrieNode next_layer=pos.match_symbol(cur_sym_id);
      if(next_layer!=null){
        pos=next_layer;
      }else{   
        TrieNode tem = new TrieNode();//next layer node
        if(pos.tbl_children==nullpos.tbl_children = new HashMap ();
        pos.tbl_children.put(cur_sym_id, tem);
        pos = tem;
      }
    }
    //#########: now add the rule into the trinode
    if(pos.rule_bin==null){
      pos.rule_bin = new RuleBin();
      num_rule_bin++;
    }   
    pos.rule_bin.add_rule(p_rule);     
    return p_rule;
  }
 

  //this method should be called such that all the rules in rulebin are sorted, this will avoid synchronization for get_sorted_rules function
  private void ensure_grammar_sorted(){   
    if(root!=null)
      root.ensure_sorted();
  }
 
  protected void print_grammar(){
    System.out.println("###########Grammar###########");
    System.out.println(String.format("####num_rules: %d; num_bins: %d; num_pruned: %d; sumest_cost: %.5f",num_rule_read, num_rule_bin, num_rule_pruned, tem_estcost));
    /*if(root!=null)
      root.print_info(Support.DEBUG);*/
  }
 
  public class TrieNode
  {
    private  RuleBin rule_bin=null;
    private HashMap  tbl_children=null;
   
    public TrieNode match_symbol(int sym_id){//looking for the next layer trinode corresponding to this symbol
      if(tbl_children==null)
        return null;
      return (TrieNode) tbl_children.get(sym_id);
    }
   
    public RuleBin get_rule_bin(){
      return rule_bin;
    }
   
    public boolean is_no_child_trienodes(){
      return (tbl_children==null);
    }
   
    //recursive call, to make sure all rules are sorted
    private void ensure_sorted(){
      if(rule_bin!=null)
        rule_bin.get_sorted_rules();
      if(tbl_children!=null){
        Object[] tem = tbl_children.values().toArray();
        for(int i=0; i< tem.length; i++){         
          ((TrieNode)tem[i]).ensure_sorted();
        }
      }
    }
   
    private void print_info(int level){
      System.out.println("###########TrieNode###########");
      if(rule_bin!=null){
        System.out.println("##### RuleBin(in TrieNode) is");
        rule_bin.print_info(level);
      }
      if(tbl_children!=null){
        Object[] tem = tbl_children.values().toArray();
        for(int i=0; i< tem.length; i++){
          System.out.println("##### ChildTrieNode(in TrieNode) is");
          ((TrieNode)tem[i]).print_info(level);
        }
      }
    }
  }

  //contain all rules with the same french side (and thus same arity)
  public class RuleBin   {
    private PriorityQueue<Rule> heap_rules=null;
        private boolean sorted=false;
    private ArrayList<Rule> l_sorted_rules = new ArrayList();
    //int arity;   
    private HashMap tbl_eng_rules = new HashMap();
   
    //TODO: now, we assume this function will be called only after all the rules have been read
    //this method need to be synchronized as we will call this function only after the decoding begins
    //to avoid the synchronized method, we should call this once the grammar is finished
    //public synchronized ArrayList<Rule> get_sorted_rules(){   
    public ArrayList<Rule> get_sorted_rules(){
      if(sorted==false){//sort once       
        l_sorted_rules.clear();
        while(heap_rules.size()>0){
          Rule t_r = (Rule) heap_rules.poll();
          l_sorted_rules.add(0,t_r);
        }
        sorted=true;
        heap_rules=null;
      }
      return l_sorted_rules;
    }
   
    private void add_rule(Rule rl){     
      if(heap_rules==null)
        heap_rules = new PriorityQueue(1, Rule.FrequencyComparator);//TODO: initial capacity?
      String sig = rl.get_eng_signature();
      Rule old_rule = (Rule) tbl_eng_rules.get(sig);
      if(old_rule!=null){
        for(int i=0; i< rl.feat_scores.length; i++)
          old_rule.feat_scores[i] += rl.feat_scores[i];
        //TODO ** this is too expenisve
        heap_rules.remove(old_rule);
        heap_rules.add(old_rule);
      }else{
        tbl_eng_rules.put(sig, rl);
        heap_rules.add(rl);
      }
      num_rule_pruned += run_pruning();
    }
   
       
   
    private int run_pruning(){
      int n_pruned=0;
      while(heap_rules.size()>MAX_N_RULES_SAME_FRENCH){
        n_pruned++;
        heap_rules.poll();
      }
      return n_pruned++;
    }
   
    //normalize the rulebin during phrase extraction
    private void score_rulebin(){
     
    }
   
    private void print_info(int level){
      //Support.write_log_line(String.format("RuleBin, arity is %d",arity),level);
      ArrayList t_l = get_sorted_rules();
      for(int i=0; i< t_l.size(); i++)
        ((Rule)t_l.get(i)).print_info(symbolTable);
    }
  }

  public static class Rule{
    //Rule formate: [Phrase] ||| french ||| english ||| feature scores
    public int lhs;//tag of this rule, state to upper layer
    public int[] french;
    public int[] english;    
    public float[] feat_scores;//the feature scores for this rule
    public ArrayList alignments;
   
   
    //public int arity=0;//TODO: disk-grammar does not have this information, so, arity-penalty feature is not supported in disk-grammar
           
    public Rule(int lhs_in, int[] fr_in, int[] eng_in){
      lhs=lhs_in;
      french = fr_in;
      english = eng_in;     
   
   
    //prune grammar based on the relative frequence P(english|french)
    protected static Comparator FrequencyComparator = new Comparator() {
        public int compare(Object rule1, Object rule2) {
          float freq1=  ((Rule) rule1).feat_scores[0];//??
          float freq2=  ((Rule) rule2).feat_scores[0];
          if(freq1 < freq2)
            return -1;
          else if(freq1 == freq2)
            return 0;
          else
            return 1;
        }
    };
   
    public String get_eng_signature(){
      StringBuffer res = new StringBuffer();     
      res.append(lhs);
      res.append(" ");
      for(int i=0; i<english.length; i++){
        res.append(english[i]);
        res.append(" ");
      }     
      return res.toString();
    }
   
    public void print_info(SymbolTable symbolTable){
      //Support.write_log("Rule is: "+ lhs + " ||| " + Support.arrayToString(french, " ") + " ||| " + Support.arrayToString(english, " ") + " |||", level);
      System.out.print("Rule is: "+ symbolTable.getWord(lhs) " ||| ");
      for(int i=0; i<french.length;i++)
         System.out.print( symbolTable.getWord(french[i]) +" ");
      System.out.print("||| ");
      for(int i=0; i<english.length;i++)
         System.out.print( symbolTable.getWord(english[i]) +" ");
      System.out.print("||| ");
      for(int i=0; i< feat_scores.length; i++)
        System.out.print(" " + feat_scores[i]);
      System.out.print("\n");
    }
  }
}

TOP

Related Classes of joshua.discriminative.syntax_reorder.TrieBasedHieroGrammarScorer$TrieNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.