package joshua.discriminative.training.oracle;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.semiring_parsing.AtomicSemiring;
/* Given a general hypergraph and a string (known to be contained in the hypergraph), this
* class returns a filtered hypergraph that contains only the derivations yielding that string
* The method is not exact, in other words, it may actually contains derivations that do not yield the given string
* Or, the recall (of the derivations we are seeking) is perfect, but the precision is not
* */
/**This programm assumes the transition_cost in the input hypergraph is properly set!!!!!!!!!!!!!!!!!
* */
public class StringSumInHG {
SymbolTable symbolTbl;
int lmFeatID = 0;//TODO
int baselineLMOrder = 5;//TODO
AtomicSemiring atomicSemirng = new AtomicSemiring(1,0);
//kbest
KBestExtractor kbestExtractor;
int topN= 1000000000;//to make sure the filtered hypergraph is exhausitively enumerated
boolean useUniqueNbest =false;
boolean useTreeNbest = false;
boolean addCombinedCost = true;
ApproximateFilterHGByOneString p_filter;
public StringSumInHG(SymbolTable symbol_, KBestExtractor kbestextractor, ApproximateFilterHGByOneString filter_){
symbolTbl = symbol_;
kbestExtractor = kbestextractor;
p_filter = filter_;
}
//for a given string, find the approximate and true sum
public double computeSumForString(HyperGraph hg, int sentenceID, String ref_string){
System.out.println("Now process string: " + ref_string);
int[] ref_sent_wrds_in = symbolTbl.addTerminals(ref_string);
//### filter hypergraph
HyperGraph filtedHG = p_filter.approximate_filter_hg(hg,ref_sent_wrds_in);
//debug
/*
double inside_outside_scaling_factor =1.0;
TrivialInsideOutside p_inside_outside = new TrivialInsideOutside();
p_inside_outside.run_inside_outside(filted_hg, 0, 1, inside_outside_scaling_factor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
double io_norm_constant = p_inside_outside.get_normalization_constant();
p_inside_outside.clear_state();
p_inside_outside.run_inside_outside(hg, 0, 1, inside_outside_scaling_factor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
System.out.println("global norm is " + p_inside_outside.get_normalization_constant());
p_inside_outside.clear_state();*/
//### extract all possible derivations in the filted hypergraph
ArrayList<String> nonUniqueNbestStrings = new ArrayList<String>();
kbestExtractor.lazyKBestExtractOnHG(filtedHG, null, this.topN, sentenceID, nonUniqueNbestStrings);//???????????????????????????
if(nonUniqueNbestStrings.size()>=topN){
System.out.println("number of possible derivations reaches topN, should increase its value");
System.exit(1);
}
//### check if each derivation yields the ref string, and compute the true_sum
int num_good_derivations = 0;
double good_log_sum_prob = Double.NEGATIVE_INFINITY;
int num_bad_derivations = 0;
double bad_log_sum_prob = Double.NEGATIVE_INFINITY;
for(String derivation_string : nonUniqueNbestStrings){
//System.out.println(derivation_string);
String[] fds = derivation_string.split("\\s+\\|{3}\\s+");
String hyp_string = fds[1];
double log_prob = new Double(fds[fds.length-1]);//TODO: use inside_outside_scaling_factor here
if(hyp_string.compareTo(ref_string)==0){//the same
good_log_sum_prob = atomicSemirng.add_in_atomic_semiring(good_log_sum_prob, log_prob);
num_good_derivations++;
}else{
bad_log_sum_prob = atomicSemirng.add_in_atomic_semiring(bad_log_sum_prob, log_prob);
num_bad_derivations++;
}
//System.out.println("log_prob: " + log_prob + "; sum: " + good_log_sum_prob);
}
System.out.println("good_sum: " + good_log_sum_prob + "; good_num: " + num_good_derivations + "; bad_sum: " + bad_log_sum_prob + "; bad_num: " + num_bad_derivations);
//if(Math.abs(io_norm_constant-good_log_sum_prob)>1e-3){System.out.println("Norm is not equal!!!!!!!!!!!!!");System.exit(1);}
p_filter.clear_state();
return good_log_sum_prob;
}
public static void main(String[] args) throws InterruptedException, IOException {
/*//##read configuration information
if(args.length<8){
System.out.println("wrong command, correct command should be: java Perceptron_HG is_crf lf_train_items lf_train_rules lf_orc_items lf_orc_rules f_l_num_sents f_data_sel f_model_out_prefix use_tm_feat use_lm_feat use_edge_bigram_feat_only f_feature_set use_joint_tm_lm_feature");
System.out.println("num of args is "+ args.length);
for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]);
System.exit(0);
}*/
if(args.length!=8){
System.out.println("Wrong number of parameters, it must be 7");
System.exit(1);
}
String f_test_items=args[0].trim();
String f_test_rules=args[1].trim();
int num_sents=new Integer(args[2].trim());
String f_nbest=args[3].trim();//output
String f_1best=args[4].trim();//output
int topN = new Integer(args[5].trim());//nbest size of the unique strings
int baseline_lm_order = new Integer(args[7].trim());//nbest size of the unique strings
int baseline_lm_feat_id = 0;//???????
int max_num_words =25;
SymbolTable p_symbol = new BuildinSymbol(null);
KBestExtractor kbest_extractor = new KBestExtractor(p_symbol, true, false, false, false, false, true);//????????????
ApproximateFilterHGByOneString filter = new ApproximateFilterHGByOneString(p_symbol,baseline_lm_feat_id,baseline_lm_order);
StringSumInHG p_sumer = new StringSumInHG(p_symbol, kbest_extractor, filter);
//#### process test set
BufferedWriter t_writer_nbest = FileUtilityOld.getWriteFileStream(f_nbest);
BufferedWriter t_writer_1best = FileUtilityOld.getWriteFileStream(f_1best);
System.out.println("############Process file " + f_test_items);
DiskHyperGraph dhg_test = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored
dhg_test.initRead(f_test_items, f_test_rules,null);
for(int sent_id=0; sent_id < num_sents; sent_id ++){
System.out.println("#Process sentence " + sent_id);
HyperGraph hg_test = dhg_test.readHyperGraph();
//if(sent_id==1)System.exit(1);
//generate a unique nbest of strings based on viterbi cost
ArrayList<String> nonUniqueNbestStrings = new ArrayList<String>();
kbest_extractor.lazyKBestExtractOnHG(hg_test, null, topN, sent_id, nonUniqueNbestStrings);
double max_prob = Double.NEGATIVE_INFINITY;
String max_string = "";
//chech if the sentence is too long
boolean skip=false;
for(String unique_string : nonUniqueNbestStrings){
//System.out.println(unique_string);
String[] fds = unique_string.split("\\s+\\|{3}\\s+");
String hyp_string = fds[1];
String[] wrds = hyp_string.split("\\s+");
if(wrds.length>max_num_words){
skip=true;
break;
}
}
if(skip==false){
for(String unique_string : nonUniqueNbestStrings){
//System.out.println(unique_string);
String[] fds = unique_string.split("\\s+\\|{3}\\s+");
String hyp_string = fds[1];
String[] wrds = hyp_string.split("\\s+");
if(wrds.length>max_num_words)
break;
double true_sum_prob = p_sumer.computeSumForString(hg_test, sent_id, hyp_string);
System.out.println( unique_string + " ||| " + true_sum_prob);
FileUtilityOld.writeLzf(t_writer_nbest, unique_string + " ||| " + true_sum_prob + "\n");
if(true_sum_prob>max_prob){
max_string = hyp_string;
max_prob = true_sum_prob;
}
}
}else{
System.out.println("lzf; skip sentence " + sent_id);
}
FileUtilityOld.writeLzf(t_writer_1best, max_string + "\n");
}
FileUtilityOld.closeWriteFile(t_writer_nbest);
FileUtilityOld.closeWriteFile(t_writer_1best);
}
}