Package joshua.discriminative.variational_decoder

Source Code of joshua.discriminative.variational_decoder.VariationalDecoder

package joshua.discriminative.variational_decoder;


import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.decoder.hypergraph.TrivialInsideOutside;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.feature_related.feature_function.FeatureTemplateBasedFF;
import joshua.discriminative.feature_related.feature_function.EdgeTblBasedBaselineFF;
import joshua.discriminative.ranker.HGRanker;




public class VariationalDecoder
 
  private int topN=300;
  private boolean useUniqueNbest =true;
  private boolean useTreeNbest = false;
  private boolean addCombinedCost = true
 
  private KBestExtractor kbestExtractor;
 
  SymbolTable symbolTbl;
  List<FeatureFunction> featFunctions ; //for HG reranking and kbest extraction
  HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF> approximatorMap;
 
 
  double insideOutsideScalingFactor= 0.5;

  HGRanker ranker;
 
  public VariationalDecoder(){
    //do nothing;
  }
 
 
 
 
//  return changed hg
  public HyperGraph decoding(HyperGraph hg, int sentenceID, BufferedWriter out) {   
 
    //=== step-1: run inside-outside
    //note, inside and outside will use the transition_cost of each hyperedge, this cost is already linearly interpolated
    TrivialInsideOutside pInsideOutside = new TrivialInsideOutside();
    pInsideOutside.runInsideOutside(hg, 0, 1, insideOutsideScalingFactor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
   
    //=== initialize baseline table
    //TODO:????????????????????
    ((EdgeTblBasedBaselineFF)featFunctions.get(0)).collectTransitionLogPs(hg);
   
    //=== step-2: model extraction based on the definition of Q
    for(Map.Entry<VariationalNgramApproximator, FeatureTemplateBasedFF> entry : approximatorMap.entrySet()){
      VariationalNgramApproximator approximator = entry.getKey();
      FeatureTemplateBasedFF featureFunction = entry.getValue();
      HashMap<String, Double> model = approximator.estimateModel(hg, pInsideOutside);
      featureFunction.setModel(model);     
    }
   
    //clean up
    pInsideOutside.clearState();
   
    //=== step-3: rank the HG using the baseline and variational feature
    this.ranker.rankHG(hg);

    //=== step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list
    try{
      kbestExtractor.lazyKBestExtractOnHG(hg, this.featFunctions, this.topN, sentenceID, out);
    } catch (IOException e) {
      e.printStackTrace();
    }
    return hg;
  }
 
 

  public void initializeDecoder(String configFile){
    VariationalDecoderConfiguration.readConfigFile(configFile);
    this.symbolTbl = new BuildinSymbol(null);
    this.featFunctions = new ArrayList<FeatureFunction>();
    this.ranker = new HGRanker(featFunctions);
    this.approximatorMap = new HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF>();   
    VariationalDecoderConfiguration.initializeModels(configFile, this.symbolTbl, this.featFunctions, this.approximatorMap);   
    this.insideOutsideScalingFactor =  VariationalDecoderConfiguration.insideoutsideScalingFactor;

    //this.kbestExtractor = new KBestExtractor(p_symbol); 
    this.kbestExtractor = new KBestExtractor(this.symbolTbl, this.useUniqueNbest, this.useTreeNbest, false, this.addCombinedCost, false, true);
  }
 
  public void  decodingTestSet(String testItemsFile, String testRulesFile, int numSents, String nbestFile) {

    BufferedWriter nbestWriter =  FileUtilityOld.getWriteFileStream(nbestFile)
    System.out.println("############Process file  " + testItemsFile);
    DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, VariationalDecoderConfiguration.ngramStateID, true, null); //have model costs stored
    diskHG.initRead(testItemsFile, testRulesFile,null);
     
    for(int sentID=0; sentID < numSents; sentID ++){
      System.out.println("#Process sentence " + sentID);
      HyperGraph testhg = diskHG.readHyperGraph();
      /*if(use_constituent_decoding)
        vdecoder.constitudent_decoding(hg_test, sent_id, t_writer_nbest);
      else*/
        decoding(testhg, sentID, nbestWriter);
     
      //Thread.sleep(3000);
    }
    FileUtilityOld.closeWriteFile(nbestWriter);
  }
 
  public void writeConfigFile(double[] newWeights, String template, String fileToWrite){
    BufferedReader configReader = FileUtilityOld.getReadFileStream(template);
    BufferedWriter configWriter =  FileUtilityOld.getWriteFileStream(fileToWrite);   
    String line;
    int featID = 0;
    while ((line = FileUtilityOld.readLineLzf(configReader)) != null) {
      line = line.trim();
      if (line.matches("^\\s*\\#.*$") || line.matches("^\\s*$") || line.indexOf("=") != -1) {//comment, empty line, or parameter lines: just copy
         FileUtilityOld.writeLzf(configWriter, line + "\n");
      }else{//models: replace the weight
        String[] fds = line.split("\\s+");
        StringBuffer new_line = new StringBuffer();
        if(fds[fds.length-1].matches("^[\\d\\.\\-\\+]+")==false){System.out.println("last field is not a number, must be wrong; the field is: " + fds[fds.length-1]); System.exit(1);};
        for(int i=0; i<fds.length-1; i++){
          new_line.append(fds[i]);
          new_line.append(" ");
        }
        new_line.append(newWeights[featID++])
        FileUtilityOld.writeLzf(configWriter, new_line.toString() + "\n");
      }       
    }
    if(featID!=newWeights.length){System.out.println("number of models does not match number of weights, must be wrong"); System.exit(1);};
    FileUtilityOld.closeReadFile(configReader);
    FileUtilityOld.closeWriteFile(configWriter);   
  }
 
  /*this assumes that the weight_vector is ordered according to the decoder config file
   * */
  public void changeFeatureWeightVector(double[] weight_vector){
    if(featFunctions.size()!=weight_vector.length){
      System.out.println("In updateFeatureWeightVector: number of weights does not match number of feature functions");
      System.exit(0);
    }   
    for(int i=0; i<featFunctions.size(); i++){
      FeatureFunction ff = featFunctions.get(i);
      double old_weight = ff.getWeight();
      ff.setWeight(weight_vector[i]);
      System.out.println("Feature function : " + ff.getClass().getSimpleName() + "; weight changed from " + old_weight + " to " + ff.getWeight());
    }
  }
 
 
 
 
  //============================ main function ==============================
  public static void main(String[] args) throws InterruptedException, IOException {

    /*//##read configuration information
    if(args.length<8){
      System.out.println("wrong command, correct command should be: java Perceptron_HG is_crf lf_train_items lf_train_rules lf_orc_items lf_orc_rules f_l_num_sents f_data_sel f_model_out_prefix use_tm_feat use_lm_feat use_edge_bigram_feat_only f_feature_set use_joint_tm_lm_feature");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]);
      System.exit(0);   
    }*/
    if(args.length!=5){
      System.out.println("Wrong number of parameters, it must be  5");
      System.exit(1);
    }
   
    //long start_time = System.currentTimeMillis();   
    String testItemsFile=args[0].trim();
    String testRulesFile=args[1].trim();
    int numSents=new Integer(args[2].trim());
    String nbestFile=args[3].trim();
    String configFile=args[4].trim();
   
    VariationalDecoder vdecoder = new VariationalDecoder();
    vdecoder.initializeDecoder(configFile);
    vdecoder.decodingTestSet(testItemsFile, testRulesFile, numSents, nbestFile);
  }
 
 
 
  /*
  //return changed hg
  public HyperGraph decoding_old(HyperGraph hg, int sentenceID, BufferedWriter out){   
    //### step-1: run inside-outside
    TrivialInsideOutside p_inside_outside = new TrivialInsideOutside();
    p_inside_outside.run_inside_outside(hg, 0, 1);//ADD_MODE=0=sum; LOG_SEMIRING=1;
   
   
    //### step-2: model extraction based on the definition of Q
    
    //step-2.1: baseline feature: do not use insideoutside
    this.baseline_feat_tbl.clear();
    FeatureExtractionHG.feature_extraction_hg(hg, this.baseline_feat_tbl, null, this.baseline_ft_template); 
    
    //step-2.2: ngram feature
    this.ngram_feat_tbl.clear();
    FeatureExtractionHG.feature_extraction_hg(hg, p_inside_outside, this.ngram_feat_tbl, null, this.ngram_ft_template); 
    VariationalLMFeature.getNormalizedLM(this.ngram_feat_tbl, this.baseline_lm_order, false, false, false); 
   
    //step-2.2: constituent feature
    this.constituent_feat_tbl.clear();
    FeatureExtractionHG.feature_extraction_hg(hg, p_inside_outside, this.constituent_feat_tbl, null, this.constituent_template); 
   
    //clean
    p_inside_outside.clear_state();
   
    //### step-3: rank the HG using the baseline and variational feature
    l_feat_functions.clear();   
   
    FeatureFunction base_ff =  new BaselineFF(this.baseline_feat_id, this.baseline_weight, this.baseline_feat_tbl);
    l_feat_functions.add(base_ff);
     
    FeatureFunction variational_ff =  new VariationalLMFeature(this.baseline_lm_feat_id, this.baseline_lm_order, this.p_symbol, this.ngram_feat_tbl,  this.ngam_weight, false, false, false);
    l_feat_functions.add(variational_ff);
   
    FeatureFunction constituent_ff =  new BaselineFF(this.constituent_feat_id, this.constituent_weight, this.constituent_feat_tbl);
    l_feat_functions.add(constituent_ff);
   
    RankHG.rankHG(hg, l_feat_functions);
   
       
    //### step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list
    //two features: baseline feature, and reranking feature
    KbestExtraction kbestExtractor = new KbestExtraction(p_symbol);
    kbestExtractor.lazy_k_best_extract_hg(hg, this.l_feat_functions, this.topN, this.use_unique_nbest, sentenceID, out, this.use_tree_nbest, this.add_combined_cost);
   
    return hg;
  }
 
 
 
 
//  return changed hg
  public HyperGraph constitudent_decoding(HyperGraph hg, int sentenceID, BufferedWriter out){   
   
    ConstituentVariationalDecoder decoder = new ConstituentVariationalDecoder();
    decoder.decoding(hg);
     
    //### step-2: model extraction based on the definition of Q
    
    //step-2.1: baseline feature: do not use insideoutside
    this.baseline_feat_tbl.clear();
    FeatureExtractionHG.feature_extraction_hg(hg, this.baseline_feat_tbl, null, this.baseline_ft_template); 
   
   
    //### step-3: rank the HG using the baseline and variational feature
    l_feat_functions.clear();   
   
    FeatureFunction base_ff =  new BaselineFF(this.baseline_feat_id, this.baseline_weight, this.baseline_feat_tbl);
    l_feat_functions.add(base_ff);
     
    //RankHG.rankHG(hg, l_feat_functions);
   
   
    //### step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list
    //two features: baseline feature, and reranking feature
    this.kbestExtractor.lazy_k_best_extract_hg(hg, this.l_feat_functions, this.topN, this.use_unique_nbest, sentenceID, out, this.use_tree_nbest, this.add_combined_cost);
   
    return hg;
  }
*/ 
 
}
 
TOP

Related Classes of joshua.discriminative.variational_decoder.VariationalDecoder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.