/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package seekfeel.miners.supervised;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import seekfeel.utilities.PropertiesGetter;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LibSVM;
import weka.classifiers.functions.SPegasos;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.core.converters.ArffSaver;
/**
*
* @author Ahmed
*/
public class WekaWrapper {
public void train(ArrayList<LinkedHashMap<Integer, Double>> trainingSamples, ArrayList<Integer> labels, int maxNumFeats) {
Instances trainingSet = fillDataSet(trainingSamples, labels, maxNumFeats);
LibSVM cModel = new LibSVM();
try {
cModel.buildClassifier(trainingSet);
saveTheModel(cModel, maxNumFeats);
ArffSaver saver = new ArffSaver();
saver.setInstances(trainingSet);
saver.setFile(new File(PropertiesGetter.getProperty("TrainingDataSetObject")));
saver.writeBatch();
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
}
}
public Evaluation crossValidate(ArrayList<LinkedHashMap<Integer, Double>> trainingSamples, ArrayList<Integer> labels, int numFolds, int numFeats) {
Instances trainingSet = fillDataSet(trainingSamples, labels, numFeats);
Instances initial = trainingSet.stringFreeStructure();
Evaluation eTest = null;
SPegasos cModel = new SPegasos();
try {
cModel.buildClassifier(initial);
for (int i = 0; i < trainingSet.numInstances(); i++) {
cModel.updateClassifier(trainingSet.instance(i));
}
eTest = new Evaluation(trainingSet);
if (numFolds > trainingSet.numInstances()) {
numFolds = trainingSet.numInstances();
}
eTest.crossValidateModel(cModel, trainingSet, numFolds, new Random(1));
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
}
return eTest;
}
private Instances fillDataSet(ArrayList<LinkedHashMap<Integer, Double>> trainingSamples, ArrayList<Integer> labels, int maxNumFeats) {
// Create an empty training set
Instances trainingSet = initializeDataSet(maxNumFeats, trainingSamples.size());
int numAtts = maxNumFeats + 1;
Iterator<Entry<Integer, Double>> currentIt;
Entry<Integer, Double> tempEntry;
Instance iExample;
int i = 0;
Instance originalCopy = getNewInstance(numAtts);
for (LinkedHashMap<Integer, Double> currentSample : trainingSamples) {
iExample = (Instance) originalCopy.copy();
iExample.setDataset(trainingSet);
currentIt = currentSample.entrySet().iterator();
while (currentIt.hasNext()) {
tempEntry = currentIt.next();
iExample.setValue(tempEntry.getKey(), tempEntry.getValue());
}
iExample.setValue(numAtts - 1, labels.get(i) == 0 ? "positive" : "negative");
i++;
trainingSet.add(iExample);
}
return trainingSet;
}
public double classify(LinkedHashMap<Integer, Double> sample) {
try {
Classifier cModel = (LibSVM) loadClassifier();
ArffLoader loader = new ArffLoader();
loader.setFile(new File(PropertiesGetter.getProperty("TrainingDataSetObject")));
Instances structure = loader.getStructure();
structure.setClassIndex(structure.numAttributes() - 1);
int numAtts = structure.numAttributes();
Instance instanceToClassify = getNewInstance(numAtts);
Iterator<Entry<Integer, Double>> currentIt = sample.entrySet().iterator();
Entry<Integer, Double> tempEntry;
while (currentIt.hasNext()) {
tempEntry = currentIt.next();
instanceToClassify.setValue(tempEntry.getKey(), tempEntry.getValue());
}
instanceToClassify.setDataset(structure);
instanceToClassify.setClassMissing();
try {
return cModel.classifyInstance(instanceToClassify);
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
return -2;
}
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
return -2;
}
}
private Instances initializeDataSet(int numFeatures, int neededCapacity) {
FastVector fvAllAttributes = new FastVector(numFeatures + 1);
Attribute tempAtt;
for (int i = 0; i < numFeatures; i++) {
tempAtt = new Attribute("att" + i);
fvAllAttributes.addElement(tempAtt);
}
FastVector fvClassVal = new FastVector(2);
fvClassVal.addElement("positive");
fvClassVal.addElement("negative");
Attribute ClassAttribute = new Attribute("theClass", fvClassVal);
fvAllAttributes.addElement(ClassAttribute);
// Create an empty training set
Instances emptySet = new Instances("dataSet", fvAllAttributes, neededCapacity);
emptySet.setClassIndex(numFeatures);
return emptySet;
}
private void saveTheModel(Classifier csfr, int currentNumFeats) {
OutputStream outStream = null;
try {
outStream = new FileOutputStream(PropertiesGetter.getProperty("SVMModelFile"));
ObjectOutputStream objectOutputStream = new ObjectOutputStream(outStream);
objectOutputStream.writeObject(csfr);
BufferedWriter bw = new BufferedWriter(new FileWriter(PropertiesGetter.getProperty("NumFeatsFile")));
bw.write(Integer.toString(currentNumFeats));
bw.close();
objectOutputStream.close();
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
} finally {
try {
outStream.close();
} catch (IOException ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
}
}
}
private Instance getNewInstance(int numAttributes) {
Instance newInst = new Instance(numAttributes);
for (int j = 0; j < numAttributes; j++) {
newInst.setValue(j, 0.0);
}
return newInst;
}
private Classifier loadClassifier() {
Classifier cModel = null;
InputStream is = null;
try {
ObjectInputStream objectInputStream = null;
is = new FileInputStream(PropertiesGetter.getProperty("SVMModelFile"));
objectInputStream = new ObjectInputStream(is);
cModel = (LibSVM) objectInputStream.readObject();
objectInputStream.close();
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
} finally {
try {
is.close();
} catch (IOException ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
}
}
return cModel;
}
private int getNumFeatures() {
int numFeats = 0;
try {
BufferedReader br = new BufferedReader(new FileReader(PropertiesGetter.getProperty("NumFeatsFile")));
numFeats = Integer.parseInt(br.readLine());
} catch (Exception ex) {
Logger.getLogger(WekaWrapper.class.getName()).log(Level.SEVERE, null, ex);
System.out.println(ex.toString());
}
return numFeats;
}
}