Package joshua.discriminative.training.oracle

Source Code of joshua.discriminative.training.oracle.SplitHg$DPState

package joshua.discriminative.training.oracle;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;


/* Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/

/*This class implements general ways of spliting the hypergraph based on coarse-to-fine idea
* input is a hypergraph
* output is another hypergraph that has changed state structures
* */

public abstract class SplitHg {
  HashMap<HGNode, ArrayList<VirtualItem> > g_tbl_split_virtual_items =  new HashMap<HGNode, ArrayList<VirtualItem> >();//Key: item; Value: a list of split virtual items
 
  //number of items or deductions after splitting the hypergraph
  public int g_num_virtual_items = 0;
  public int g_num_virtual_deductions = 0;
 
  //Note: the implementaion of the folowing two functions should call add_deduction
  protected abstract void process_one_combination_axiom(HGNode parent_item, HashMap virtual_item_sigs, HyperEdge cur_dt);
  protected abstract void process_one_combination_nonaxiom(HGNode parent_item, HashMap virtual_item_sigs, HyperEdge cur_dt, ArrayList<VirtualItem> l_ant_virtual_item);
 
  //#### all the functions should be called after running split_hg(), before clearing g_tbl_split_virtual_items
  public double get_best_goal_cost(HyperGraph hg, HashMap g_tbl_split_virtual_items){
    double res = get_virtual_goal_item(hg, g_tbl_split_virtual_items).best_virtual_deduction.best_cost;
    //System.out.println("best bleu is " +res);
    return res;
  }

 
 
  public VirtualItem get_virtual_goal_item(HyperGraph original_hg, HashMap g_tbl_split_virtual_items){
    ArrayList l_virtual_items = (ArrayList) g_tbl_split_virtual_items.get(original_hg.goalNode);
    if(l_virtual_items.size()!=1){System.out.println("number of virtual goal items is not equal to one"); System.exit(0);}
    return (VirtualItem)l_virtual_items.get(0);
  }
 
 
  //get the 1best tree hg, the 1-best is ranked by the split hypergraph, but the return hypergraph is in the form of the original hg 
  public HyperGraph get_1best_tree_hg(HyperGraph original_hg, HashMap g_tbl_split_virtual_items){
    VirtualItem virutal_goal_item =  get_virtual_goal_item(original_hg, g_tbl_split_virtual_items);
    HGNode onebest_goal_item = clone_item_with_best_deduction(virutal_goal_item);   
    HyperGraph res = new HyperGraph(onebest_goal_item, -1, -1, original_hg.sentID, original_hg.sentLen);//TODO: number of items/deductions
    get_1best_tree_item(virutal_goal_item, onebest_goal_item);
    return res;
  }
 
  private void get_1best_tree_item(VirtualItem virtual_it, HGNode onebest_item){ 
    VirtualDeduction virtual_dt = virtual_it.best_virtual_deduction;
    if(virtual_dt.l_ant_virtual_items!=null)
      for(int i=0; i< virtual_dt.l_ant_virtual_items.size(); i++){
        VirtualItem ant_it = (VirtualItem) virtual_dt.l_ant_virtual_items.get(i);
        HGNode new_it = clone_item_with_best_deduction(ant_it);
        onebest_item.bestHyperedge.getAntNodes().set(i, new_it);
        get_1best_tree_item(ant_it,new_it)
      }   
  } 
 
  //TODO: tbl_states
  private static HGNode clone_item_with_best_deduction(VirtualItem virtual_it){
    HGNode original_it = virtual_it.p_item;
    ArrayList<HyperEdge> l_deductions = new ArrayList<HyperEdge>();   
    HyperEdge clone_dt = clone_deduction(virtual_it.best_virtual_deduction);
    l_deductions.add(clone_dt);
    return new HGNode(original_it.i, original_it.j, original_it.lhs,  l_deductions, clone_dt, original_it.getDPStates())
  }
 

