/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package seekfeel.miners.supervised;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import seekfeel.utilities.PropertiesGetter;
/**
*
* @author Ahmed
*/
public class LibSVMWrapper {
public LibSVMWrapper() {
}
public void train(ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
//
// Training Steps
//
// 1- Form The File with SVM format
// 2- loading the file using dataset loader into a problem object
// 3- perform grid selection
// 4- traing using the selected parameters and save into the model
svm_problem prob = constructProblem(labels, trainingSamples);
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.RBF;
param.gamma = 0.5;
param.C = 1.0;
svm_model theModel = svm.svm_train(prob, param);
try {
svm.svm_save_model(PropertiesGetter.getProperty("SVMModelFile"), theModel);
} catch (IOException ex) {
Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
}
}
public ArrayList<Double> classify(ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
ArrayList<Double> labels = new ArrayList<Double>();
for (int i = 0; i < trainingSamples.size(); i++) {
labels.add(classify(trainingSamples.get(i)));
}
return labels;
}
public double classify(LinkedHashMap<Integer, Double> trainingSample) {
try {
svm_model currentModel = svm.svm_load_model(PropertiesGetter.getProperty("SVMModelFile"));
Iterator<Entry<Integer, Double>> it = trainingSample.entrySet().iterator();
int m = trainingSample.size();
svm_node[] x = new svm_node[m];
Entry<Integer, Double> currentEntry;
for (int j = 0; j < m; j++) {
currentEntry = it.next();
x[j] = new svm_node();
x[j].index = currentEntry.getKey();
x[j].value = currentEntry.getValue();
}
return svm.svm_predict(currentModel, x);
} catch (IOException ex) {
Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
return -1;
}
}
public static void writeTrainingFile(String trainingFileName, ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
try {
BufferedWriter bw = new BufferedWriter(new FileWriter(trainingFileName));
StringBuilder currentLine;
LinkedHashMap<Integer, Double> currentFeatures;
Iterator<Entry<Integer, Double>> it;
Entry<Integer, Double> currentEntry;
for (int i = 0; i < labels.size(); i++) {
currentLine = new StringBuilder();
currentLine.append(labels.get(i));
currentFeatures = trainingSamples.get(i);
it = currentFeatures.entrySet().iterator();
while (it.hasNext()) {
currentEntry = it.next();
currentLine.append(" ").append(currentEntry.getKey()).append(":").append(currentEntry.getValue());
}
bw.write(currentLine.toString());
bw.newLine();
}
} catch (IOException ex) {
Logger.getLogger(LibSVMWrapper.class.getName()).log(Level.SEVERE, null, ex);
}
}
private svm_problem constructProblem(ArrayList<Integer> labels, ArrayList<LinkedHashMap<Integer, Double>> trainingSamples) {
double[] y = new double[labels.size()];
for (int i = 0; i < labels.size(); i++) {
y[i] = labels.get(i);
}
int max_index = 0;
HashMap<Integer, Double> currentFeatures;
Iterator<Entry<Integer, Double>> it;
Entry<Integer, Double> currentEntry;
int m;
svm_node[] x;
svm_node[][] allSamples = new svm_node[labels.size()][];
for (int i = 0; i < labels.size(); i++) {
currentFeatures = trainingSamples.get(i);
it = currentFeatures.entrySet().iterator();
m = currentFeatures.size();
x = new svm_node[m];
for (int j = 0; j < m; j++) {
currentEntry = it.next();
x[j] = new svm_node();
x[j].index = currentEntry.getKey();
x[j].value = currentEntry.getValue();
}
if (m > 0) {
max_index = Math.max(max_index, x[m - 1].index);
}
allSamples[i] = x;
}
svm_problem newProblem = new svm_problem();
newProblem.l = labels.size();
newProblem.y = y;
newProblem.x = allSamples;
return newProblem;
}
}