/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.cscw2013;
import etc.aloe.data.ExampleSet;
import etc.aloe.data.Model;
import etc.aloe.processes.FeatureWeighting;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import weka.classifiers.Classifier;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.CostSensitiveClassifier;
import weka.core.Instances;
/**
* Extracts top features and feature weights from a linear support vector
* machine (SMO) classifier.
*
* Also works with a CostSensitiveClassifier wrapping an SMO.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class SMOFeatureWeighting implements FeatureWeighting {
@Override
public List<String> getTopFeatures(ExampleSet trainingExamples, Model model, int topN) {
List<Map.Entry<String, Double>> weights = getFeatureWeights(trainingExamples, model);
Collections.sort(weights, new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
return -Double.compare(o1.getValue() * o1.getValue(), o2.getValue() * o2.getValue());
}
});
List<String> result = new ArrayList<String>();
for (int i = 0; i < topN && i < weights.size(); i++) {
Map.Entry<String, Double> entry = weights.get(i);
result.add(entry.getKey());
}
return result;
}
@Override
public List<Entry<String, Double>> getFeatureWeights(ExampleSet trainingExamples, Model model) {
WekaModel wekaModel = (WekaModel) model;
Classifier classifier = wekaModel.getClassifier();
Instances dataFormat = trainingExamples.getInstances();
SMO smo = getSMO(classifier);
double[] sparseWeights = smo.sparseWeights()[0][1];
int[] sparseIndices = smo.sparseIndices()[0][1];
Map<String, Double> weights = new HashMap<String, Double>();
for (int i = 0; i < sparseWeights.length; i++) {
int index = sparseIndices[i];
double weight = sparseWeights[i];
String name = dataFormat.attribute(index).name();
weights.put(name, weight);
}
List<Map.Entry<String, Double>> entries = new ArrayList<Map.Entry<String, Double>>(weights.entrySet());
Collections.sort(entries, new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
return o1.getKey().compareTo(o2.getKey());
}
});
return entries;
}
/**
* Given a classifier, attempts to cast it to an SMO or get the contained
* SMO.
*
* @param classifier
* @return
*/
private SMO getSMO(Classifier classifier) {
if (classifier instanceof CostSensitiveClassifier) {
classifier = ((CostSensitiveClassifier) classifier).getClassifier();
}
SMO smo = null;
if (classifier instanceof SMO) {
smo = (SMO) classifier;
} else {
throw new IllegalArgumentException("Classifier was neither SMO or CostSensitiveClassifier(SMO)");
}
return smo;
}
}