Package org.fnlp.ml.types

Examples of org.fnlp.ml.types.InstanceSet


    Pipe fpipe = new StringArray2IndexArray(factory, true);
    //构造转换器组
    Pipe pipe = new SeriesPipes(new Pipe[]{lpipe,fpipe});
   
    //构建训练集
    train = new InstanceSet(pipe, factory);
    SimpleFileReader reader = new SimpleFileReader (path,true);
    train.loadThruStagePipes(reader);
    al.setStopIncrement(true);
   
    //构建测试集
    test = new InstanceSet(pipe, factory);   
    reader = new SimpleFileReader (path,true);
    test.loadThruStagePipes(reader)

    System.out.println("Train Number: " + train.size());
    System.out.println("Test Number: " + test.size());
View Full Code Here


    float data[][] = {{1,1,-1},{0,0,0},{1,1,1},{1,1,1},{1,-1,1}};
    String target[] = {"Y","Y","N","N","N"};
   
   
    //构建训练集
    trainset = new InstanceSet(factory);
   
    for(int i=0;i<data.length;i++){
      ISparseVector sv = new HashSparseVector(data[i],true);
      int l = lf.lookupIndex(target[i]);
      Instance inst = new Instance(sv,l);
View Full Code Here

    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());   
   
    //建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,indexpp});
   
    InstanceSet instset = new InstanceSet(pp,af);
   
    //用不同的Reader读取相应格式的文件
    Reader reader = new FileReader(trainDataPath,"UTF-8",".data");
   
    //读入数据,并进行数据处理
    instset.loadThruStagePipes(reader);
       
    float percent = 0.8f;
   
    //将数据集分为训练是和测试集
    InstanceSet[] splitsets = instset.split(percent);
   
    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1]
   
    /**
     * 建立分类器
     */   
    OnlineTrainer trainer = new OnlineTrainer(af);
    Linear pclassifier = trainer.train(trainset);
    pp.removeTargetPipe();
    pclassifier.setPipe(pp);
    af.setStopIncrement(true);
   
    //将分类器保存到模型文件
    pclassifier.saveTo(modelFile)
    pclassifier = null;
   
    //从模型文件读入分类器
    Linear cl =Linear.loadFrom(modelFile);
   
    //性能评测
    Evaluation eval = new Evaluation(testset);
    eval.eval(cl,1);

    /**
     * 测试
     */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
     
      Integer gold = (Integer) data.getTarget();
      String pred_label = cl.getStringLabel(data);
      String gold_label = cl.getLabel(gold);
     
      if(pred_label.equals(gold_label))
        System.out.println(pred_label+" : "+testset.getInstance(i).getSource());
      else
        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getSource());
    }
   
   
    /**
     * 分类器使用
 
View Full Code Here

    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet())
    //建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,sparsepp});

    System.out.print("\nReading data......\n");
    InstanceSet instset = new InstanceSet(pp,af)
    Reader reader = new MyDocumentReader(trainDataPath,"gbk");
    instset.loadThruStagePipes(reader);
    System.out.print("..Reading data complete\n");
   
    //将数据集分为训练是和测试集
    System.out.print("Sspliting....");
    float percent = 0.9f;
    InstanceSet[] splitsets = instset.split(percent);
   
    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1]
    System.out.print("..Spliting complete!\n");

    System.out.print("Training...\n");
    BayesTrainer trainer=new BayesTrainer();
    BayesClassifier classifier= (BayesClassifier) trainer.train(trainset);
    pp.removeTargetPipe();
    classifier.setPipe(pp);
    af.setStopIncrement(true);
    System.out.print("..Training complete!\n");
    System.out.print("Saving model...\n");
    classifier.saveTo(bayesModelFile)
    classifier = null;
    System.out.print("..Saving model complete!\n");
    /**
     * 测试
     */
    System.out.print("Loading model...\n");
    BayesClassifier bayes;
    bayes =BayesClassifier.loadFrom(bayesModelFile);
//    bayes =classifier;
    System.out.print("..Loading model complete!\n");
   
    System.out.println("Testing Bayes...");
    int count=0;
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
      Integer gold = (Integer) data.getTarget();
      Predict<String> pres=bayes.classify(data, Type.STRING, 3);
      String pred_label=pres.getLabel();
