Package joshua.discriminative.training

Source Code of joshua.discriminative.training.NBESTDiscriminativeLearner

package joshua.discriminative.training;

import java.io.BufferedReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import joshua.discriminative.DiscriminativeSupport;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.training.learning_algorithm.DefaultCRF;
import joshua.discriminative.training.learning_algorithm.DefaultPerceptron;
import joshua.discriminative.training.learning_algorithm.GradientBasedOptimizer;




public class NBESTDiscriminativeLearner  {
 
  GradientBasedOptimizer optimizer =  null;
  static boolean usingCRF = true;//default is crf

 
  //##feature realted
  HashMap<String,Double> empiricalFeatsTbl = new HashMap<String,Double>();//experical feature counts
  HashMap<String,Double> modelFeatsTbl = new HashMap<String,Double>()//feature counts assigned by model
  HashSet<String> restrictedFeatureSet = null;// only consider feature in this set, if null, then ignore this
 
  //## batch update related
  int numProcessedExamples=0;
       
  //## reranking: baseline
  static String baselineFeatName ="baseline_lzf";
  static boolean fixBaseline= true;//TODO: the current version assume we always fix the baseline during training
  static double baselineScale = 1.0; //the scale for baseline relative to the corrective score
 
 
  static int startNgramOrder =1;
  static int endNgramOrder =2;
  static int hypStrID = 1;//hyp format: sent_id ||| hyp_str ||| features ||| final score
 
  public NBESTDiscriminativeLearner(GradientBasedOptimizer optimizer, HashSet<String> restrictedFeatureSet){
    this.optimizer = optimizer;   
    this.restrictedFeatureSet = restrictedFeatureSet;
  }
 

//  all hyp are represented as a hyper-graph
  public void processOneSent( ArrayList nbest, String hyp_oracle, String ref_sent){   
    //####feature extraction using the current model, get modelFeatsTbl
    if(usingCRF){
      getFeatureExpection(modelFeatsTbl, optimizer.getSumModel(),  restrictedFeatureSet,  nbest);
    }else{//perceptron
      String  rerankedOnebest = rerankNbest(optimizer.getSumModel(), restrictedFeatureSet, nbest);
      featureExtraction(rerankedOnebest, modelFeatsTbl, restrictedFeatureSet, hypStrID, false);
    }
   
    featureExtraction(hyp_oracle, empiricalFeatsTbl, restrictedFeatureSet, 0, false);   
       
    //####common: update the sum and avg model
    /*System.out.println("g_tbl_feats_model size: " + g_tbl_feats_model.size());
    System.out.println("g_tbl_feats_emperical size: " + g_tbl_feats_empirical.size());
    System.out.println("g_restricted_feature_set size: " + g_restricted_feature_set.size());
    */
    //####common: update the models
    numProcessedExamples++;
    updateModel(false);
  }
 
