Package joshua.discriminative.training.learning_algorithm

Examples of joshua.discriminative.training.learning_algorithm.GradientBasedOptimizer


    double init_gain = 0.1;
    double sigma = 0.5;
    boolean is_minimize_score = true;
   
    //setup optimizer
    GradientBasedOptimizer optimizer = null;
    if(usingCRF){
      HashMap crfModel = new HashMap();
      if(f_feature_set!=null){
        DiscriminativeSupport.loadModel(f_feature_set, crfModel, null);
      }else{
        System.out.println("In crf, must specify feature set"); System.exit(0);
      }
      optimizer = new DefaultCRF(crfModel, train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score);
      optimizer.initModel(-1, 1);//TODO optimal initial parameters
      hgdl = new HGDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet()));
      hgdl.reset_baseline_feat();//add and init baseline feature
      System.out.println("size3: " + optimizer.getSumModel().size());
    }else{//perceptron
      HashMap perceptron_sum_model = new HashMap();
      HashMap perceptron_avg_model = new HashMap();
      HashMap perceptronModel = new HashMap();
      if(f_feature_set!=null){
        DiscriminativeSupport.loadModel(f_feature_set, perceptronModel, null);
        perceptronModel.put(baselineFeatName, 1.0);
        System.out.println("feature set size is " + perceptronModel.size());
      }else{
        System.out.println("In perceptron, should specify feature set");       
      }
      optimizer = new DefaultPerceptron(perceptron_sum_model, perceptron_avg_model,train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score);
      hgdl = new HGDiscriminativeLearner(optimizer,  new HashSet<String>(perceptronModel.keySet()));
      hgdl.reset_baseline_feat();
    }   
       
    //#####begin to do training
    int g_sent_id=0;
    for(int loop_id=0; loop_id<max_loop; loop_id++){
      System.out.println("###################################Loop " + loop_id);
      for(int fid=0; fid < l_file_train_items.size(); fid++){
        System.out.println("############Process file id " + fid);
        DiskHyperGraph dhg_train = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
        dhg_train.initRead((String)l_file_train_items.get(fid), (String)l_file_train_rules.get(fid),tbl_sent_selected);
        DiskHyperGraph dhg_orc =null;
        BufferedReader t_reader_orc =null;
        if(l_file_orc_rules!=null){
          dhg_orc = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
          dhg_orc.initRead((String)l_file_orc_items.get(fid), (String)l_file_orc_rules.get(fid), tbl_sent_selected);
        }else{
          t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc_items.get(fid),"UTF-8");
        }
         
        int total_num_sent = new Integer((String)l_num_sents.get(fid));
        for(int sent_id=0; sent_id < total_num_sent; sent_id ++){
          System.out.println("#Process sentence " + g_sent_id);
          HyperGraph hg_train = dhg_train.readHyperGraph();
          HyperGraph hg_orc =null;
          String hyp_oracle =null;
          if(l_file_orc_rules!=null)
            hg_orc = dhg_orc.readHyperGraph();
          else
            hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
          if(hg_train!=null){//sent is not skipped
            if(l_file_orc_rules!=null)
              hgdl.processOneSent( hg_train, hg_orc, null, l_feat_templates, l_feat_templates_nobaseline);
            else
              hgdl.processOneSent( hg_train, hyp_oracle, null, l_feat_templates, l_feat_templates_nobaseline);
          }
          g_sent_id++;
        }
      }
   
      if(usingCRF){
        hgdl.update_model(true);
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false);
      }else{//perceptron
        hgdl.update_model(true);
        ((DefaultPerceptron)optimizer).force_update_avg_model();
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false);
        FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true);
      }
      System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
    }
  }
View Full Code Here


    double initGain = 0.1;
    double sigma = 0.5;
    boolean isMinimizeScore = false;
   
    //setup optimizer
    GradientBasedOptimizer optimizer = null;
    if(usingCRF){
      HashMap<String,Double> crfModel = new HashMap<String,Double>();
      if(initModelFile!=null)
        DiscriminativeSupport.loadModel(initModelFile, crfModel, null);
      else{
        System.out.println("In crf, must specify feature set");
        System.exit(0);
      }
      optimizer = new DefaultCRF(crfModel, trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore);
      optimizer.initModel(0, 0);//TODO optimal initial parameters
      ndl = new NBESTDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet()));
      ndl.resetBaselineFeat();//add and init baseline feature
    }else{//perceptron
      HashMap<String,Double> perceptronSumModel = new HashMap<String,Double>();
      HashMap<String,Double> perceptronAvgModel = new HashMap<String,Double>();
      HashMap<String,Double> perceptronModel = new HashMap<String,Double>();
      if(initModelFile!=null){
        DiscriminativeSupport.loadModel(initModelFile, perceptronModel, null);
        perceptronModel.put(baselineFeatName, 1.0);
      }else{
        System.out.println("In perceptron, should specify feature set");       
      }
      optimizer = new DefaultPerceptron(perceptronSumModel, perceptronAvgModel,trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore);
      ndl = new NBESTDiscriminativeLearner(optimizer,  new HashSet<String>(perceptronModel.keySet()));
      ndl.resetBaselineFeat();
   
   
    //TODO
    ndl.optimizer.set_no_cooling();
    //ndl.p_optimizer.set_no_regularization();
     
    for(int loop_id=0; loop_id<max_loop; loop_id++){
      System.out.println("###################################Loop " + loop_id);
      for(int fid=0; fid < l_file_train_nbest.size(); fid++){
        System.out.println("#######Process file id " + fid);
        BufferedReader t_reader_nbest = FileUtilityOld.getReadFileStream((String)l_file_train_nbest.get(fid),"UTF-8");
        BufferedReader t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc.get(fid),"UTF-8");
        String line=null;
        int old_sent_id=-1;
        ArrayList<String> nbest = new ArrayList<String>();
        while((line=FileUtilityOld.readLineLzf(t_reader_nbest))!=null){
          String[] fds = line.split("\\s+\\|{3}\\s+");
          int new_sent_id = new Integer(fds[0]);
          if(old_sent_id!=-1 && old_sent_id!=new_sent_id){           
            String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
            if(tbl_sent_selected.containsKey(old_sent_id)){
              System.out.println("#Process sentence " + old_sent_id);
              ndl.processOneSent( nbest, hyp_oracle, null);
            }else
              System.out.println("#Skip sentence " + old_sent_id);
            nbest.clear();
            //Support.print_hash_tbl(perceptron.g_tbl_sum_model);//debug
          }
          old_sent_id = new_sent_id;
          nbest.add(line);
        }
        //last nbest
        String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
        if(tbl_sent_selected.containsKey(old_sent_id)){
          System.out.println("#Process sentence " + old_sent_id);
          ndl.processOneSent( nbest, hyp_oracle, null);
        }else
          System.out.println("#Skip sentence " + old_sent_id);
        nbest.clear();
       
        FileUtilityOld.closeReadFile(t_reader_nbest);
        FileUtilityOld.closeReadFile(t_reader_orc);
      }
     
      if(usingCRF){
        ndl.updateModel(true);
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false);
      }else{//perceptron
        ndl.updateModel(true);
        ((DefaultPerceptron)optimizer).force_update_avg_model();
        FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false);
        FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true);
      }
      System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
      //clean up     
    }
 
View Full Code Here

TOP

Related Classes of joshua.discriminative.training.learning_algorithm.GradientBasedOptimizer

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.