package joshua.discriminative.syntax_reorder;
import java.util.Comparator;
import java.util.HashMap ;
import java.util.PriorityQueue;
import java.util.ArrayList;
import joshua.corpus.vocab.SymbolTable;
/*Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/
/*public interfaces
* TMGrammar: init and load the grammar
* TrieNode: match symbol for next layer
* RuleBin: get sorted rules
* Rule: rule information
* */
public class TrieBasedHieroGrammarScorer {
public static int MAX_N_RULES_SAME_FRENCH=40;
private int num_rule_read=0;
private int num_rule_pruned=0;
private int num_rule_bin=0;
private TrieNode root = null;
static private double tem_estcost =0.0;//debug
SymbolTable symbolTable = null; //TODO
/*TMGrammar is composed by Trie nodes
Each trie node has:
(1) RuleBin: a list of rules matching the french sides so far
(2) a HashMap of next-layer trie nodes, the next french word used as the key in HashMap
*/
public TrieBasedHieroGrammarScorer(){
root = new TrieNode();
}
public TrieNode get_root(){
return root;
}
//normalize the grammar
public void score_grammar(){
}
public void dump_grammar(){
}
public Rule add_rule(Rule p_rule){
num_rule_read++;
//######### identify the position, and insert the trinodes if necessary
TrieNode pos = root;
for(int k=0; k < p_rule.french.length; k++){
int cur_sym_id=p_rule.french[k];
TrieNode next_layer=pos.match_symbol(cur_sym_id);
if(next_layer!=null){
pos=next_layer;
}else{
TrieNode tem = new TrieNode();//next layer node
if(pos.tbl_children==null) pos.tbl_children = new HashMap ();
pos.tbl_children.put(cur_sym_id, tem);
pos = tem;
}
}
//#########: now add the rule into the trinode
if(pos.rule_bin==null){
pos.rule_bin = new RuleBin();
num_rule_bin++;
}
pos.rule_bin.add_rule(p_rule);
return p_rule;
}
//this method should be called such that all the rules in rulebin are sorted, this will avoid synchronization for get_sorted_rules function
private void ensure_grammar_sorted(){
if(root!=null)
root.ensure_sorted();
}
protected void print_grammar(){
System.out.println("###########Grammar###########");
System.out.println(String.format("####num_rules: %d; num_bins: %d; num_pruned: %d; sumest_cost: %.5f",num_rule_read, num_rule_bin, num_rule_pruned, tem_estcost));
/*if(root!=null)
root.print_info(Support.DEBUG);*/
}
public class TrieNode
{
private RuleBin rule_bin=null;
private HashMap tbl_children=null;
public TrieNode match_symbol(int sym_id){//looking for the next layer trinode corresponding to this symbol
if(tbl_children==null)
return null;
return (TrieNode) tbl_children.get(sym_id);
}
public RuleBin get_rule_bin(){
return rule_bin;
}
public boolean is_no_child_trienodes(){
return (tbl_children==null);
}
//recursive call, to make sure all rules are sorted
private void ensure_sorted(){
if(rule_bin!=null)
rule_bin.get_sorted_rules();
if(tbl_children!=null){
Object[] tem = tbl_children.values().toArray();
for(int i=0; i< tem.length; i++){
((TrieNode)tem[i]).ensure_sorted();
}
}
}
private void print_info(int level){
System.out.println("###########TrieNode###########");
if(rule_bin!=null){
System.out.println("##### RuleBin(in TrieNode) is");
rule_bin.print_info(level);
}
if(tbl_children!=null){
Object[] tem = tbl_children.values().toArray();
for(int i=0; i< tem.length; i++){
System.out.println("##### ChildTrieNode(in TrieNode) is");
((TrieNode)tem[i]).print_info(level);
}
}
}
}
//contain all rules with the same french side (and thus same arity)
public class RuleBin {
private PriorityQueue<Rule> heap_rules=null;
private boolean sorted=false;
private ArrayList<Rule> l_sorted_rules = new ArrayList();
//int arity;
private HashMap tbl_eng_rules = new HashMap();
//TODO: now, we assume this function will be called only after all the rules have been read
//this method need to be synchronized as we will call this function only after the decoding begins
//to avoid the synchronized method, we should call this once the grammar is finished
//public synchronized ArrayList<Rule> get_sorted_rules(){
public ArrayList<Rule> get_sorted_rules(){
if(sorted==false){//sort once
l_sorted_rules.clear();
while(heap_rules.size()>0){
Rule t_r = (Rule) heap_rules.poll();
l_sorted_rules.add(0,t_r);
}
sorted=true;
heap_rules=null;
}
return l_sorted_rules;
}
private void add_rule(Rule rl){
if(heap_rules==null)
heap_rules = new PriorityQueue(1, Rule.FrequencyComparator);//TODO: initial capacity?
String sig = rl.get_eng_signature();
Rule old_rule = (Rule) tbl_eng_rules.get(sig);
if(old_rule!=null){
for(int i=0; i< rl.feat_scores.length; i++)
old_rule.feat_scores[i] += rl.feat_scores[i];
//TODO ** this is too expenisve
heap_rules.remove(old_rule);
heap_rules.add(old_rule);
}else{
tbl_eng_rules.put(sig, rl);
heap_rules.add(rl);
}
num_rule_pruned += run_pruning();
}
private int run_pruning(){
int n_pruned=0;
while(heap_rules.size()>MAX_N_RULES_SAME_FRENCH){
n_pruned++;
heap_rules.poll();
}
return n_pruned++;
}
//normalize the rulebin during phrase extraction
private void score_rulebin(){
}
private void print_info(int level){
//Support.write_log_line(String.format("RuleBin, arity is %d",arity),level);
ArrayList t_l = get_sorted_rules();
for(int i=0; i< t_l.size(); i++)
((Rule)t_l.get(i)).print_info(symbolTable);
}
}
public static class Rule{
//Rule formate: [Phrase] ||| french ||| english ||| feature scores
public int lhs;//tag of this rule, state to upper layer
public int[] french;
public int[] english;
public float[] feat_scores;//the feature scores for this rule
public ArrayList alignments;
//public int arity=0;//TODO: disk-grammar does not have this information, so, arity-penalty feature is not supported in disk-grammar
public Rule(int lhs_in, int[] fr_in, int[] eng_in){
lhs=lhs_in;
french = fr_in;
english = eng_in;
}
//prune grammar based on the relative frequence P(english|french)
protected static Comparator FrequencyComparator = new Comparator() {
public int compare(Object rule1, Object rule2) {
float freq1= ((Rule) rule1).feat_scores[0];//??
float freq2= ((Rule) rule2).feat_scores[0];
if(freq1 < freq2)
return -1;
else if(freq1 == freq2)
return 0;
else
return 1;
}
};
public String get_eng_signature(){
StringBuffer res = new StringBuffer();
res.append(lhs);
res.append(" ");
for(int i=0; i<english.length; i++){
res.append(english[i]);
res.append(" ");
}
return res.toString();
}
public void print_info(SymbolTable symbolTable){
//Support.write_log("Rule is: "+ lhs + " ||| " + Support.arrayToString(french, " ") + " ||| " + Support.arrayToString(english, " ") + " |||", level);
System.out.print("Rule is: "+ symbolTable.getWord(lhs) + " ||| ");
for(int i=0; i<french.length;i++)
System.out.print( symbolTable.getWord(french[i]) +" ");
System.out.print("||| ");
for(int i=0; i<english.length;i++)
System.out.print( symbolTable.getWord(english[i]) +" ");
System.out.print("||| ");
for(int i=0; i< feat_scores.length; i++)
System.out.print(" " + feat_scores[i]);
System.out.print("\n");
}
}
}