  public void updateModel(boolean force_update){
   
    if(force_update || numProcessedExamples>=optimizer.getBatchSize()){
      /*//debug
      System.out.println("baseline feature emprical " + g_tbl_feats_empirical.get(g_baseline_feat_name));
      System.out.println("baseline feature model "    + g_tbl_feats_model.get(g_baseline_feat_name));
      //edn*/
     
      optimizer.updateModel(empiricalFeatsTbl, modelFeatsTbl);
      //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name));
      resetBaselineFeat();
      //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name));
     
      empiricalFeatsTbl.clear();
      modelFeatsTbl.clear();
      numProcessedExamples=0;           
   
  }
 
  public void resetBaselineFeat(){
    if(fixBaseline)
      optimizer.setFeatureWeight(baselineFeatName, baselineScale);
    else{
      System.out.println("not implemented"); System.exit(0);
   
  }
 
  public String rerankNbest(HashMap<String,Double> corrective_model,  HashSet<String> restrictedFeatSet, ArrayList nbest){
    double best_score = -1000000000;
    String best_hyp=null;
    HashMap<String,Double> tbl_feat_set= new HashMap<String,Double>();
    for(int i=0; i<nbest.size(); i++){
      tbl_feat_set.clear();
      String cur_hyp = (String)nbest.get(i);
      featureExtraction(cur_hyp, tbl_feat_set, restrictedFeatSet, hypStrID, true);
      double cur_score = DiscriminativeSupport.computeLinearCombinationLogP(tbl_feat_set, corrective_model);
      if(i==0 || cur_score > best_score){//maximize
        best_score = cur_score;
        best_hyp = cur_hyp;
      }
    }
    return best_hyp;
  }
 

  //accumulate feature expectations in res_feat_tbl
  public void getFeatureExpection(HashMap<String,Double> res_feat_tbl, HashMap<String,Double> corrective_model,  HashSet<String> restricted_feat_set, ArrayList nbest){
    //### get noralization constant, remember features, remember the combined linear score
    double normalization_constant = Double.NEGATIVE_INFINITY;//log-semiring
    ArrayList<HashMap> l_feats = new ArrayList<HashMap>();
    ArrayList<Double> l_score = new ArrayList<Double>();   
    for(int i=0; i<nbest.size(); i++){
      HashMap<String,Double> tbl_feat_set= new HashMap<String,Double>();
      String cur_hyp = (String)nbest.get(i);
      featureExtraction(cur_hyp, tbl_feat_set, restricted_feat_set, hypStrID, true);
      double curScore = DiscriminativeSupport.computeLinearCombinationLogP(tbl_feat_set, corrective_model);
      normalization_constant = addInLogSemiring(normalization_constant, curScore,0);
      l_feats.add(tbl_feat_set);
      l_score.add(curScore);
    }
   
    ////### get expected feature count
    double sum=0;
    for(int i=0; i<nbest.size(); i++){
      double cur_score = l_score.get(i);
      HashMap<String, Double> feats = l_feats.get(i);
      double post_prob = Math.exp(cur_score-normalization_constant);
      sum += post_prob;
      //accumulate feature counts
      for (Map.Entry<String, Double> entry : feats.entrySet() ){
          DiscriminativeSupport.increaseCount(res_feat_tbl, entry.getKey(), entry.getValue()*post_prob);
      }
    }
    System.out.println("Sum is " + sum);
  }
 
 
//  OR: return Math.log(Math.exp(x) + Math.exp(y));
  private double addInLogSemiring(double x, double y, int add_mode){//prevent over-flow
    if(add_mode==0){//sum
      if(x==Double.NEGATIVE_INFINITY)//if y is also n-infinity, then return n-infinity
        return y;
      if(y==Double.NEGATIVE_INFINITY)
        return x;
     
      if(y<=x)
        return x + Math.log(1+Math.exp(y-x));
      else//x<y
        return y + Math.log(1+Math.exp(x-y));
    }else if (add_mode==1){//viter-min
      return (x<=y)?x:y;
    }else if (add_mode==2){//viter-max
      return (x>=y)?x:y;
    }else{
      System.out.println("invalid add mode"); System.exit(0); return 0;
    }
  }
 
  public static void featureExtraction(String hyp, HashMap<String,Double> feat_tbl, HashSet<String> restricted_feat_set, int hyp_str_id, boolean extract_baseline_feat){
    String[] fds = hyp.split("\\s+\\|{3}\\s+");
   
    //### baseline feature
    if(extract_baseline_feat){
      String score = replaceBadSymbol(fds[fds.length-1]);
      double baseline_score = new Double(score);
      feat_tbl.put(baselineFeatName, baseline_score);
    }
   
    //### ngram feature
    String[] wrds = fds[hyp_str_id].split("\\s+");
    for(int i=0; i<wrds.length; i++)
      for(int j=startNgramOrder-1; j<endNgramOrder  && j+i<wrds.length; j++){//ngram: [i,i+j]
        StringBuffer ngram = new StringBuffer();
        for(int k=i; k<=i+j; k++){
          String t_wrd = wrds[k];
          ngram.append(t_wrd);
          if(k<i+j) ngram.append(" ");
        }
        String ngram_str = ngram.toString();
        if(restricted_feat_set==null || restricted_feat_set.contains(ngram_str)){//filter
          DiscriminativeSupport.increaseCount(feat_tbl, ngram_str, 1.0);
        }
    }
  } 
 
  private static String replaceBadSymbol(String in){
    if(in.startsWith("--"))
      return in.substring(1);
    else
      return in;
  }
 
 
  public static void main(String[] args) {   
    if(args.length<5){
      System.out.println("wrong command, correct command should be: java is_using_crf f_l_train_nbest f_l_orc f_data_sel f_model_out_prefix [f_feature_set]");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++) System.out.println("arg is: " + args[i]);
      System.exit(0);   
    }
    long start_time = System.currentTimeMillis();     
    boolean is_using_crf =  new Boolean(args[0].trim());
    NBESTDiscriminativeLearner.usingCRF=is_using_crf;
    String f_l_train_nbest=args[1].trim();
    String f_l_orc=args[2].trim();
    String f_data_sel=args[3].trim();
    String f_model_out_prefix =args[4].trim();   
    String initModelFile=null;
    if(args.length>5)
      initModelFile = args[5].trim();
   
    int max_loop = 3;//TODO
   
    List<String> l_file_train_nbest = DiscriminativeSupport.readFileList(f_l_train_nbest);
    List<String> l_file_orc = DiscriminativeSupport.readFileList(f_l_orc)
    HashMap tbl_sent_selected = DiscriminativeSupport.setupDataSelTbl(f_data_sel);//for data selection
   
    //####### training
   
     
//    ###### INIT Model ###################
    NBESTDiscriminativeLearner ndl = null
    //TODO optimal parameters
    int trainSize = 610000;
    int batchUpdateSize = 1;
    int convergePass =  1;
    double initGain = 0.1;
    double sigma = 0.5;
    boolean isMinimizeScore = false;
   
    //setup optimizer
    GradientBasedOptimizer optimizer = null;
    if(usingCRF){
      HashMap<String,Double> crfModel = new HashMap<String,Double>();
      if(initModelFile!=null)
        DiscriminativeSupport.loadModel(initModelFile, crfModel, null);
      else{
        System.out.println("In crf, must specify feature set");
        System.exit(0);
      }
      optimizer = new DefaultCRF(crfModel, trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore);
      optimizer.initModel(0, 0);//TODO optimal initial parameters
      ndl = new NBESTDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet()));
      ndl.resetBaselineFeat();//add and init baseline feature
    }else{//perceptron
      HashMap<String,Double> perceptronSumModel = new HashMap<String,Double>();
      HashMap<String,Double> perceptronAvgModel = new HashMap<String,Double>();
      HashMap<String,Double> perceptronModel = new HashMap<String,Double>();
      if(initModelFile!=null){
        DiscriminativeSupport.loadModel(initModelFile, perceptronModel, null);
        perceptronModel.put(baselineFeatName, 1.0);
      }else{
        System.out.println("In perceptron, should specify feature set");       
      }
      optimizer = new DefaultPerceptron(perceptronSumModel, perceptronAvgModel,trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore);
      ndl = new NBESTDiscriminativeLearner(optimizer,  new HashSet<String>(perceptronModel.keySet()));
      ndl.resetBaselineFeat();
   
   
    //TODO
    ndl.optimizer.set_no_cooling();
    //ndl.p_optimizer.set_no_regularization();
     
    for(int loop_id=0; loop_id<max_loop; loop_id++){
      System.out.println("###################################Loop " + loop_id);
      for(int fid=0; fid < l_file_train_nbest.size(); fid++){
        System.out.println("#######Process file id " + fid);
        BufferedReader t_reader_nbest = FileUtilityOld.getReadFileStream((String)l_file_train_nbest.get(fid),"UTF-8");
        BufferedReader t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc.get(fid),"UTF-8");
        String line=null;
        int old_sent_id=-1;
        ArrayList<String> nbest = new ArrayList<String>();
        while((line=FileUtilityOld.readLineLzf(t_reader_nbest))!=null){
          String[] fds = line.split("\\s+\\|{3}\\s+");
          int new_sent_id = new Integer(fds[0]);
          if(old_sent_id!=-1 && old_sent_id!=new_sent_id){           
            String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
            if(tbl_sent_selected.containsKey(old_sent_id)){
              System.out.println("#Process sentence " + old_sent_id);
              ndl.processOneSent( nbest, hyp_oracle, null);
            }else
              System.out.println("#Skip sentence " + old_sent_id);
            nbest.clear();
            //Support.print_hash_tbl(perceptron.g_tbl_sum_model);//debug
          }
          old_sent_id = new_sent_id;
          nbest.add(line);
        }
        //last nbest
        String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
        if(tbl_sent_selected.containsKey(old_sent_id)){
          System.out.println("#Process sentence " + old_sent_id);
          ndl.processOneSent( nbest, hyp_oracle, null);
        }else
          System.out.println("#Skip sentence " + old_sent_id);
        nbest.clear();
       
        FileUtilityOld.closeReadFile(t_reader_nbest);
        FileUtilityOld.closeReadFile(t_reader_orc);
      }
     
      if(usingCRF){
        ndl.updateModel(true);
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false);
      }else{//perceptron
        ndl.updateModel(true);
        ((DefaultPerceptron)optimizer).force_update_avg_model();
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false);
        FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true);
      }
      System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
      //clean up     
    }
 
  }
 

 
}
TOP

Related Classes of joshua.discriminative.training.NBESTDiscriminativeLearner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.