Package seekfeel.miners.supervised

Source Code of seekfeel.miners.supervised.LibSVMWrapper

/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package seekfeel.miners.supervised;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import seekfeel.utilities.PropertiesGetter;

/**
*
* @author Ahmed
*/
public class LibSVMWrapper {

    public LibSVMWrapper() {
    }

    public void train(ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
        //
        // Training Steps
        //
        // 1- Form The File with SVM format
        // 2- loading the file using dataset loader into a problem object
        // 3- perform grid selection
        // 4- traing using the selected parameters and save into the model
        svm_problem prob = constructProblem(labels, trainingSamples);
        svm_parameter param = new svm_parameter();
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.RBF;
        param.gamma = 0.5;
        param.C = 1.0;
        svm_model theModel = svm.svm_train(prob, param);
        try {
            svm.svm_save_model(PropertiesGetter.getProperty("SVMModelFile"), theModel);
        } catch (IOException ex) {
            Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public ArrayList<Double> classify(ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
        ArrayList<Double> labels = new ArrayList<Double>();
        for (int i = 0; i < trainingSamples.size(); i++) {
            labels.add(classify(trainingSamples.get(i)));
        }
        return labels;
    }

    public double classify(LinkedHashMap<Integer, Double> trainingSample) {
        try {
            svm_model currentModel = svm.svm_load_model(PropertiesGetter.getProperty("SVMModelFile"));
            Iterator<Entry<Integer, Double>> it = trainingSample.entrySet().iterator();
            int m = trainingSample.size();
            svm_node[] x = new svm_node[m];
            Entry<Integer, Double> currentEntry;
            for (int j = 0; j < m; j++) {
                currentEntry = it.next();
                x[j] = new svm_node();
                x[j].index = currentEntry.getKey();
                x[j].value = currentEntry.getValue();
            }
            return svm.svm_predict(currentModel, x);
        } catch (IOException ex) {
            Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
            return -1;
        }
    }

    public static void writeTrainingFile(String trainingFileName, ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
        try {
            BufferedWriter bw = new BufferedWriter(new FileWriter(trainingFileName));
            StringBuilder currentLine;
            LinkedHashMap<Integer, Double> currentFeatures;
            Iterator<Entry<Integer, Double>> it;
            Entry<Integer, Double> currentEntry;
            for (int i = 0; i < labels.size(); i++) {
                currentLine = new StringBuilder();
                currentLine.append(labels.get(i));
                currentFeatures = trainingSamples.get(i);
                it = currentFeatures.entrySet().iterator();
                while (it.hasNext()) {
                    currentEntry = it.next();
                    currentLine.append(" ").append(currentEntry.getKey()).append(":").append(currentEntry.getValue());
                }
                bw.write(currentLine.toString());
                bw.newLine();
            }
        } catch (IOException ex) {
            Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    private svm_problem constructProblem(ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {

        double[] y = new double[labels.size()];
        for (int i = 0; i < labels.size(); i++) {
            y[i] = labels.get(i);
        }
        int max_index = 0;
        HashMap<Integer, Double> currentFeatures;
        Iterator<Entry<Integer, Double>> it;
        Entry<Integer, Double> currentEntry;
        int m;
        svm_node[] x;
        svm_node[][] allSamples = new svm_node[labels.size()][];
        for (int i = 0; i < labels.size(); i++) {
            currentFeatures = trainingSamples.get(i);
            it = currentFeatures.entrySet().iterator();
            m = currentFeatures.size();
            x = new svm_node[m];
            for (int j = 0; j < m; j++) {
                currentEntry = it.next();
                x[j] = new svm_node();
                x[j].index = currentEntry.getKey();
                x[j].value = currentEntry.getValue();
            }
            if (m > 0) {
                max_index = Math.max(max_index, x[m - 1].index);
            }
            allSamples[i] = x;
        }
        svm_problem newProblem = new svm_problem();
        newProblem.l = labels.size();
        newProblem.y = y;
        newProblem.x = allSamples;
        return newProblem;
    }
}
TOP

Related Classes of seekfeel.miners.supervised.LibSVMWrapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.