Package fr.lip6.jkernelmachines.example

Source Code of fr.lip6.jkernelmachines.example.CrossValidationExample

/**
    This file is part of JkernelMachines.

    JkernelMachines is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    JkernelMachines is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with JkernelMachines.  If not, see <http://www.gnu.org/licenses/>.

    Copyright David Picard - 2012

*/
package fr.lip6.jkernelmachines.example;

import java.util.List;

import fr.lip6.jkernelmachines.classifier.Classifier;
import fr.lip6.jkernelmachines.classifier.LaSVM;
import fr.lip6.jkernelmachines.classifier.LaSVMI;
import fr.lip6.jkernelmachines.classifier.NystromLSSVM;
import fr.lip6.jkernelmachines.classifier.SDCA;
import fr.lip6.jkernelmachines.classifier.SMOSVM;
import fr.lip6.jkernelmachines.evaluation.AccuracyEvaluator;
import fr.lip6.jkernelmachines.evaluation.RandomSplitCrossValidation;
import fr.lip6.jkernelmachines.io.ArffImporter;
import fr.lip6.jkernelmachines.io.LibSVMImporter;
import fr.lip6.jkernelmachines.kernel.Kernel;
import fr.lip6.jkernelmachines.kernel.typed.DoubleGaussL2;
import fr.lip6.jkernelmachines.kernel.typed.DoubleLinear;
import fr.lip6.jkernelmachines.projection.DoublePCA;
import fr.lip6.jkernelmachines.type.TrainingSample;
import fr.lip6.jkernelmachines.util.DebugPrinter;

/**
* <p>
* This is a more complex example of how to code with JKernelMachines than can
* be used as a stand-alone program.
* </p>
* <p>
* It reads data from an input file in the libsvm format, and performs a cross-validation evaluation.
* Optional parameters include the number of test to perform, the percentage of data to keep from training, the type of kernel, and the svm algorithm.
* Launch without argument to get the following help:<br/>
* <br />
* CrossValidationExample -f file [-p percent] [-n nbtests] [-k kernel] [-a algorithm] [-pca type] [-vvv]
* <ul>
<li>-f file: the data file in libsvm format</li>
<li>-p percent: the percent of data to keep for training</li>
<li>-n nbtests: the number of test to perform during crossvalidation</li>
<li>-k kernel: the type of kernel (linear or gauss, default gauss)</li>
<li>-a algorithm: type of SVM algorithm(lasvm, lasvmi, smo, nlssvm default lasvm)</li>
<li>-pca type: perform a PCA as preprocessing (no, yes, white, default no)</li>
<li>-v: verbose (v few, vv lot, vvv insane, default none)</li>
* </ul>
* </p>
*
* @author David Picard
*
*/
public class CrossValidationExample {

  /**
   * @param args
   */
  public static void main(String[] args) {

    String file = "";
    double percent = 0.8;
    int nbtest = 10;
    Kernel<double[]> kernel = null;
    Classifier<double[]> svm = null;
    int hasPCA = 0;

    // parsing options
    try {
      for (int i = 0; i < args.length; i++) {
        // parsing file
        if (args[i].equalsIgnoreCase("-f")) {
          i++;
          file = args[i];
        }
        // split percent
        else if (args[i].equalsIgnoreCase("-p")) {
          i++;
          percent = Double.parseDouble(args[i]);
        }
        // number of tests
        else if (args[i].equalsIgnoreCase("-n")) {
          i++;
          nbtest = Integer.parseInt(args[i]);
        }
        // kernel type
        else if (args[i].equalsIgnoreCase("-k")) {
          i++;

          if (args[i].equalsIgnoreCase("gauss")) {
            kernel = new DoubleGaussL2();
          } else {
            kernel = new DoubleLinear();
          }
        }
        // algorithm
        else if (args[i].equalsIgnoreCase("-a")) {
          i++;
          if (args[i].equalsIgnoreCase("lasvmi")) {
            svm = new LaSVMI<double[]>(kernel);
          } else if (args[i].equalsIgnoreCase("sdca")) {
            svm = new SDCA<double[]>(kernel);
          } else if (args[i].equalsIgnoreCase("smo")) {
            svm = new SMOSVM<double[]>(kernel);
          } else if (args[i].equalsIgnoreCase("nlssvm")) {
            NystromLSSVM<double[]> nlssvm = new NystromLSSVM<double[]>(kernel);
            nlssvm.setPercent(0.1);
            svm = nlssvm;
           
          else { // default lasvm
            svm = new LaSVM<double[]>(kernel);
          }
        }
        else if (args[i].equalsIgnoreCase("-pca")) {
          i++;
          if(args[i].equalsIgnoreCase("yes")) {
            hasPCA = 1;
          }
          else if(args[i].equalsIgnoreCase("white")) {
            hasPCA = 2;
          }
        }
        // verbose
        else if (args[i].equalsIgnoreCase("-v")) {
          DebugPrinter.DEBUG_LEVEL = 2;
        } else if (args[i].equalsIgnoreCase("-vv")) {
          DebugPrinter.DEBUG_LEVEL = 3;
        } else if (args[i].equalsIgnoreCase("-vvv")) {
          DebugPrinter.DEBUG_LEVEL = 4;
        }
      }

    } catch (Exception e) {
      printHelp();
      System.exit(-1);
    }

    // check option
    if (file.equalsIgnoreCase("") || kernel == null || svm == null) {
      printHelp();
      System.exit(-1);
    }

    // read data
    List<TrainingSample<double[]>> list = null;
    try {
      if(file.endsWith(".arff")) {
        list = ArffImporter.importFromFile(file);
      }
      else {
        list = LibSVMImporter.importFromFile(file);
      }
    } catch (Exception e) {
      System.out.println("Wrong data file");
      printHelp();
      System.exit(-1);
    }
    if (list == null) {
      System.out.println("Wrong data file");
      printHelp();
      System.exit(-1);
    }
    // perform PCA
    if(hasPCA == 1) {
      DoublePCA pca = new DoublePCA();
      pca.train(list);
      list = pca.projectList(list);
    }
    else if(hasPCA == 2) {
      DoublePCA pca = new DoublePCA();
      pca.train(list);
      list = pca.projectList(list, true);
    }

    // initialize CV
    AccuracyEvaluator<double[]> ev = new AccuracyEvaluator<double[]>();
    RandomSplitCrossValidation<double[]> cv = new RandomSplitCrossValidation<double[]>(
        svm, list, ev);
    cv.setNbTest(nbtest);
    cv.setTrainPercent(percent);

    // do cv
    cv.run();

    // print result
    System.out.println("Accuracy: " + cv.getAverageScore() + " +/- "
        + cv.getStdDevScore());
  }

  private static void printHelp() {
    System.out
        .println("CrossValidationExample -f file [-p percent] [-n nbtests] [-k kernel] [-a algorithm] [-vvv]");
    System.out.println("\t-f file: the data file in libsvm format");
    System.out
        .println("\t-p percent: the percent of data to keep for training");
    System.out
        .println("\t-n nbtests: the number of test to perform during crossvalidation");
    System.out
        .println("\t-k kernel: the type of kernel (linear or gauss, default gauss)");
    System.out
        .println("\t-a algorithm: type of SVM algorithm(lasvm, lasvmi, smo, default lasvm)");
    System.out
        .println("\t-pca type: perform a PCA as preprocessing(no, yes, white, default no");
    System.out
        .println("\t-v: verbose (v few, vv lot, vvv insane, default none)");
  }

}
TOP

Related Classes of fr.lip6.jkernelmachines.example.CrossValidationExample

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.