  private static HyperEdge clone_deduction(VirtualDeduction virtual_dt){
    HyperEdge originalEdge = virtual_dt.p_dt;
    ArrayList<HGNode> l_ant_items = null;
    if(originalEdge.getAntNodes()!=null)
      l_ant_items = new ArrayList<HGNode>(originalEdge.getAntNodes());//l_ant_items will be changed in get_1best_tree_item
    HyperEdge res = new HyperEdge(originalEdge.getRule(), originalEdge.bestDerivationLogP, originalEdge.getTransitionLogP(false), l_ant_items, null);
    return res;
  }
 
 
 
//  ############### split hg ##### 
    public  void split_hg(HyperGraph hg){ 
      //TODO: more pre-process in the extended class
      g_tbl_split_virtual_items.clear();
      g_num_virtual_items = 0;
      g_num_virtual_deductions = 0;       
      split_item(hg.goalNode)
   
   
    //for each original Item, get a list of VirtualItem
    private void split_item(HGNode it){
      if(g_tbl_split_virtual_items.containsKey(it))
        return;//already processed
      HashMap virtual_item_sigs = new HashMap();
      //### recursive call on each deduction
      if( speed_up_item(it) ){
        for(HyperEdge dt : it.hyperedges){         
          split_deduction(dt, virtual_item_sigs, it);
        }
      }
      //### item-specific operation
      ArrayList l_virtual_items = new ArrayList();//a list of items result by splitting me
      for(Iterator iter = virtual_item_sigs.keySet().iterator(); iter.hasNext();)
        l_virtual_items.add(virtual_item_sigs.get(iter.next()));
      g_tbl_split_virtual_items.put(it,l_virtual_items);
      g_num_virtual_items += l_virtual_items.size();
      //if(virtual_item_sigs.size()!=1)System.out.println("num of split items is " + virtual_item_sigs.size());
      //get_best_virtual_score(it);//debug
   
   
    private void split_deduction(HyperEdge cur_dt, HashMap virtual_item_sigs, HGNode parent_item){
      if(speed_up_deduction(cur_dt)==false) return;//no need to continue
     
      //### recursively split all my ant items, get a l_split_items for each original item
      if(cur_dt.getAntNodes()!=null)
        for(HGNode ant_it : cur_dt.getAntNodes())
          split_item(ant_it);
     
      //### recombine the deduction
      redo_combine(cur_dt, virtual_item_sigs, parent_item);
   
     
    private void redo_combine(HyperEdge cur_dt, HashMap virtual_item_sigs, HGNode parent_item){
      List<HGNode> l_ant_items = cur_dt.getAntNodes();
      if(l_ant_items!=null){     
        if(l_ant_items.size()==1){//arity: one
          HGNode it = (HGNode)l_ant_items.get(0);
          ArrayList<VirtualItem> l_virtual_items = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it);       
          for(VirtualItem ant_virtual_item: l_virtual_items){
            ArrayList l_ant_virtual_item = new ArrayList();//used in combination
            l_ant_virtual_item.add(ant_virtual_item);
            process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt,  l_ant_virtual_item);
          }
        }else if(l_ant_items.size()==2){//arity: two
          HGNode it1 = (HGNode)l_ant_items.get(0);
          HGNode it2 = (HGNode)l_ant_items.get(1);
          ArrayList<VirtualItem> l_virtual_items1 = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it1);
          ArrayList<VirtualItem> l_virtual_items2 = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it2);
          for(VirtualItem virtual_it1: l_virtual_items1){
            for(VirtualItem virtual_it2: l_virtual_items2){
              ArrayList l_ant_virtual_item = new ArrayList();//used in combination
              l_ant_virtual_item.add(virtual_it1);
              l_ant_virtual_item.add(virtual_it2);
              process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt,  l_ant_virtual_item);
            }         
          }
        }else{
          System.out.println("Sorry, we can only deal with rules with at most TWO non-terminals");System.exit(0);
        }     
      }else{//axiom case: no nonterminal
        //System.out.println("no ant!!! " + Symbol.get_string(cur_dt.get_rule().english));
        process_one_combination_axiom(parent_item, virtual_item_sigs, cur_dt);
      }   
    }
   
    //this function should be called by process_one_combination_axiom/process_one_combination_nonaxiom
    //virtual_item_sigs is specific to parent_item
    protected  void add_deduction(HGNode parent_item, HashMap virtual_item_sigs, VirtualDeduction t_ded, DPState dpstate, boolean maintain_onebest_only){
      if(t_ded==null) {System.out.println("deduction is null"); System.exit(0);}
      String sig = VirtualItem.get_signature(parent_item, dpstate);
      VirtualItem t_virtual_item = (VirtualItem)virtual_item_sigs.get(sig);
      if(t_virtual_item!=null){
        t_virtual_item.add_deduction(t_ded, dpstate, maintain_onebest_only);
      }else{
        t_virtual_item = new VirtualItem(parent_item, dpstate, t_ded, maintain_onebest_only);
        virtual_item_sigs.put(sig,t_virtual_item );
      }   
    }
   
    //return false if we can skip the item;
    protected  boolean speed_up_item(HGNode it){
      return true;//e.g., if the lm state is not valid, then no need to continue
    }
   
