double init_gain = 0.1;
double sigma = 0.5;
boolean is_minimize_score = true;
//setup optimizer
GradientBasedOptimizer optimizer = null;
if(usingCRF){
HashMap crfModel = new HashMap();
if(f_feature_set!=null){
DiscriminativeSupport.loadModel(f_feature_set, crfModel, null);
}else{
System.out.println("In crf, must specify feature set"); System.exit(0);
}
optimizer = new DefaultCRF(crfModel, train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score);
optimizer.initModel(-1, 1);//TODO optimal initial parameters
hgdl = new HGDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet()));
hgdl.reset_baseline_feat();//add and init baseline feature
System.out.println("size3: " + optimizer.getSumModel().size());
}else{//perceptron
HashMap perceptron_sum_model = new HashMap();
HashMap perceptron_avg_model = new HashMap();
HashMap perceptronModel = new HashMap();
if(f_feature_set!=null){
DiscriminativeSupport.loadModel(f_feature_set, perceptronModel, null);
perceptronModel.put(baselineFeatName, 1.0);
System.out.println("feature set size is " + perceptronModel.size());
}else{
System.out.println("In perceptron, should specify feature set");
}
optimizer = new DefaultPerceptron(perceptron_sum_model, perceptron_avg_model,train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score);
hgdl = new HGDiscriminativeLearner(optimizer, new HashSet<String>(perceptronModel.keySet()));
hgdl.reset_baseline_feat();
}
//#####begin to do training
int g_sent_id=0;
for(int loop_id=0; loop_id<max_loop; loop_id++){
System.out.println("###################################Loop " + loop_id);
for(int fid=0; fid < l_file_train_items.size(); fid++){
System.out.println("############Process file id " + fid);
DiskHyperGraph dhg_train = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
dhg_train.initRead((String)l_file_train_items.get(fid), (String)l_file_train_rules.get(fid),tbl_sent_selected);
DiskHyperGraph dhg_orc =null;
BufferedReader t_reader_orc =null;
if(l_file_orc_rules!=null){
dhg_orc = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
dhg_orc.initRead((String)l_file_orc_items.get(fid), (String)l_file_orc_rules.get(fid), tbl_sent_selected);
}else{
t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc_items.get(fid),"UTF-8");
}
int total_num_sent = new Integer((String)l_num_sents.get(fid));
for(int sent_id=0; sent_id < total_num_sent; sent_id ++){
System.out.println("#Process sentence " + g_sent_id);
HyperGraph hg_train = dhg_train.readHyperGraph();
HyperGraph hg_orc =null;
String hyp_oracle =null;
if(l_file_orc_rules!=null)
hg_orc = dhg_orc.readHyperGraph();
else
hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc);
if(hg_train!=null){//sent is not skipped
if(l_file_orc_rules!=null)
hgdl.processOneSent( hg_train, hg_orc, null, l_feat_templates, l_feat_templates_nobaseline);
else
hgdl.processOneSent( hg_train, hyp_oracle, null, l_feat_templates, l_feat_templates_nobaseline);
}
g_sent_id++;
}
}
if(usingCRF){
hgdl.update_model(true);
FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false);
}else{//perceptron
hgdl.update_model(true);
((DefaultPerceptron)optimizer).force_update_avg_model();
FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false);
FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true);
}
System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
}
}