Package joshua.discriminative.training.oracle

Source Code of joshua.discriminative.training.oracle.StringSumInHG

package joshua.discriminative.training.oracle;

import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.semiring_parsing.AtomicSemiring;



/* Given a general hypergraph and a string (known to be contained in the hypergraph), this
* class returns a filtered hypergraph that contains only the derivations yielding that string
* The method is not exact, in other words, it may actually contains derivations that do not yield the given string
* Or, the recall (of the derivations we are seeking) is perfect, but the precision is not
* */


/**This programm assumes the transition_cost in the input hypergraph is properly set!!!!!!!!!!!!!!!!!
* */

public class StringSumInHG {
 
  SymbolTable symbolTbl;
  int lmFeatID = 0;//TODO
  int baselineLMOrder = 5;//TODO
  AtomicSemiring atomicSemirng = new AtomicSemiring(1,0);
 
  //kbest
  KBestExtractor kbestExtractor;
  int topN= 1000000000;//to make sure the filtered hypergraph is exhausitively enumerated
  boolean useUniqueNbest =false;
  boolean useTreeNbest = false;
  boolean addCombinedCost = true
 
  ApproximateFilterHGByOneString p_filter;
 
  public StringSumInHG(SymbolTable symbol_, KBestExtractor kbestextractor, ApproximateFilterHGByOneString filter_){
    symbolTbl = symbol_;
    kbestExtractor = kbestextractor;
    p_filter = filter_;
  }
 
  //for a given string, find the approximate and true sum
  public double computeSumForString(HyperGraph hg, int sentenceID, String ref_string){
    System.out.println("Now process string: " + ref_string);
    int[] ref_sent_wrds_in = symbolTbl.addTerminals(ref_string);
   
    //### filter hypergraph
    HyperGraph filtedHG = p_filter.approximate_filter_hg(hg,ref_sent_wrds_in);
   
    //debug
    /*
    double inside_outside_scaling_factor =1.0;
    TrivialInsideOutside p_inside_outside = new TrivialInsideOutside();
    p_inside_outside.run_inside_outside(filted_hg, 0, 1, inside_outside_scaling_factor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
    double io_norm_constant =  p_inside_outside.get_normalization_constant();   
    p_inside_outside.clear_state();
   
    p_inside_outside.run_inside_outside(hg, 0, 1, inside_outside_scaling_factor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
    System.out.println("global norm is " + p_inside_outside.get_normalization_constant());
    p_inside_outside.clear_state();*/
   
    //### extract all possible derivations in the filted hypergraph
    ArrayList<String> nonUniqueNbestStrings = new ArrayList<String>();
    kbestExtractor.lazyKBestExtractOnHG(filtedHG, null, this.topN,  sentenceID, nonUniqueNbestStrings);//???????????????????????????
    if(nonUniqueNbestStrings.size()>=topN){
      System.out.println("number of possible derivations reaches topN, should increase its value");
      System.exit(1);
    }
   
    //### check if each derivation yields the ref string, and compute the true_sum
    int num_good_derivations = 0;
    double good_log_sum_prob = Double.NEGATIVE_INFINITY;
    int num_bad_derivations = 0;   
    double bad_log_sum_prob = Double.NEGATIVE_INFINITY;
    for(String derivation_string : nonUniqueNbestStrings){
      //System.out.println(derivation_string);
      String[] fds = derivation_string.split("\\s+\\|{3}\\s+");
      String hyp_string = fds[1];
      double log_prob = new Double(fds[fds.length-1]);//TODO: use inside_outside_scaling_factor here
     
      if(hyp_string.compareTo(ref_string)==0){//the same
        good_log_sum_prob = atomicSemirng.add_in_atomic_semiring(good_log_sum_prob, log_prob);
        num_good_derivations++;
      }else{
        bad_log_sum_prob = atomicSemirng.add_in_atomic_semiring(bad_log_sum_prob, log_prob);
        num_bad_derivations++;
      }
      //System.out.println("log_prob: " + log_prob + "; sum: " + good_log_sum_prob);
    }
    System.out.println("good_sum: " + good_log_sum_prob + "; good_num: " + num_good_derivations + "; bad_sum: " + bad_log_sum_prob + "; bad_num: " + num_bad_derivations);
    //if(Math.abs(io_norm_constant-good_log_sum_prob)>1e-3){System.out.println("Norm is not equal!!!!!!!!!!!!!");System.exit(1);}
    p_filter.clear_state();
    return good_log_sum_prob;
  }
 

