Package com.digitalpebble.classification.util

Source Code of com.digitalpebble.classification.util.WeightedAttribute

/**
* Copyright 2009 DigitalPebble Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package com.digitalpebble.classification.util;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import com.digitalpebble.classification.Document;
import com.digitalpebble.classification.Lexicon;
import com.digitalpebble.classification.TextClassifier;

import de.bwaldvogel.liblinear.Model;

public class ModelUtils {

  /**
   * Prints out the attributes and their weights from the models generated by
   * liblinear. This is different from CorpusUtils.dumpBestAttributes which
   * computes a score for the attributes regardless of the model.
   **/

  public static void getAttributeScores(String modelPath, String lexiconF,
      int topAttributesNumber) throws IOException {
    // load the model + the lexicon
    // try to see if we can get a list of the best scores from the model
    // works only for liblinear
    Lexicon lexicon = new Lexicon(lexiconF);
    Model liblinearModel = Model.load(new File(modelPath));
    double[] weights = liblinearModel.getFeatureWeights();
    // dump all the weights
    int numClasses = liblinearModel.getNrClass();
    int numFeatures = liblinearModel.getNrFeature();

    Map<Integer, String> invertedAttributeIndex = lexicon
        .getInvertedIndex();

    Map<String, WeightedAttributeQueue> topAttributesPerLabel = new HashMap<String, WeightedAttributeQueue>(
        numClasses);

    // for (int i = 0; i < nr_w; i++) {
    // double contrib = w[(idx - 1) * nr_w + i] * lx.getValue();
    // }
    //
    // idx 1 in class 1 -> 0 x 22 + 0 = 0
    // idx 2 in class 1 -> 1 x 22 + 0 = 22
    // idx 1 in class 2 -> 0 x 22 + 1 = 1
    // idx 2 in class 2 -> 1 x 22 + 1 = 23

    // initialise the queues
    if (topAttributesNumber != -1) {
      for (int classNum = 0; classNum < numClasses; classNum++) {
        String classLabel = lexicon.getLabel(classNum);
        WeightedAttributeQueue queue = new WeightedAttributeQueue(
            topAttributesNumber);
        topAttributesPerLabel.put(classLabel, queue);
      }
    }

    for (int classNum = 0; classNum < numClasses; classNum++) {
      String classLabel = lexicon.getLabel(classNum);
      WeightedAttributeQueue queue = topAttributesPerLabel
          .get(classLabel);
      for (int featNum = 0; featNum < numFeatures; featNum++) {
        int pos = featNum * numClasses + classNum;
        double featWeight = weights[pos];
        String attLabel = invertedAttributeIndex.get(featNum + 1);

        // display the values between -0.001 and +0.001 as 0
        if (featWeight < 0.001 && featWeight > -0.001)
          featWeight = 0;
        // want to limit to the top n terms?
        if (topAttributesNumber != -1) {
          WeightedAttribute wa = new WeightedAttribute(attLabel,
              featWeight);
          queue.insertWithOverflow(wa);
          continue;
        }

        System.out.println(attLabel + "\t" + classLabel + "\t"
            + featWeight);
      }
    }

    // dump the attributes per label
    if (topAttributesNumber < 1)
      return;

    Iterator<String> labelIter = topAttributesPerLabel.keySet().iterator();
    while (labelIter.hasNext()) {
      String label = (String) labelIter.next();
      System.out.println("LABEL : " + label);
      WeightedAttributeQueue queue = topAttributesPerLabel.get(label);
      // revert the order
      String[] sorted = new String[queue.size()];

      for (int i = queue.size() - 1; i >= 0; i--) {
        WeightedAttribute wa = queue.pop();
        sorted[i] = wa.label + " : " + wa.weight;
      }
     
      for (int j=0;j<sorted.length;j++){
        System.out.println(j+1 + "\t" + sorted[j]);
      }
    }
  }

  public static void main(String[] args) {
    if (args.length < 2) {
      StringBuffer buffer = new StringBuffer();
      buffer.append("ModelUtils : \n");
      buffer.append("\t -getAttributeScores modelFile lexicon [topAttributesThreshold]\n");
      buffer.append("\t -classifyTextFile resourceDir input\n");
      System.out.println(buffer.toString());
      return;
    }

    else if (args[0].equalsIgnoreCase("-getAttributeScores")) {
      String model = args[1];
      String lexicon = args[2];
      int topAttributesNumber = -1;
      if (args.length > 3) {
        topAttributesNumber = Integer.parseInt(args[3]);
      }
      try {
        getAttributeScores(model, lexicon, topAttributesNumber);
      } catch (Exception e) {
        e.printStackTrace();
      }
    }

    else if (args[0].equalsIgnoreCase("-classifyTextFile")) {
      String resourceDir = args[1];
      String inputTextFile = args[2];

      // load text file as String
      StringBuffer text = new StringBuffer();
      String line;
      BufferedReader br;
      try {
        br = new BufferedReader(new FileReader(new File(inputTextFile)));

        while ((line = br.readLine()) != null) {
          text.append(line).append("\n");
        }

        br.close();

        // load classifier
        TextClassifier classifier = TextClassifier
            .getClassifier(new File(resourceDir));
        // create a document from a String
        String[] tokens = Tokenizer.tokenize(text.toString(), true);
        Document doc = classifier.createDocument(tokens);
        // classify
        double[] scores = classifier.classify(doc);
        // get best label
        String label = classifier.getBestLabel(scores);
        System.out.println("Classified as : " + label);
      } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
      }
    }

  }

}

class WeightedAttribute {

  String label;
  double weight;

  WeightedAttribute(String l, double w) {
    label = l;
    weight = w;
  }

  // rank the attributes based on weight
  public int compareTo(WeightedAttribute att) {
    double absoltarget = att.weight;
    // if (absoltarget < 0)
    // absoltarget = -absoltarget;
    double absolsource = this.weight;
    // if (absolsource < 0)
    // absolsource = -absolsource;
    double diff = absolsource - absoltarget;
    if (diff < 0)
      return -1;
    if (diff > 0)
      return +1;
    return 0;
  }

}

/**
* Sorts SuggestWord instances
*
*/
final class WeightedAttributeQueue extends PriorityQueue<WeightedAttribute> {

  public WeightedAttributeQueue(int size) {
    initialize(size);
  }

  @Override
  protected final boolean lessThan(WeightedAttribute wa, WeightedAttribute wb) {
    int val = wa.compareTo(wb);
    return val < 0;
  }
}
TOP

Related Classes of com.digitalpebble.classification.util.WeightedAttribute

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.