//      String pred_label = bayes.getStringLabel(data);
      String gold_label = bayes.getLabel(gold);
     
      if(pred_label.equals(gold_label)){
        //System.out.println(pred_label+" : "+testsetbayes.getInstance(i).getTempData());
        count++;
      }
      else{
//        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
//        for(int j=0;j<3;j++)
//          System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
      }
    }
    int bayesCount=count;
    System.out.println("..Testing Bayes complete!");
    System.out.println("Bayes Precision:"+((float)bayesCount/testset.size())+"("+bayesCount+"/"+testset.size()+")");


    /**
     * Knn
     */
    System.out.print("\nKnn\n");
    //建立字典管理器
    AlphabetFactory af2 = AlphabetFactory.buildFactory();
    //使用n元特征
    ngrampp = new NGram(new int[] {2,3});
    //将字符特征转换成字典索引; 
    sparsepp=new StringArray2SV(af2);
    //将目标值对应的索引号作为类别
    targetpp = new Target2Label(af2.DefaultLabelAlphabet())
    //建立pipe组合
    pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,sparsepp});

    System.out.print("Init dataset...");
    trainset.setAlphabetFactory(af2)
    trainset.setPipes(pp)
    testset.setAlphabetFactory(af2)
    testset.setPipes(pp);     
    for(int i=0;i<trainset.size();i++){
      Instance inst=trainset.get(i);
      inst.setData(inst.getSource());
      int target_id=Integer.parseInt(inst.getTarget().toString());
      inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
      pp.addThruPipe(inst);
    }   
    for(int i=0;i<testset.size();i++){
      Instance inst=testset.get(i);
      inst.setData(inst.getSource());
      int target_id=Integer.parseInt(inst.getTarget().toString());
      inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
      pp.addThruPipe(inst);
    }

    System.out.print("complete!\n");
    System.out.print("Training Knn...\n");
    SparseVectorSimilarity sim=new SparseVectorSimilarity();
    pp.removeTargetPipe();
    KNNClassifier knn=new KNNClassifier(trainset, pp, sim, af2, 7)
    af2.setStopIncrement(true)
    System.out.print("..Training compelte!\n");
    System.out.print("Saving model...\n");
    knn.saveTo(knnModelFile)
    knn = null;
    System.out.print("..Saving model compelte!\n");

   
    System.out.print("Loading model...\n");
    knn =KNNClassifier.loadFrom(knnModelFile);
    System.out.print("..Loading model compelte!\n");
    System.out.println("Testing Knn...\n");
    count=0;
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
      Integer gold = (Integer) data.getTarget();
      Predict<String> pres=(Predict<String>) knn.classify(data, Type.STRING, 3);
      String pred_label=pres.getLabel();
      String gold_label = knn.getLabel(gold);
     
      if(pred_label.equals(gold_label)){
        //System.out.println(pred_label+" : "+testsetknn.getInstance(i).getTempData());
        count++;
      }
      else{
//        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
//        for(int j=0;j<3;j++)
//          System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
      }
    }
    int knnCount=count;
    System.out.println("..Testing Knn Complete");
    System.out.println("Bayes Precision:"+((float)bayesCount/testset.size())+"("+bayesCount+"/"+testset.size()+")");
    System.out.println("Knn Precision:"+((float)knnCount/testset.size())+"("+knnCount+"/"+testset.size()+")");
   
    //建立字典管理器
    AlphabetFactory af3 = AlphabetFactory.buildFactory();
    //使用n元特征
    ngrampp = new NGram(new int[] {2,3 });
    //将字符特征转换成字典索引
    Pipe indexpp = new StringArray2IndexArray(af3);
    //将目标值对应的索引号作为类别
    targetpp = new Target2Label(af3.DefaultLabelAlphabet());   
   
    //建立pipe组合
    pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,indexpp});
   
    trainset.setAlphabetFactory(af3)
    trainset.setPipes(pp)
    testset.setAlphabetFactory(af3)
    testset.setPipes(pp);
    for(int i=0;i<trainset.size();i++){
      Instance inst=trainset.get(i);
      inst.setData(inst.getSource());
      int target_id=Integer.parseInt(inst.getTarget().toString());
      inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
      pp.addThruPipe(inst);
    }   
    for(int i=0;i<testset.size();i++){
      Instance inst=testset.get(i);
      inst.setData(inst.getSource());
      int target_id=Integer.parseInt(inst.getTarget().toString());
      inst.setTarget(af.DefaultLabelAlphabet().lookupString(target_id));
      pp.addThruPipe(inst);
    }     
   
    /**
     * 建立分类器
     */   
    OnlineTrainer trainer3 = new OnlineTrainer(af3);
    Linear pclassifier = trainer3.train(trainset);
    pp.removeTargetPipe();
    pclassifier.setPipe(pp);
    af.setStopIncrement(true);
   
    //将分类器保存到模型文件
    pclassifier.saveTo(linearModelFile)
    pclassifier = null;
   
    //从模型文件读入分类器
    Linear cl =Linear.loadFrom(linearModelFile);
   
    //性能评测
    Evaluation eval = new Evaluation(testset);
    eval.eval(cl,1);
    /**
     * 测试
     */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    count=0;
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
     
      Integer gold = (Integer) data.getTarget();
      String pred_label = cl.getStringLabel(data);
      String gold_label = cl.getLabel(gold);
     
      if(pred_label.equals(gold_label)){
        //System.out.println(pred_label+" : "+testsetliner.getInstance(i).getSource());
        count++;
      }
      else{
//        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
      }
    }
    int linearCount=count;
    System.out.println("结果");
    System.out.println("labelSize: "+af.getLabelSize());
    System.out.println("instsetSize: "+instset.size());
    System.out.println("Bayes Precision:"+((float)bayesCount/testset.size())+"("+bayesCount+"/"+testset.size()+")");
    System.out.println("Knn Precision:"+((float)knnCount/testset.size())+"("+knnCount+"/"+testset.size()+")");
    System.out.println("Linear Precision:"+((float)linearCount/testset.size())+"("+linearCount+"/"+testset.size()+")")
  }
