/**
* Copyright 2009 DigitalPebble Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.digitalpebble.classification.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import com.digitalpebble.classification.Document;
import com.digitalpebble.classification.Lexicon;
import com.digitalpebble.classification.TextClassifier;
import de.bwaldvogel.liblinear.Model;
public class ModelUtils {
/**
* Prints out the attributes and their weights from the models generated by
* liblinear. This is different from CorpusUtils.dumpBestAttributes which
* computes a score for the attributes regardless of the model.
**/
public static void getAttributeScores(String modelPath, String lexiconF,
int topAttributesNumber) throws IOException {
// load the model + the lexicon
// try to see if we can get a list of the best scores from the model
// works only for liblinear
Lexicon lexicon = new Lexicon(lexiconF);
Model liblinearModel = Model.load(new File(modelPath));
double[] weights = liblinearModel.getFeatureWeights();
// dump all the weights
int numClasses = liblinearModel.getNrClass();
int numFeatures = liblinearModel.getNrFeature();
Map<Integer, String> invertedAttributeIndex = lexicon
.getInvertedIndex();
Map<String, WeightedAttributeQueue> topAttributesPerLabel = new HashMap<String, WeightedAttributeQueue>(
numClasses);
// for (int i = 0; i < nr_w; i++) {
// double contrib = w[(idx - 1) * nr_w + i] * lx.getValue();
// }
//
// idx 1 in class 1 -> 0 x 22 + 0 = 0
// idx 2 in class 1 -> 1 x 22 + 0 = 22
// idx 1 in class 2 -> 0 x 22 + 1 = 1
// idx 2 in class 2 -> 1 x 22 + 1 = 23
// initialise the queues
if (topAttributesNumber != -1) {
for (int classNum = 0; classNum < numClasses; classNum++) {
String classLabel = lexicon.getLabel(classNum);
WeightedAttributeQueue queue = new WeightedAttributeQueue(
topAttributesNumber);
topAttributesPerLabel.put(classLabel, queue);
}
}
for (int classNum = 0; classNum < numClasses; classNum++) {
String classLabel = lexicon.getLabel(classNum);
WeightedAttributeQueue queue = topAttributesPerLabel
.get(classLabel);
for (int featNum = 0; featNum < numFeatures; featNum++) {
int pos = featNum * numClasses + classNum;
double featWeight = weights[pos];
String attLabel = invertedAttributeIndex.get(featNum + 1);
// display the values between -0.001 and +0.001 as 0
if (featWeight < 0.001 && featWeight > -0.001)
featWeight = 0;
// want to limit to the top n terms?
if (topAttributesNumber != -1) {
WeightedAttribute wa = new WeightedAttribute(attLabel,
featWeight);
queue.insertWithOverflow(wa);
continue;
}
System.out.println(attLabel + "\t" + classLabel + "\t"
+ featWeight);
}
}
// dump the attributes per label
if (topAttributesNumber < 1)
return;
Iterator<String> labelIter = topAttributesPerLabel.keySet().iterator();
while (labelIter.hasNext()) {
String label = (String) labelIter.next();
System.out.println("LABEL : " + label);
WeightedAttributeQueue queue = topAttributesPerLabel.get(label);
// revert the order
String[] sorted = new String[queue.size()];
for (int i = queue.size() - 1; i >= 0; i--) {
WeightedAttribute wa = queue.pop();
sorted[i] = wa.label + " : " + wa.weight;
}
for (int j=0;j<sorted.length;j++){
System.out.println(j+1 + "\t" + sorted[j]);
}
}
}
public static void main(String[] args) {
if (args.length < 2) {
StringBuffer buffer = new StringBuffer();
buffer.append("ModelUtils : \n");
buffer.append("\t -getAttributeScores modelFile lexicon [topAttributesThreshold]\n");
buffer.append("\t -classifyTextFile resourceDir input\n");
System.out.println(buffer.toString());
return;
}
else if (args[0].equalsIgnoreCase("-getAttributeScores")) {
String model = args[1];
String lexicon = args[2];
int topAttributesNumber = -1;
if (args.length > 3) {
topAttributesNumber = Integer.parseInt(args[3]);
}
try {
getAttributeScores(model, lexicon, topAttributesNumber);
} catch (Exception e) {
e.printStackTrace();
}
}
else if (args[0].equalsIgnoreCase("-classifyTextFile")) {
String resourceDir = args[1];
String inputTextFile = args[2];
// load text file as String
StringBuffer text = new StringBuffer();
String line;
BufferedReader br;
try {
br = new BufferedReader(new FileReader(new File(inputTextFile)));
while ((line = br.readLine()) != null) {
text.append(line).append("\n");
}
br.close();
// load classifier
TextClassifier classifier = TextClassifier
.getClassifier(new File(resourceDir));
// create a document from a String
String[] tokens = Tokenizer.tokenize(text.toString(), true);
Document doc = classifier.createDocument(tokens);
// classify
double[] scores = classifier.classify(doc);
// get best label
String label = classifier.getBestLabel(scores);
System.out.println("Classified as : " + label);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
class WeightedAttribute {
String label;
double weight;
WeightedAttribute(String l, double w) {
label = l;
weight = w;
}
// rank the attributes based on weight
public int compareTo(WeightedAttribute att) {
double absoltarget = att.weight;
// if (absoltarget < 0)
// absoltarget = -absoltarget;
double absolsource = this.weight;
// if (absolsource < 0)
// absolsource = -absolsource;
double diff = absolsource - absoltarget;
if (diff < 0)
return -1;
if (diff > 0)
return +1;
return 0;
}
}
/**
* Sorts SuggestWord instances
*
*/
final class WeightedAttributeQueue extends PriorityQueue<WeightedAttribute> {
public WeightedAttributeQueue(int size) {
initialize(size);
}
@Override
protected final boolean lessThan(WeightedAttribute wa, WeightedAttribute wb) {
int val = wa.compareTo(wb);
return val < 0;
}
}