/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.oracle;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
/**
* This class implements general ways of spliting the hypergraph
* based on coarse-to-fine idea input is a hypergraph output is
* another hypergraph that has changed state structures.
*
* @author Zhifei Li, <zhifei.work@gmail.com> (Johns Hopkins University)
* @version $LastChangedDate: 2010-01-14 19:15:28 -0600 (Thu, 14 Jan 2010) $
*/
public abstract class SplitHg {
HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items =
new HashMap<HGNode, ArrayList<VirtualItem>>();
//number of items or deductions after splitting the hypergraph
public int g_num_virtual_items = 0;
public int g_num_virtual_deductions = 0;
//Note: the implementaion of the folowing two functions should call add_deduction
protected abstract void process_one_combination_axiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt);
protected abstract void process_one_combination_nonaxiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt, ArrayList<VirtualItem> l_ant_virtual_item);
//#### all the functions should be called after running split_hg(), before clearing g_tbl_split_virtual_items
public double get_best_goal_cost(HyperGraph hg, HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items){
double res = get_virtual_goal_item(hg, g_tbl_split_virtual_items).best_virtual_deduction.best_cost;
//System.out.println("best bleu is " +res);
return res;
}
public VirtualItem get_virtual_goal_item(HyperGraph original_hg, HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items){
ArrayList<VirtualItem> l_virtual_items = g_tbl_split_virtual_items.get(original_hg.goalNode);
if (l_virtual_items.size()!=1) {
// TODO: log this properly, fail properly
throw new RuntimeException("number of virtual goal items is not equal to one");
}
return l_virtual_items.get(0);
}
//get the 1best tree hg, the 1-best is ranked by the split hypergraph, but the return hypergraph is in the form of the original hg
public HyperGraph get_1best_tree_hg(HyperGraph original_hg, HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items){
VirtualItem virutal_goal_item = get_virtual_goal_item(original_hg, g_tbl_split_virtual_items);
HGNode onebest_goal_item = clone_item_with_best_deduction(virutal_goal_item);
HyperGraph res = new HyperGraph(onebest_goal_item, -1, -1, original_hg.sentID, original_hg.sentLen);//TODO: number of items/deductions
get_1best_tree_item(virutal_goal_item, onebest_goal_item);
return res;
}
private void get_1best_tree_item(VirtualItem virtual_it, HGNode onebest_item){
VirtualDeduction virtual_dt = virtual_it.best_virtual_deduction;
if(virtual_dt.l_ant_virtual_items!=null)
for(int i=0; i< virtual_dt.l_ant_virtual_items.size(); i++){
VirtualItem ant_it = (VirtualItem) virtual_dt.l_ant_virtual_items.get(i);
HGNode new_it = clone_item_with_best_deduction(ant_it);
onebest_item.bestHyperedge.getAntNodes().set(i, new_it);
get_1best_tree_item(ant_it,new_it);
}
}
//TODO: tbl_states
private static HGNode clone_item_with_best_deduction(VirtualItem virtual_it){
HGNode original_it = virtual_it.p_item;
ArrayList<HyperEdge> l_deductions = new ArrayList<HyperEdge>();
HyperEdge clone_dt = clone_deduction(virtual_it.best_virtual_deduction);
l_deductions.add(clone_dt);
return new HGNode(original_it.i, original_it.j, original_it.lhs, l_deductions, clone_dt, original_it.getDPStates());
}
private static HyperEdge clone_deduction(VirtualDeduction virtual_dt){
HyperEdge original_dt = virtual_dt.p_dt;
ArrayList<HGNode> l_ant_items = null;
// l_ant_items will be changed in get_1best_tree_item
if(original_dt.getAntNodes() != null) l_ant_items = new ArrayList<HGNode>(original_dt.getAntNodes());
HyperEdge res = new HyperEdge(original_dt.getRule(), original_dt.bestDerivationLogP, original_dt.getTransitionLogP(false), l_ant_items, original_dt.getSourcePath());
return res;
}
// ############### split hg #####
public void split_hg(HyperGraph hg){
//TODO: more pre-process in the extended class
g_tbl_split_virtual_items.clear();
g_num_virtual_items = 0;
g_num_virtual_deductions = 0;
split_item(hg.goalNode);
}
//for each original Item, get a list of VirtualItem
private void split_item(HGNode it){
if(g_tbl_split_virtual_items.containsKey(it))return;//already processed
HashMap<String, VirtualItem> virtual_item_sigs =
new HashMap<String, VirtualItem>();
//### recursive call on each deduction
if( speed_up_item(it) ){
for(HyperEdge dt : it.hyperedges){
split_deduction(dt, virtual_item_sigs, it);
}
}
//### item-specific operation
// a list of items result by splitting me
ArrayList<VirtualItem> l_virtual_items = new ArrayList<VirtualItem>();
for (String signature : virtual_item_sigs.keySet())
l_virtual_items.add(virtual_item_sigs.get(signature));
g_tbl_split_virtual_items.put(it,l_virtual_items);
g_num_virtual_items += l_virtual_items.size();
//if(virtual_item_sigs.size()!=1)System.out.println("num of split items is " + virtual_item_sigs.size());
//get_best_virtual_score(it);//debug
}
private void split_deduction(HyperEdge cur_dt, HashMap<String, VirtualItem> virtual_item_sigs, HGNode parent_item){
if(speed_up_deduction(cur_dt)==false) return;//no need to continue
//### recursively split all my ant items, get a l_split_items for each original item
if(cur_dt.getAntNodes()!=null)
for(HGNode ant_it : cur_dt.getAntNodes())
split_item(ant_it);
//### recombine the deduction
redo_combine(cur_dt, virtual_item_sigs, parent_item);
}
private void redo_combine(HyperEdge cur_dt, HashMap<String, VirtualItem> virtual_item_sigs, HGNode parent_item){
List<HGNode> l_ant_items = cur_dt.getAntNodes();
if(l_ant_items!=null){
// arity: one
if(l_ant_items.size() == 1){
HGNode it = l_ant_items.get(0);
ArrayList<VirtualItem> l_virtual_items = g_tbl_split_virtual_items.get(it);
for(VirtualItem ant_virtual_item : l_virtual_items) {
//used in combination
ArrayList<VirtualItem> l_ant_virtual_item =
new ArrayList<VirtualItem>();
l_ant_virtual_item.add(ant_virtual_item);
process_one_combination_nonaxiom(parent_item, virtual_item_sigs,
cur_dt, l_ant_virtual_item);
}
// arity: two
} else if (l_ant_items.size() == 2) {
HGNode it1 = l_ant_items.get(0);
HGNode it2 = l_ant_items.get(1);
ArrayList<VirtualItem> l_virtual_items1 = g_tbl_split_virtual_items.get(it1);
ArrayList<VirtualItem> l_virtual_items2 = g_tbl_split_virtual_items.get(it2);
for (VirtualItem virtual_it1 : l_virtual_items1) {
for (VirtualItem virtual_it2 : l_virtual_items2) {
// used in combination
ArrayList<VirtualItem> l_ant_virtual_item = new ArrayList<VirtualItem>();
l_ant_virtual_item.add(virtual_it1);
l_ant_virtual_item.add(virtual_it2);
process_one_combination_nonaxiom(parent_item, virtual_item_sigs,
cur_dt, l_ant_virtual_item);
}
}
} else {
throw new RuntimeException("Sorry, we can only deal with rules with at most TWO non-terminals");
}
// axiom case: no nonterminal
} else {
process_one_combination_axiom(parent_item, virtual_item_sigs, cur_dt);
}
}
//this function should be called by process_one_combination_axiom/process_one_combination_nonaxiom
//virtual_item_sigs is specific to parent_item
protected void add_deduction(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, VirtualDeduction t_ded, DPState dpstate, boolean maintain_onebest_only){
if (null == t_ded) {
throw new RuntimeException("deduction is null");
}
String sig = VirtualItem.get_signature(parent_item, dpstate);
VirtualItem t_virtual_item = (VirtualItem)virtual_item_sigs.get(sig);
if(t_virtual_item!=null){
t_virtual_item.add_deduction(t_ded, dpstate, maintain_onebest_only);
}else{
t_virtual_item = new VirtualItem(parent_item, dpstate, t_ded, maintain_onebest_only);
virtual_item_sigs.put(sig,t_virtual_item );
}
}
//return false if we can skip the item;
protected boolean speed_up_item(HGNode it){
return true;//e.g., if the lm state is not valid, then no need to continue
}
// return false if we can skip the deduction;
protected boolean speed_up_deduction(HyperEdge dt){
return true;// if the rule state is not valid, then no need to continue
}
protected abstract static class DPState {
protected abstract String get_signature();
};
/*In general, variables of items
* (1) list of hyperedges
* (2) best hyperedge
* (3) DP state
* (4) signature (operated on part/full of DP state)
* */
protected static class VirtualItem {
HGNode p_item =null;//pointer to the true item
ArrayList<VirtualDeduction> l_virtual_deductions = null;
VirtualDeduction best_virtual_deduction=null;
DPState dp_state;//dynamic programming state: not all the variable in dp_state are in the signature
public VirtualItem(HGNode item, DPState dstate, VirtualDeduction fdt, boolean maintain_onebest_only){
p_item = item;
add_deduction(fdt, dstate, maintain_onebest_only);
}
public void add_deduction(VirtualDeduction fdt, DPState dstate, boolean maintain_onebest_only){
if(maintain_onebest_only==false){
if(l_virtual_deductions==null) l_virtual_deductions = new ArrayList<VirtualDeduction>();;
l_virtual_deductions.add(fdt);
}
if( best_virtual_deduction==null || fdt.best_cost < best_virtual_deduction.best_cost ) {
dp_state = dstate;
best_virtual_deduction = fdt;
}
}
// not all the variable in dp_state are in the signature
public String get_signature(){
return get_signature(p_item, dp_state);
}
public static String get_signature(HGNode item, DPState dstate){
/*StringBuffer res = new StringBuffer();
//res.append(item); res.append(" ");//TODO:
res.append(dstate.get_signature());
return res.toString();*/
return dstate.get_signature();
}
}
protected static class VirtualDeduction {
HyperEdge p_dt =null;//pointer to the true deduction
ArrayList<VirtualItem> l_ant_virtual_items=null;
double best_cost=Double.POSITIVE_INFINITY;//the 1-best cost of all possible derivation: best costs of ant items + non_stateless_transition_cost + r.statelesscost
public VirtualDeduction(HyperEdge dt, ArrayList<VirtualItem> ant_items, double best_cost_in){
p_dt=dt;
l_ant_virtual_items = ant_items;
best_cost = best_cost_in;
}
public double get_transition_cost(){//note: transition_cost is already linearly interpolated
double res = best_cost;
if(l_ant_virtual_items!=null)
for(VirtualItem ant_it : l_ant_virtual_items)
res -= ant_it.best_virtual_deduction.best_cost;
return res;
}
}
}