View Full Code Here

    System.out.print("Loading training data ...");
    long beginTime = System.currentTimeMillis();

    Pipe pipe = createProcessor(false);
    InstanceSet trainSet = new InstanceSet(pipe, factory);

    LabelAlphabet labels = factory.DefaultLabelAlphabet();
    IFeatureAlphabet features = factory.DefaultFeatureAlphabet();

    // 训练集
    trainSet.loadThruStagePipes(new SequenceReader(train,true, "utf8"));

    long endTime = System.currentTimeMillis();
    System.out.println(" done!");
    System.out
    .println("Time escape: " + (endTime - beginTime) / 1000 + "s");
    System.out.println();

    // 输出
    System.out.println("Training Number: " + trainSet.size());

    System.out.println("Label Number: " + labels.size()); // 标签个数
    System.out.println("Feature Number: " + features.size()); // 特征个数

    // 冻结特征集
    features.setStopIncrement(true);
    labels.setStopIncrement(true);

    InstanceSet testSet = null;
    // /////////////////
    if (testfile != null) {

      Pipe tpipe;
      if (false) {// 如果test data没有标注
        tpipe = new SeriesPipes(new Pipe[] { featurePipe });
      } else {
        tpipe = pipe;
      }

      // 测试集
      testSet = new InstanceSet(tpipe);

      testSet.loadThruStagePipes(new SequenceReader(testfile, true, "utf8"));
      System.out.println("Test Number: " + testSet.size()); // 样本个数
    }

    /**
     *
     * 更新参数的准则
 
View Full Code Here

    long starttime = System.currentTimeMillis();
    // 将样本通过Pipe抽取特征
    Pipe pipe = createProcessor(true);

    // 测试集
    InstanceSet testSet = new InstanceSet(pipe);

    testSet.loadThruStagePipes(new SequenceReader(testfile, true, "utf8"));
    System.out.println("Test Number: " + testSet.size()); // 样本个数

    long featuretime = System.currentTimeMillis();

    boolean acc = true;
    double error = 0;
    int senError = 0;
    int len = 0;
    boolean hasENG = false;
    int ENG_all = 0, ENG_right = 0;
    Loss loss = new HammingLoss();

    String[][] labelsSet = new String[testSet.size()][];
    String[][] targetSet = new String[testSet.size()][];
    LabelAlphabet labels = cl.getAlphabetFactory().buildLabelAlphabet(
        "labels");
    for (int i = 0; i < testSet.size(); i++) {
      Instance carrier = testSet.get(i);
      int[] pred = (int[]) cl.classify(carrier).getLabel(0);
      if (acc) {
        len += pred.length;
        double e = loss.calc(carrier.getTarget(), pred);
        error += e;
        if(e != 0)
          senError++;
        //测试中英混杂语料
        if(hasENG) {
          String[][] origin = (String[][])carrier.getSource();
          int[] target = (int[])carrier.getTarget();
          for(int j = 0; j < target.length; j++) {
            if(origin[j][0].contains("ENG")) {
              ENG_all++;
              if(target[j] == pred[j])
                ENG_right++;
            }
          }
        }
      }
      labelsSet[i] = labels.lookupString(pred);
      targetSet[i] = labels.lookupString((int[])carrier.getTarget());
    }

    long endtime = System.currentTimeMillis();
    System.out.println("totaltime\t" + (endtime - starttime) / 1000.0);
    System.out.println("feature\t" + (featuretime - starttime) / 1000.0);
    System.out.println("predict\t" + (endtime - featuretime) / 1000.0);
   
    if (acc) {
      System.out.println("Test Accuracy:\t" + (1 - error / len));
      System.out.println("Sentence Accuracy:\t" + ((double)(testSet.size() - senError) / testSet.size()));
      if(hasENG)
        System.out.println("ENG Accuracy:\t" + ((double)ENG_right / ENG_all));
    }

    if (output != null) {
View Full Code Here

    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());   
   
    //建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,indexpp});
   
    InstanceSet trainset = new InstanceSet(pp,af);
    InstanceSet testset = new InstanceSet(pp,af);
   
    //用不同的Reader读取相应格式的文件
    Reader reader = new DocumentReader(trainDataPath);
   
    //读入数据,并进行数据处理
    trainset.loadThruStagePipes(reader);
   
    reader = new DocumentReader(testDataPath);
     
    testset.loadThruStagePipes(reader);
   
   
    /**
     * 建立分类器
     */   
 