//    return false if we can skip the deduction;
    protected  boolean speed_up_deduction(HyperEdge dt){
      return true;// if the rule state is not valid, then no need to continue 
    }
   
    protected abstract static class DPState {
      protected abstract String get_signature();
    };
   
   
    /*In general, variables of items
     * (1) list of hyperedges
     * (2) best hyperedge
     * (3) DP state
     * (4) signature (operated on part/full of DP state)
     * */
   
    protected static class VirtualItem {
      HGNode p_item =null;//pointer to the true item
      ArrayList<VirtualDeduction> l_virtual_deductions = null;
      VirtualDeduction best_virtual_deduction=null;
      DPState dp_state;//dynamic programming state: not all the variable in dp_state are in the signature
     
      public VirtualItem(HGNode item, DPState dstate, VirtualDeduction fdt, boolean maintain_onebest_only){
        p_item = item;
        add_deduction(fdt, dstate, maintain_onebest_only);
      }
     
     
      public void add_deduction(VirtualDeduction fdt, DPState dstate, boolean maintain_onebest_only){
        if(maintain_onebest_only==false){
          if(l_virtual_deductions==null) l_virtual_deductions = new ArrayList<VirtualDeduction>();;
          l_virtual_deductions.add(fdt);
        }
        if( best_virtual_deduction==null || fdt.best_cost < best_virtual_deduction.best_cost ) {
          best_virtual_deduction = fdt;
          dp_state = dstate;
        }
      }
     
      // not all the variable in dp_state are in the signature
      public String get_signature(){
        return get_signature(p_item, dp_state);
      }
     
      public static String get_signature(HGNode item, DPState dstate){
        /*StringBuffer res = new StringBuffer();
        //res.append(item); res.append(" ");//TODO:
        res.append(dstate.get_signature());
        return res.toString();*/
        return dstate.get_signature();
      }
    }
   
    protected static class VirtualDeduction {
      HyperEdge p_dt =null;//pointer to the true deduction
      ArrayList<VirtualItem> l_ant_virtual_items=null;
      double best_cost=Double.POSITIVE_INFINITY;//the 1-best cost of all possible derivation: best costs of ant items + non_stateless_transition_cost + r.statelesscost
     
      public VirtualDeduction(HyperEdge dt, ArrayList<VirtualItem> ant_items, double best_cost_in){
        p_dt=dt;
        l_ant_virtual_items = ant_items;
        best_cost = best_cost_in;
      }
     
      public double get_transition_cost(){//note: transition_cost is already linearly interpolated
        double res = best_cost;
        if(l_ant_virtual_items!=null
          for(VirtualItem ant_it : l_ant_virtual_items)
            res -= ant_it.best_virtual_deduction.best_cost;
        return res;
      }
    }   
   
}
TOP

Related Classes of joshua.discriminative.training.oracle.SplitHg$DPState

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.