  public static void main(String[] args) throws InterruptedException, IOException {
    /*//##read configuration information
    if(args.length<8){
      System.out.println("wrong command, correct command should be: java Perceptron_HG is_crf lf_train_items lf_train_rules lf_orc_items lf_orc_rules f_l_num_sents f_data_sel f_model_out_prefix use_tm_feat use_lm_feat use_edge_bigram_feat_only f_feature_set use_joint_tm_lm_feature");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]);
      System.exit(0);   
    }*/
       
    if(args.length!=8){
      System.out.println("Wrong number of parameters, it must be  7");
      System.exit(1);
    }
   
   
    String f_test_items=args[0].trim();
    String f_test_rules=args[1].trim();
    int num_sents=new Integer(args[2].trim());
    String f_nbest=args[3].trim();//output
    String f_1best=args[4].trim();//output
    int topN = new Integer(args[5].trim());//nbest size of the unique strings
   
    int baseline_lm_order = new Integer(args[7].trim());//nbest size of the unique strings
 
    int baseline_lm_feat_id = 0;//???????
   
    int max_num_words =25;
   
    SymbolTable p_symbol = new BuildinSymbol(null);
    KBestExtractor kbest_extractor = new KBestExtractor(p_symbol, true, false, false, false,  false, true);//????????????
    ApproximateFilterHGByOneString filter = new ApproximateFilterHGByOneString(p_symbol,baseline_lm_feat_id,baseline_lm_order);
    StringSumInHG p_sumer = new StringSumInHG(p_symbol, kbest_extractor, filter);
   
    //#### process test set
    BufferedWriter t_writer_nbest =  FileUtilityOld.getWriteFileStream(f_nbest)
    BufferedWriter t_writer_1best =  FileUtilityOld.getWriteFileStream(f_1best);
    System.out.println("############Process file  " + f_test_items);
    DiskHyperGraph dhg_test = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored
    dhg_test.initRead(f_test_items, f_test_rules,null);
     
    for(int sent_id=0; sent_id < num_sents; sent_id ++){
      System.out.println("#Process sentence " + sent_id);
      HyperGraph hg_test = dhg_test.readHyperGraph();     
      //if(sent_id==1)System.exit(1);
      //generate a unique nbest of strings based on viterbi cost
      ArrayList<String> nonUniqueNbestStrings = new ArrayList<String>();
      kbest_extractor.lazyKBestExtractOnHG(hg_test, null, topN, sent_id, nonUniqueNbestStrings);
     
      double max_prob = Double.NEGATIVE_INFINITY;
      String max_string = "";
     
      //chech if the sentence is too long
      boolean skip=false;
      for(String unique_string : nonUniqueNbestStrings){
        //System.out.println(unique_string);
        String[] fds = unique_string.split("\\s+\\|{3}\\s+");
        String hyp_string = fds[1];
        String[] wrds = hyp_string.split("\\s+");
        if(wrds.length>max_num_words){
          skip=true;
          break;
        }
      }
      if(skip==false){
        for(String unique_string : nonUniqueNbestStrings){
          //System.out.println(unique_string);
          String[] fds = unique_string.split("\\s+\\|{3}\\s+");
          String hyp_string = fds[1];
          String[] wrds = hyp_string.split("\\s+");
          if(wrds.length>max_num_words)
            break;
          double true_sum_prob = p_sumer.computeSumForString(hg_test, sent_id, hyp_string);
          System.out.println( unique_string + " ||| " + true_sum_prob);
          FileUtilityOld.writeLzf(t_writer_nbest, unique_string + " ||| " + true_sum_prob + "\n");
          if(true_sum_prob>max_prob){
            max_string = hyp_string;
            max_prob = true_sum_prob;
          }
        }
      }else{
        System.out.println("lzf; skip sentence " + sent_id);
      }
      FileUtilityOld.writeLzf(t_writer_1best, max_string + "\n");
    }
    FileUtilityOld.closeWriteFile(t_writer_nbest);
    FileUtilityOld.closeWriteFile(t_writer_1best);
       
  }
   
}
TOP

Related Classes of joshua.discriminative.training.oracle.StringSumInHG

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.