View Full Code Here

    long starttime = System.currentTimeMillis();
    // 将样本通过Pipe抽取特征
    Pipe pipe = createProcessor(true);

    // 测试集
    InstanceSet testSet = new InstanceSet(pipe);

    testSet.loadThruStagePipes(new SequenceReader(testfile, true, "utf8"));
    System.out.println("Test Number: " + testSet.size()); // 样本个数

    long featuretime = System.currentTimeMillis();

    boolean acc = true;
    float error = 0;
    int senError = 0;
    int len = 0;
    boolean hasENG = false;
    int ENG_all = 0, ENG_right = 0;
    Loss loss = new HammingLoss();

    String[][] labelsSet = new String[testSet.size()][];
    String[][] targetSet = new String[testSet.size()][];
    LabelAlphabet labels = cl.getAlphabetFactory().buildLabelAlphabet(
        "labels");
    for (int i = 0; i < testSet.size(); i++) {
      Instance carrier = testSet.get(i);
      int[] pred = (int[]) cl.classify(carrier).getLabel(0);
      if (acc) {
        len += pred.length;
        double e = loss.calc(carrier.getTarget(), pred);
        error += e;
        if(e != 0)
          senError++;
        //测试中英混杂语料
        if(hasENG) {
          String[][] origin = (String[][])carrier.getSource();
          int[] target = (int[])carrier.getTarget();
          for(int j = 0; j < target.length; j++) {
            if(origin[j][0].contains("ENG")) {
              ENG_all++;
              if(target[j] == pred[j])
                ENG_right++;
            }
          }
        }
      }
      labelsSet[i] = labels.lookupString(pred);
      targetSet[i] = labels.lookupString((int[])carrier.getTarget());
    }

    long endtime = System.currentTimeMillis();
    System.out.println("totaltime\t" + (endtime - starttime) / 1000.0);
    System.out.println("feature\t" + (featuretime - starttime) / 1000.0);
    System.out.println("predict\t" + (endtime - featuretime) / 1000.0);
   
    if (acc) {
      System.out.println("Test Accuracy:\t" + (1 - error / len));
      System.out.println("Sentence Accuracy:\t" + ((double)(testSet.size() - senError) / testSet.size()));
      if(hasENG)
        System.out.println("ENG Accuracy:\t" + ((double)ENG_right / ENG_all));
    }

    if (output != null) {
View Full Code Here

    //将目标值对应的索引号作为类别
    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet())
    //建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[]{ngrampp,targetpp,sparsepp});
   
    InstanceSet instset = new InstanceSet(pp,af);
   
    //用不同的Reader读取相应格式的文件
    Reader reader = new FileReader(trainDataPath,"UTF-8",".data");
   
    //读入数据,并进行数据处理
    instset.loadThruStagePipes(reader);
    //将数据集分为训练是和测试集
    float percent = 0.8f;
    InstanceSet[] splitsets = instset.split(percent);
   
    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1]

    /**
     * 测试
     */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
     
      Integer gold = (Integer) data.getTarget();
      Predict<String> pres=bayes.classify(data, Type.STRING, 3);
      String pred_label=pres.getLabel();
