System.out.print("..Loading model complete!\n");
System.out.println("Testing Bayes...");
int count=0;
for(int i=0;i<testset.size();i++){
Instance data = testset.getInstance(i);
Integer gold = (Integer) data.getTarget();
Predict<String> pres=bayes.classify(data, Type.STRING, 3);
String pred_label=pres.getLabel();
// String pred_label = bayes.getStringLabel(data);
String gold_label = bayes.getLabel(gold);
if(pred_label.equals(gold_label)){
//System.out.println(pred_label+" : "+testsetbayes.getInstance(i).getTempData());
count++;
}
else{
// System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
// for(int j=0;j<3;j++)
// System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
}
}
int bayesCount=count;
System.out.println("..Testing Bayes complete!");
System.out.println("Bayes Precision:"+((float)bayesCount/testset.size())+"("+bayesCount+"/"+testset.size()+")");
/**
* Knn
*/
System.out.print("\nKnn\n");
//建立字典管理器
AlphabetFactory af2 = AlphabetFactory.buildFactory();
//使用n元特征
ngrampp = new NGram(new int[] {2,3});
//将字符特征转换成字典索引;
sparsepp=new StringArray2SV(af2);
//将目标值对应的索引号作为类别
targetpp = new Target2Label(af2.DefaultLabelAlphabet());
//建立pipe组合
pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,sparsepp});
System.out.print("Init dataset...");
trainset.setAlphabetFactory(af2);
trainset.setPipes(pp);
testset.setAlphabetFactory(af2);
testset.setPipes(pp);
for(int i=0;i<trainset.size();i++){
Instance inst=trainset.get(i);
inst.setData(inst.getSource());
int target_id=Integer.parseInt(inst.getTarget().toString());
inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
pp.addThruPipe(inst);
}
for(int i=0;i<testset.size();i++){
Instance inst=testset.get(i);
inst.setData(inst.getSource());
int target_id=Integer.parseInt(inst.getTarget().toString());
inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
pp.addThruPipe(inst);
}
System.out.print("complete!\n");
System.out.print("Training Knn...\n");
SparseVectorSimilarity sim=new SparseVectorSimilarity();
pp.removeTargetPipe();
KNNClassifier knn=new KNNClassifier(trainset, pp, sim, af2, 7);
af2.setStopIncrement(true);
System.out.print("..Training compelte!\n");
System.out.print("Saving model...\n");
knn.saveTo(knnModelFile);
knn = null;
System.out.print("..Saving model compelte!\n");
System.out.print("Loading model...\n");
knn =KNNClassifier.loadFrom(knnModelFile);
System.out.print("..Loading model compelte!\n");
System.out.println("Testing Knn...\n");
count=0;
for(int i=0;i<testset.size();i++){
Instance data = testset.getInstance(i);
Integer gold = (Integer) data.getTarget();
Predict<String> pres=(Predict<String>) knn.classify(data, Type.STRING, 3);
String pred_label=pres.getLabel();
String gold_label = knn.getLabel(gold);
if(pred_label.equals(gold_label)){
//System.out.println(pred_label+" : "+testsetknn.getInstance(i).getTempData());
count++;
}
else{
// System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
// for(int j=0;j<3;j++)
// System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
}
}
int knnCount=count;
System.out.println("..Testing Knn Complete");
System.out.println("Bayes Precision:"+((float)bayesCount/testset.size())+"("+bayesCount+"/"+testset.size()+")");
System.out.println("Knn Precision:"+((float)knnCount/testset.size())+"("+knnCount+"/"+testset.size()+")");
//建立字典管理器
AlphabetFactory af3 = AlphabetFactory.buildFactory();
//使用n元特征
ngrampp = new NGram(new int[] {2,3 });
//将字符特征转换成字典索引
Pipe indexpp = new StringArray2IndexArray(af3);
//将目标值对应的索引号作为类别
targetpp = new Target2Label(af3.DefaultLabelAlphabet());
//建立pipe组合
pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,indexpp});
trainset.setAlphabetFactory(af3);
trainset.setPipes(pp);
testset.setAlphabetFactory(af3);
testset.setPipes(pp);
for(int i=0;i<trainset.size();i++){
Instance inst=trainset.get(i);
inst.setData(inst.getSource());
int target_id=Integer.parseInt(inst.getTarget().toString());
inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
pp.addThruPipe(inst);
}
for(int i=0;i<testset.size();i++){
Instance inst=testset.get(i);
inst.setData(inst.getSource());
int target_id=Integer.parseInt(inst.getTarget().toString());
inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
pp.addThruPipe(inst);
}
/**
* 建立分类器
*/
OnlineTrainer trainer3 = new OnlineTrainer(af3);
Linear pclassifier = trainer3.train(trainset);
pp.removeTargetPipe();
pclassifier.setPipe(pp);
af.setStopIncrement(true);
//将分类器保存到模型文件
pclassifier.saveTo(linearModelFile);
pclassifier = null;
//从模型文件读入分类器
Linear cl =Linear.loadFrom(linearModelFile);
//性能评测
Evaluation eval = new Evaluation(testset);
eval.eval(cl,1);
/**
* 测试
*/
System.out.println("类别 : 文本内容");
System.out.println("===================");
count=0;
for(int i=0;i<testset.size();i++){
Instance data = testset.getInstance(i);
Integer gold = (Integer) data.getTarget();
String pred_label = cl.getStringLabel(data);
String gold_label = cl.getLabel(gold);
if(pred_label.equals(gold_label)){
//System.out.println(pred_label+" : "+testsetliner.getInstance(i).getSource());