/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.oracle;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.Support;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.decoder.hypergraph.ViterbiExtractor;
import joshua.util.FileUtility;
/**
* approximated BLEU
* (1) do not consider clipping effect
* (2) in the dynamic programming, do not maintain different states
* for different hyp length
* (3) brief penalty is calculated based on the avg ref length
* (4) using sentence-level BLEU, instead of doc-level BLEU
*
* @author Zhifei Li, <zhifei.work@gmail.com> (Johns Hopkins University)
* @version $LastChangedDate: 2010-02-08 13:03:13 -0600 (Mon, 08 Feb 2010) $
*/
public class OracleExtractionHG extends SplitHg {
static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>";
public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state
static String NULL_LEFT_LM_STATE_SYM="<lzflnull>";
public int NULL_LEFT_LM_STATE_SYM_ID;//used for equivelant state
static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>";
public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state
// int[] ref_sentence;//reference string (not tree)
protected int src_sent_len =0;
protected int ref_sent_len =0;
protected int g_lm_order=4; //only used for decide whether to get the LM state by this class or not in compute_state
static protected boolean do_local_ngram_clip =false;
static protected boolean maitain_length_state = false;
static protected int g_bleu_order=4;
static boolean using_left_equiv_state = true;
static boolean using_right_equiv_state = true;
//TODO Add generics to hash tables in this class
HashMap<String, Boolean> tbl_suffix = new HashMap<String, Boolean>();
HashMap<String, Boolean> tbl_prefix = new HashMap<String, Boolean>();
static PrefixGrammar grammar_prefix = new PrefixGrammar();//TODO
static PrefixGrammar grammar_suffix = new PrefixGrammar();//TODO
// key: item; value: best_deduction, best_bleu, best_len, # of n-gram match where n is in [1,4]
protected HashMap<String, Integer> tbl_ref_ngrams = new HashMap<String, Integer>();
static boolean always_maintain_seperate_lm_state = true; //if true: the virtual item maintain its own lm state regardless whether lm_order>=g_bleu_order
SymbolTable p_symbolTable;
int lm_feat_id=0; //the baseline LM feature id
/**
* Constructs a new object capable of extracting a tree
* from a hypergraph that most closely matches a provided
* oracle sentence.
* <p>
* It seems that the symbol table here should only need to
* represent monolingual terminals, plus nonterminals.
*
* @param symbolTable
* @param lm_feat_id_
*/
public OracleExtractionHG(SymbolTable symbolTable, int lm_feat_id_){
this.p_symbolTable = symbolTable;
this.lm_feat_id = lm_feat_id_;
this.BACKOFF_LEFT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM);
this.NULL_LEFT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
this.NULL_RIGHT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
}
/*for 919 sent, time_on_reading: 148797
time_on_orc_extract: 580286*/
public static void main(String[] args) throws IOException {
/*String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items";
String f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules";
String f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1";
String f_orc_out ="C:\\Users\\zli\\Documents\\mt03.orc.txt";*/
if (6 != args.length) {
System.out.println("Usage: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest");
System.out.println("num of args is "+ args.length);
for (int i = 0; i < args.length; i++) {
System.out.println("arg is: " + args[i]);
}
System.exit(1);
}
String f_hypergraphs = args[0].trim();
String f_rule_tbl = args[1].trim();
String f_ref_files = args[2].trim();
String f_orc_out = args[3].trim();
int lm_order = Integer.parseInt(args[4].trim());
boolean orc_extract_nbest = Boolean.valueOf(args[5].trim()); // oracle extraction from nbest or hg
//??????????????????????????????????????
int baseline_lm_feat_id = 0;
//??????????????????????????????????????
SymbolTable p_symbolTable = new BuildinSymbol(null);
KBestExtractor kbest_extractor = null;
int topN = 300;//TODO
boolean extract_unique_nbest = true;//TODO
boolean do_ngram_clip_nbest = true; //TODO
if (orc_extract_nbest) {
System.out.println("oracle extraction from nbest list");
kbest_extractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false, false, true);
}
BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out);
long start_time0 = System.currentTimeMillis();
long time_on_reading = 0;
long time_on_orc_extract = 0;
BufferedReader t_reader_ref = FileUtility.getReadFileStream(f_ref_files);
DiskHyperGraph dhg_read = new DiskHyperGraph(p_symbolTable, baseline_lm_feat_id, true, null);
dhg_read.initRead(f_hypergraphs, f_rule_tbl, null);
OracleExtractionHG orc_extractor = new OracleExtractionHG(p_symbolTable, baseline_lm_feat_id);
String ref_sent= null;
long start_time = System.currentTimeMillis();
int sent_id=0;
while( (ref_sent=FileUtility.read_line_lzf(t_reader_ref))!= null ){
System.out.println("############Process sentence " + sent_id);
start_time = System.currentTimeMillis();
sent_id++;
//if(sent_id>10)break;
HyperGraph hg = dhg_read.readHyperGraph();
if(hg==null)continue;
String orc_sent=null;
double orc_bleu=0;
//System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time));
time_on_reading += System.currentTimeMillis()-start_time;
start_time = System.currentTimeMillis();
if(orc_extract_nbest){
Object[] res = orc_extractor.oracle_extract_nbest(kbest_extractor, hg, topN, do_ngram_clip_nbest, ref_sent);
orc_sent = (String) res[0];
orc_bleu = (Double) res[1];
}else{
HyperGraph hg_oracle = orc_extractor.oracle_extract_hg(hg, hg.sentLen, lm_order, ref_sent);
orc_sent = ViterbiExtractor.extractViterbiString(p_symbolTable, hg_oracle.goalNode);
orc_bleu = orc_extractor.get_best_goal_cost(hg, orc_extractor.g_tbl_split_virtual_items);
time_on_orc_extract += System.currentTimeMillis()-start_time;
System.out.println("num_virtual_items: " + orc_extractor.g_num_virtual_items + " num_virtual_dts: " + orc_extractor.g_num_virtual_deductions);
//System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time));
}
orc_out.write(orc_sent+"\n");
System.out.println("orc bleu is " + orc_bleu);
}
t_reader_ref.close();
orc_out.close();
System.out.println("time_on_reading: " + time_on_reading);
System.out.println("time_on_orc_extract: " + time_on_orc_extract);
System.out.println("total running time: "
+ (System.currentTimeMillis() - start_time0));
}
//find the oracle hypothesis in the nbest list
public Object[] oracle_extract_nbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){
if(hg.goalNode==null) return null;
kbest_extractor.resetState();
int next_n=0;
double orc_bleu=-1;
String orc_sent=null;
while(true){
String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//?????????
if(hyp_sent==null || next_n > n) break;
double t_bleu = compute_sentence_bleu(this.p_symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4);
if(t_bleu>orc_bleu){
orc_bleu = t_bleu;
orc_sent = hyp_sent;
}
}
System.out.println("Oracle sent: " + orc_sent);
System.out.println("Oracle bleu: " + orc_bleu);
Object[] res = new Object[2];
res[0]=orc_sent;
res[1]=orc_bleu;
return res;
}
public HyperGraph oracle_extract_hg(HyperGraph hg, int src_sent_len_in, int lm_order,
String ref_sent_str)
{
int[] ref_sent = this.p_symbolTable.addTerminals(ref_sent_str.split("\\s+"));
g_lm_order=lm_order;
src_sent_len = src_sent_len_in;
ref_sent_len = ref_sent.length;
tbl_ref_ngrams.clear();
get_ngrams(tbl_ref_ngrams, g_bleu_order, ref_sent, false);
if(using_left_equiv_state || using_right_equiv_state){
tbl_prefix.clear(); tbl_suffix.clear();
setup_prefix_suffix_tbl(ref_sent, g_bleu_order, tbl_prefix, tbl_suffix);
setup_prefix_suffix_grammar(ref_sent, g_bleu_order, grammar_prefix, grammar_suffix);//TODO
}
split_hg(hg);
//System.out.println("best bleu is " + get_best_goal_cost( hg, g_tbl_split_virtual_items));
return get_1best_tree_hg(hg, g_tbl_split_virtual_items);
}
/*This procedure does
* (1) identify all possible match
* (2) add a new deduction for each matches*/
protected void process_one_combination_axiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt){
if (null == cur_dt.getRule()) {
throw new RuntimeException("error null rule in axiom");
}
double avg_ref_len = (parent_item.j-parent_item.i>=src_sent_len) ? ref_sent_len : (parent_item.j-parent_item.i)*ref_sent_len*1.0/src_sent_len;//avg len?
double bleu_score[] = new double[1];
DPStateOracle dps = compute_state(parent_item, cur_dt, null, tbl_ref_ngrams, do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
VirtualDeduction t_dt = new VirtualDeduction(cur_dt, null, -bleu_score[0]);//cost: -best_bleu
g_num_virtual_deductions++;
add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
}
/*This procedure does
* (1) create a new deduction (based on cur_dt and ant_virtual_item)
* (2) find whether an Item can contain this deduction (based on virtual_item_sigs which is a hashmap specific to a parent_item)
* (2.1) if yes, add the deduction,
* (2.2) otherwise
* (2.2.1) create a new item
* (2.2.2) and add the item into virtual_item_sigs
**/
protected void process_one_combination_nonaxiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt, ArrayList<VirtualItem> l_ant_virtual_item){
if (null == l_ant_virtual_item) {
throw new RuntimeException("wrong call in process_one_combination_nonaxiom");
}
double avg_ref_len = (parent_item.j-parent_item.i>=src_sent_len) ? ref_sent_len : (parent_item.j-parent_item.i)*ref_sent_len*1.0/src_sent_len;//avg len?
double bleu_score[] = new double[1];
DPStateOracle dps = compute_state(parent_item, cur_dt, l_ant_virtual_item, tbl_ref_ngrams, do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
VirtualDeduction t_dt = new VirtualDeduction(cur_dt, l_ant_virtual_item, -bleu_score[0]);//cost: -best_bleu
g_num_virtual_deductions++;
add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
}
//DPState maintain all the state information at an item that is required during dynamic programming
protected static class DPStateOracle extends DPState {
int best_len; //this may not be used in the signature
int[] ngram_matches;
int[] left_lm_state;
int[] right_lm_state;
public DPStateOracle(int blen, int[] matches, int[] left, int[] right){
best_len = blen;
ngram_matches = matches;
left_lm_state = left;
right_lm_state = right;
}
protected String get_signature() {
StringBuffer res = new StringBuffer();
if (maitain_length_state) {
res.append(best_len);
res.append(' ');
}
if (null != left_lm_state) { // goal-item have null state
for (int i = 0; i < left_lm_state.length; i++) {
res.append(left_lm_state[i]);
res.append(' ');
}
}
res.append("lzf ");
if (null != right_lm_state) { // goal-item have null state
for (int i = 0; i < right_lm_state.length; i++) {
res.append(right_lm_state[i]);
res.append(' ');
}
}
//if(left_lm_state==null || right_lm_state==null)System.out.println("sig is: " + res.toString());
return res.toString();
}
protected void print(){
StringBuffer res = new StringBuffer();
res.append("DPstate: best_len: ");
res.append(best_len);
for(int i=0; i<ngram_matches.length; i++){
res.append("; ngram: ");
res.append(ngram_matches[i]);
}
System.out.println(res.toString());
}
}
// ########################## commmon funcions #####################
//based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state
//get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS
protected DPStateOracle compute_state(
HGNode parent_item, HyperEdge dt, ArrayList<VirtualItem> l_ant_virtual_item, HashMap<String,Integer> tbl_ref_ngrams,
boolean do_local_ngram_clip, int lm_order, double ref_len, double[] bleu_score, HashMap<String, Boolean> tbl_suffix, HashMap<String, Boolean> tbl_prefix
) {
//##### deductions under "goal item" does not have rule
if (null == dt.getRule()) {
if (l_ant_virtual_item.size() != 1) {
throw new RuntimeException("error deduction under goal item have more than one item");
}
bleu_score[0] = -l_ant_virtual_item.get(0).best_virtual_deduction.best_cost;
return new DPStateOracle(0, null, null,null); // no DPState at all
}
//################## deductions *not* under "goal item"
HashMap<String, Integer> new_ngram_counts = new HashMap<String, Integer>();//new ngrams created due to the combination
HashMap<String, Integer> old_ngram_counts = new HashMap<String, Integer>();//the ngram that has already been computed
int total_hyp_len =0;
int[] num_ngram_match = new int[g_bleu_order];
int[] en_words = dt.getRule().getEnglish();
//####calulate new and old ngram counts, and len
ArrayList<Integer> words= new ArrayList<Integer>();
// used for compute left- and right- lm state
ArrayList<Integer> left_state_sequence = null;
// used for compute left- and right- lm state
ArrayList<Integer> right_state_sequence = null;
int correct_lm_order = lm_order;
if (always_maintain_seperate_lm_state || lm_order < g_bleu_order) {
left_state_sequence = new ArrayList<Integer>();
right_state_sequence = new ArrayList<Integer>();
correct_lm_order = g_bleu_order; // if lm_order is smaller than g_bleu_order, we will get the lm state by ourself
}
//#### get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match
for (int c = 0; c < en_words.length; c++) {
int c_id = en_words[c];
if (this.p_symbolTable.isNonterminal(c_id)) {
int index = this.p_symbolTable.getTargetNonterminalIndex(c_id);
DPStateOracle ant_state = (DPStateOracle) l_ant_virtual_item.get(index).dp_state;
total_hyp_len += ant_state.best_len;
for (int t = 0; t < g_bleu_order; t++) {
num_ngram_match[t] += ant_state.ngram_matches[t];
}
int[] l_context = ant_state.left_lm_state;
int[] r_context = ant_state.right_lm_state;
for (int t : l_context) { // always have l_context
words.add(t);
if (null != left_state_sequence
&& left_state_sequence.size() < g_bleu_order-1) {
left_state_sequence.add(t);
}
}
get_ngrams(old_ngram_counts, g_bleu_order, l_context, true);
if (r_context.length >= correct_lm_order-1) { // the right and left are NOT overlapping
get_ngrams(new_ngram_counts, g_bleu_order, words, true);
get_ngrams(old_ngram_counts, g_bleu_order, r_context, true);
words.clear();//start a new chunk
if (null != right_state_sequence) {
right_state_sequence.clear();
}
for (int t : r_context) {
words.add(t);
}
}
if (null != right_state_sequence) {
for(int t : r_context) {
right_state_sequence.add(t);
}
}
} else {
words.add(c_id);
total_hyp_len += 1;
if (null != left_state_sequence
&& left_state_sequence.size() < g_bleu_order-1) {
left_state_sequence.add(c_id);
}
if (null != right_state_sequence) {
right_state_sequence.add(c_id);
}
}
}
get_ngrams(new_ngram_counts, g_bleu_order, words, true);
//####now deduct ngram counts
for (String ngram : new_ngram_counts.keySet()) {
if (tbl_ref_ngrams.containsKey(ngram)) {
int final_count = (Integer)new_ngram_counts.get(ngram);
if (old_ngram_counts.containsKey(ngram)) {
final_count -= (Integer)old_ngram_counts.get(ngram);
// BUG: Whoa, is that an actual hard-coded ID in there? :)
if (final_count < 0) {
throw new RuntimeException("negative count for ngram: "
+ this.p_symbolTable.getWord(11844)
+ "; new: " + new_ngram_counts.get(ngram)
+ "; old: " + old_ngram_counts.get(ngram) );
}
}
if (final_count > 0) { // TODO: not correct/global ngram clip
if (do_local_ngram_clip) {
// BUG: use joshua.util.Regex.spaces.split(...)
num_ngram_match[ngram.split("\\s+").length-1] +=
Support.findMin(final_count, (Integer)tbl_ref_ngrams.get(ngram));
} else {
// BUG: use joshua.util.Regex.spaces.split(...)
num_ngram_match[ngram.split("\\s+").length-1] += final_count; //do not do any cliping
}
}
}
}
//####now calculate the BLEU score and state
int[] left_lm_state = null;
int[] right_lm_state = null;
if (!always_maintain_seperate_lm_state && lm_order >= g_bleu_order) { //do not need to change lm state, just use orignal lm state
NgramDPState state = (NgramDPState) parent_item.getDPState(this.lm_feat_id);
left_lm_state = intListToArray( state.getLeftLMStateWords() );
right_lm_state = intListToArray( state.getRightLMStateWords() );
} else {
left_lm_state = get_left_equiv_state(left_state_sequence, tbl_suffix);
right_lm_state = get_right_equiv_state(right_state_sequence, tbl_prefix);
//debug
//System.out.println("lm_order is " + lm_order);
//compare_two_int_arrays(left_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID));
//compare_two_int_arrays(right_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID));
//end
}
bleu_score[0] = compute_bleu(total_hyp_len, ref_len, num_ngram_match, g_bleu_order);
//System.out.println("blue score is " + bleu_score[0]);
return new DPStateOracle(total_hyp_len, num_ngram_match, left_lm_state, right_lm_state);
}
private int[] intListToArray(List<Integer> words){
int[] res = new int[words.size()];
int i=0;
for(int wrd : words)
res[i++] = wrd;
return res;
}
private int[] get_left_equiv_state(ArrayList<Integer> left_state_sequence,
HashMap<String, Boolean> tbl_suffix)
{
int l_size = (left_state_sequence.size()<g_bleu_order-1)? left_state_sequence.size() : (g_bleu_order-1);
int[] left_lm_state = new int[l_size];
if (!using_left_equiv_state || l_size < g_bleu_order-1) { // regular
for (int i = 0; i < l_size; i++) {
left_lm_state[i] = left_state_sequence.get(i);
}
} else {
for (int i = l_size-1; i >= 0; i--) { // right to left
if (is_a_suffix_in_tbl(left_state_sequence, 0, i, tbl_suffix)) {
//if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){
for (int j = i; j >= 0; j--) {
left_lm_state[j] = left_state_sequence.get(j);
}
break;
} else {
left_lm_state[i] = this.NULL_LEFT_LM_STATE_SYM_ID;
}
}
//System.out.println("origi left:" + Symbol.get_string(left_state_sequence) + "; equiv left:" + Symbol.get_string(left_lm_state));
}
return left_lm_state;
}
private boolean is_a_suffix_in_tbl(ArrayList<Integer> left_state_sequence,
int start_pos, int end_pos, HashMap<String, Boolean> tbl_suffix)
{
if ((Integer)left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) {
return false;
}
StringBuffer suffix = new StringBuffer();
for (int i = end_pos; i >= start_pos; i--) { // right-most first
suffix.append(left_state_sequence.get(i));
if (i > start_pos) suffix.append(' ');
}
return (Boolean) tbl_suffix.containsKey(suffix.toString());
}
// TODO: never called. remove?
private boolean is_a_suffix_in_grammar(
ArrayList<Integer> left_state_sequence,
int start_pos, int end_pos, PrefixGrammar grammar_suffix)
{
if ((Integer)left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) {
return false;
}
ArrayList<Integer> suffix = new ArrayList<Integer>();
for (int i = end_pos; i >= start_pos; i--) { // right-most first
suffix.add(left_state_sequence.get(i));
}
return grammar_suffix.contain_ngram(suffix, 0, suffix.size()-1);
}
private int[] get_right_equiv_state(
ArrayList<Integer> right_state_sequence,
HashMap<String, Boolean> tbl_prefix)
{
int r_size = (right_state_sequence.size() < g_bleu_order-1)
? right_state_sequence.size()
: (g_bleu_order-1);
int[] right_lm_state = new int[r_size];
if (!using_right_equiv_state || r_size < g_bleu_order-1) { // regular
for (int i = 0; i < r_size; i++) {
right_lm_state[i] = (Integer)right_state_sequence.get(right_state_sequence.size()-r_size+i);
}
} else {
for (int i = 0; i < r_size; i++) { // left to right
if (is_a_prefix_in_tbl(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, tbl_prefix)) {
//if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, grammar_prefix)){
for (int j = i; j < r_size; j++) {
right_lm_state[j] = (Integer)right_state_sequence.get(right_state_sequence.size()-r_size+j);
}
break;
} else {
right_lm_state[i] = this.NULL_RIGHT_LM_STATE_SYM_ID;
}
}
//System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+ "; equiv right:" + Symbol.get_string(right_lm_state));
}
return right_lm_state;
}
private boolean is_a_prefix_in_tbl(ArrayList<Integer> right_state_sequence,
int start_pos, int end_pos, HashMap<String, Boolean> tbl_prefix)
{
if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) {
return false;
}
StringBuffer prefix = new StringBuffer();
for (int i = start_pos; i <= end_pos; i++) {
prefix.append(right_state_sequence.get(i));
if (i < end_pos) prefix.append(' ');
}
return (Boolean) tbl_prefix.containsKey(prefix.toString());
}
// TODO: never called. remove?
private boolean isAPrefixInGrammar(
ArrayList<Integer> right_state_sequence,
int start_pos, int end_pos, PrefixGrammar gr_prefix)
{
if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) {
return false;
}
return gr_prefix.contain_ngram(right_state_sequence, start_pos, end_pos);
}
public static void compare_two_int_arrays(int[] a, int[] b) {
if (a.length != b.length) {
throw new RuntimeException("two arrays do not have same size");
}
for (int i = 0; i<a.length; i++) {
if (a[i] != b[i]) {
throw new RuntimeException("elements in two arrays are not same");
}
}
}
//sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
public static double compute_bleu(int hyp_len, double ref_len, int[] num_ngram_match, int bleu_order){
if (hyp_len <= 0 || ref_len <= 0){
throw new RuntimeException("ref or hyp is zero len");
}
double res = 0;
double wt = 1.0/bleu_order;
double prec = 0;
double smooth_factor=1.0;
for (int t = 0; t < bleu_order && t < hyp_len; t++) {
if (num_ngram_match[t] > 0) {
prec += wt*Math.log(num_ngram_match[t]*1.0/(hyp_len-t));
} else {
smooth_factor *= 0.5;//TODO
prec += wt*Math.log(smooth_factor/(hyp_len-t));
}
}
double bp = (hyp_len>=ref_len) ? 1.0 : Math.exp(1-ref_len/hyp_len);
res = bp*Math.exp(prec);
//System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res);
return res;
}
//accumulate ngram counts into tbl
public void get_ngrams(HashMap<String,Integer> tbl, int order, int[] wrds, boolean ignore_null_equiv_symbol) {
for (int i = 0; i < wrds.length; i++) {
for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j]
boolean contain_null = false;
StringBuffer ngram = new StringBuffer();
for (int k = i; k <= i+j; k++) {
if (wrds[k] == this.NULL_LEFT_LM_STATE_SYM_ID
|| wrds[k] == this.NULL_RIGHT_LM_STATE_SYM_ID) {
contain_null = true;
if (ignore_null_equiv_symbol) break;
}
ngram.append(wrds[k]);
if (k < i+j) ngram.append(' ');
}
if (ignore_null_equiv_symbol && contain_null) continue; // skip this ngram
String ngram_str = ngram.toString();
if (tbl.containsKey(ngram_str)) {
tbl.put(ngram_str, (Integer)tbl.get(ngram_str)+1);
} else {
tbl.put(ngram_str, 1);
}
}
}
}
/** accumulate ngram counts into tbl. */
public void get_ngrams(HashMap<String, Integer> tbl, int order,
ArrayList<Integer> wrds, boolean ignore_null_equiv_symbol)
{
for (int i = 0; i < wrds.size(); i++) {
// ngram: [i,i+j]
for (int j = 0; j < order && j+i < wrds.size(); j++) {
boolean contain_null = false;
StringBuffer ngram = new StringBuffer();
for (int k = i; k <= i+j; k++) {
int t_wrd = (Integer) wrds.get(k);
if (t_wrd == this.NULL_LEFT_LM_STATE_SYM_ID
|| t_wrd == this.NULL_RIGHT_LM_STATE_SYM_ID) {
contain_null = true;
if (ignore_null_equiv_symbol) break;
}
ngram.append(t_wrd);
if (k < i+j) ngram.append(' ');
}
// skip this ngram
if (ignore_null_equiv_symbol && contain_null) continue;
String ngram_str = ngram.toString();
if (tbl.containsKey(ngram_str)) {
tbl.put(ngram_str, (Integer)tbl.get(ngram_str)+1);
} else {
tbl.put(ngram_str, 1);
}
}
}
}
//do_ngram_clip: consider global n-gram clip
public double compute_sentence_bleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order) {
// BUG: use joshua.util.Regex.spaces.split(...)
int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+"));
int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+"));
return compute_sentence_bleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);
}
public double compute_sentence_bleu(int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order) {
double res_bleu = 0;
int order = 4;
HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer>();
get_ngrams(ref_ngram_tbl, order, ref_sent, false);
HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer>();
get_ngrams(hyp_ngram_tbl, order, hyp_sent, false);
int[] num_ngram_match = new int[order];
for (String ngram : hyp_ngram_tbl.keySet()) {
if (ref_ngram_tbl.containsKey(ngram)) {
if (do_ngram_clip) {
// BUG: use joshua.util.Regex.spaces.split(...)
num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin((Integer)ref_ngram_tbl.get(ngram),(Integer)hyp_ngram_tbl.get(ngram)); //ngram clip
} else {
// BUG: use joshua.util.Regex.spaces.split(...)
num_ngram_match[ngram.split("\\s+").length-1] += (Integer)hyp_ngram_tbl.get(ngram);//without ngram count clipping
}
}
}
res_bleu = compute_bleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
//System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
// " " + num_ngram_match[2] + " " +num_ngram_match[3]);
return res_bleu;
}
// TODO: never called, remove?
private static void printState(Object[] state) {
System.out.println("State is");
for (int i = 0; i < state.length; i++) {
System.out.print(state[i] + " ---- ");
}
System.out.println();
}
//#### equivalent lm stuff ############
public static void setup_prefix_suffix_tbl(int[] wrds, int order,
HashMap<String, Boolean> prefix_tbl, HashMap<String, Boolean> suffix_tbl)
{
for (int i = 0; i < wrds.length; i++) {
for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j]
StringBuffer ngram = new StringBuffer();
//### prefix
for (int k = i; k < i+j; k++) { // all ngrams [i,i+j-1]
ngram.append(wrds[k]);
prefix_tbl.put(ngram.toString(), true);
ngram.append(' ');
}
//### suffix: right-most wrd first
ngram = new StringBuffer();
for (int k = i+j; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
ngram.append(wrds[k]);
suffix_tbl.put(ngram.toString(), true);//stored in reverse order
ngram.append(' ');
}
}
}
}
// #### equivalent lm stuff ############
public static void setup_prefix_suffix_grammar(int[] wrds, int order,
PrefixGrammar prefix_gr, PrefixGrammar suffix_gr)
{
for (int i = 0; i < wrds.length; i++) {
for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j]
//### prefix
prefix_gr.add_ngram(wrds, i, i+j-1);//ngram: [i,i+j-1]
//### suffix: right-most wrd first
int[] reverse_wrds = new int[j];
for (int k = i+j, t = 0; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
reverse_wrds[t++] = wrds[k];
}
suffix_gr.add_ngram(reverse_wrds, 0, j-1);
}
}
}
/* a backoff node is a hashtable, it may include:
* (1) probabilititis for next words
* (2) pointers to a next-layer backoff node (hashtable)
* (3) backoff weight for this node
* (4) suffix/prefix flag to indicate that there is ngrams start from this suffix
*/
private static class PrefixGrammar {
private static class PrefixGrammarNode extends HashMap<Integer, PrefixGrammarNode> {
private static final long serialVersionUID = 1L;
};
PrefixGrammarNode root = new PrefixGrammarNode();
//add prefix information
public void add_ngram(int[] wrds, int start_pos, int end_pos) {
//######### identify the position, and insert the trinodes if necessary
PrefixGrammarNode pos = root;
for (int k = start_pos; k <= end_pos; k++) {
int cur_sym_id = wrds[k];
PrefixGrammarNode next_layer = pos.get(cur_sym_id);
if (null != next_layer) {
pos = next_layer;
} else {
// next layer node
PrefixGrammarNode tmp = new PrefixGrammarNode();
pos.put(cur_sym_id, tmp);
pos = tmp;
}
}
}
public boolean contain_ngram(ArrayList<Integer> wrds, int start_pos, int end_pos) {
if (end_pos < start_pos) return false;
PrefixGrammarNode pos = root;
for (int k = start_pos; k <= end_pos; k++) {
int cur_sym_id = wrds.get(k);
PrefixGrammarNode next_layer = pos.get(cur_sym_id);
if (next_layer != null) {
pos = next_layer;
} else {
return false;
}
}
return true;
}
}
}