//      String pred_label = bayes.getStringLabel(data);
      String gold_label = bayes.getLabel(gold);
     
      if(pred_label.equals(gold_label))
        System.out.println(pred_label+" : "+testset.getInstance(i).getSource());
      else
        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getSource());
      for(int j=0;j<3;j++)
        System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
    }
  }
View Full Code Here

     * Knn
     */
    System.out.print("\nKnn\n");
    System.out.print("\nReading data......\n");
    long time_mark=System.currentTimeMillis();
    InstanceSet instset = new InstanceSet(pp,af)
    Reader reader = new MyDocumentReader(trainDataPath,"gbk");
    instset.loadThruStagePipes(reader);
    System.out.print("..Reading data complete "+(System.currentTimeMillis()-time_mark)+"(ms)\n");
   
    //将数据集分为训练是和测试集
    System.out.print("Sspliting....");
    float percent = 0.9f;
    InstanceSet[] splitsets = instset.split(percent);
   
    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1]
    System.out.print("..Spliting complete!\n");
   
    System.out.print("Training Knn...\n");
    time_mark=System.currentTimeMillis();
    SparseVectorSimilarity sim=new SparseVectorSimilarity();
    pp.removeTargetPipe();
    KNNClassifier knn=new KNNClassifier(trainset, pp, sim, af, 9)
    af.setStopIncrement(true)
   
    ItemFrequency tf=new ItemFrequency(trainset);
    FeatureSelect fs=new FeatureSelect(tf.getFeatureSize());
    long time_train=System.currentTimeMillis()-time_mark;
   
    System.out.print("..Training compelte!\n");
    System.out.print("Saving model...\n");
    knn.saveTo(knnModelFile)
    knn = null;
    System.out.print("..Saving model compelte!\n");

   
    System.out.print("Loading model...\n");
    knn =KNNClassifier.loadFrom(knnModelFile);
    System.out.print("..Loading model compelte!\n");
    System.out.println("Testing Knn...\n");
    int count=0;
    fs.fS_CS(tf, 0.1f);
    knn.setFs(fs);
    for(int i=0;i<testset.size();i++){
      Instance data = testset.getInstance(i);
      Integer gold = (Integer) data.getTarget();
      Predict<String> pres=(Predict<String>) knn.classify(data, Type.STRING, 3);
      String pred_label=pres.getLabel();
      String gold_label = knn.getLabel(gold);
     
      if(pred_label.equals(gold_label)){
        //System.out.println(pred_label+" : "+testsetknn.getInstance(i).getTempData());
        count++;
      }
      else{
//        System.err.println(gold_label+"->"+pred_label+" : "+testset.getInstance(i).getTempData());
//        for(int j=0;j<3;j++)
//          System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
      }
    }
    int knnCount=count;
    System.out.println("..Testing Knn Complete");
    System.out.println("Knn Precision:"+((float)knnCount/testset.size())+"("+knnCount+"/"+testset.size()+")");
    knn.noFeatureSelection();
    int flag=0;
    long time_sum=0,time_times=0;
    float[] percents_cs=new float[]{1.0f,0.9f,0.8f,0.7f,0.5f,0.3f,0.2f,0.1f};
    int[] counts_cs=new int[10];
    for(int test=0;test<percents_cs.length;test++){
      long time_st=System.currentTimeMillis();
      System.out.println("Testing Bayes"+percents_cs[test]+"...");
      if(test!=0){
        fs.fS_CS(tf, percents_cs[test]);
        knn.setFs(fs);
      }
      count=0;
      for(int i=0;i<testset.size();i++){
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres=(Predict<String>)knn.classify(data, Type.STRING, 3);
        String pred_label=pres.getLabel();
        String gold_label = knn.getLabel(gold);
       
        if(pred_label.equals(gold_label)){
          count++;
        }
        else{
        }
      }
      counts_cs[test]=count;
      long time_ed=System.currentTimeMillis();
      time_sum+=time_ed-time_st;
      time_times++;
      System.out.println("Knn Precision("+percents_cs[test]+"):"
      +((float)count/testset.size())+"("+count+"/"+testset.size()+")"+"  "+(time_ed-time_st)+"ms");
    }
   
    knn.noFeatureSelection();
    float[] percents_csmax=new float[]{1.0f,0.9f,0.8f,0.7f,0.5f,0.3f,0.2f,0.1f};
    int[] counts_csmax=new int[10];
    for(int test=0;test<percents_csmax.length;test++){
      long time_st=System.currentTimeMillis();
      System.out.println("Testing Bayes"+percents_csmax[test]+"...");
      if(test!=0){
        fs.fS_CS_Max(tf, percents_cs[test]);
        knn.setFs(fs);
      }
      count=0;
      for(int i=0;i<testset.size();i++){
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres=(Predict<String>)knn.classify(data, Type.STRING, 3);
        String pred_label=pres.getLabel();
        String gold_label = knn.getLabel(gold);
       
        if(pred_label.equals(gold_label)){
          count++;
        }
        else{
        }
      }
      counts_csmax[test]=count;
      long time_ed=System.currentTimeMillis();
      time_sum+=time_ed-time_st;
      time_times++;
      System.out.println("Knn Precision("+percents_csmax[test]+"):"
      +((float)count/testset.size())+"("+count+"/"+testset.size()+")"+"  "+(time_ed-time_st)+"ms");
    }
    knn.noFeatureSelection();
    float[] percents_ig=new float[]{1.0f,0.9f,0.8f,0.7f,0.5f,0.3f,0.2f,0.1f};
    int[] counts_ig=new int[10];
    for(int test=0;test<percents_ig.length;test++){
      long time_st=System.currentTimeMillis();
      System.out.println("Testing Bayes"+percents_ig[test]+"...");
      if(test!=0){
        fs.fS_IG(tf, percents_cs[test]);
        knn.setFs(fs);
      }
      count=0;
      for(int i=0;i<testset.size();i++){
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres=(Predict<String>)knn.classify(data, Type.STRING, 3);
        String pred_label=pres.getLabel();
        String gold_label = knn.getLabel(gold);
       
        if(pred_label.equals(gold_label)){
          count++;
        }
        else{
        }
      }
      counts_ig[test]=count;

      long time_ed=System.currentTimeMillis();
      time_sum+=time_ed-time_st;
      time_times++;
      System.out.println("Knn Precision("+percents_ig[test]+"):"
      +((float)count/testset.size())+"("+count+"/"+testset.size()+")"+"  "+(time_ed-time_st)+"ms");
    }
   
    System.out.println("..Testing Bayes complete!");
    for(int i=0;i<percents_cs.length;i++)
      System.out.println("Knn Precision CS("+percents_cs[i]+"):"
    +((float)counts_cs[i]/testset.size())+"("+counts_cs[i]+"/"+testset.size()+")");
   
    for(int i=0;i<percents_csmax.length;i++)
      System.out.println("Knn Precision CS_Max("+percents_csmax[i]+"):"
    +((float)counts_csmax[i]/testset.size())+"("+counts_csmax[i]+"/"+testset.size()+")");
   
    for(int i=0;i<percents_ig.length;i++)
      System.out.println("Knn Precision IG("+percents_ig[i]+"):"
    +((float)counts_ig[i]/testset.size())+"("+counts_ig[i]+"/"+testset.size()+")");

    System.out.println("\nTrain time: "+time_train+"(ms) for "
        +trainset.size()+" train instances\n");
    if(time_times>0)
      System.out.println("Ave Test time: "+time_sum/time_times+"(ms) for "
          +testset.size()+" test instances\n");
  }
View Full Code Here

TOP

Related Classes of org.fnlp.ml.types.InstanceSet

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.