Package com.digitalpebble.classification.util

Source Code of com.digitalpebble.classification.util.WeightedAttribute

/**
* Copyright 2009 DigitalPebble Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package com.digitalpebble.classification.util;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import com.digitalpebble.classification.Document;
import com.digitalpebble.classification.Lexicon;
import com.digitalpebble.classification.TextClassifier;

import de.bwaldvogel.liblinear.Model;

public class ModelUtils {

  /**
   * Prints out the attributes and their weights from the models generated by
   * liblinear. This is different from CorpusUtils.dumpBestAttributes which
   * computes a score for the attributes regardless of the model.
   **/

  public static void getAttributeScores(String modelPath, String lexiconF,
      int topAttributesNumber) throws IOException {
    // load the model + the lexicon
    // try to see if we can get a list of the best scores from the model
    // works only for liblinear
    Lexicon lexicon = new Lexicon(lexiconF);
    Model liblinearModel = Model.load(new File(modelPath));
    double[] weights = liblinearModel.getFeatureWeights();
    // dump all the weights
    int numClasses = liblinearModel.getNrClass();
    int numFeatures = liblinearModel.getNrFeature();

    Map<Integer, String> invertedAttributeIndex = lexicon
        .getInvertedIndex();

    Map<String, WeightedAttributeQueue> topAttributesPerLabel = new HashMap<String, WeightedAttributeQueue>(
        numClasses);

    for (int i = 0; i < weights.length; i++) {
      // get current class num
      int classNum = i / numFeatures;
      int featNum = i % numFeatures;
      String classLabel = lexicon.getLabel(classNum);
      String attLabel = invertedAttributeIndex.get(featNum + 1);

      // display the values between -0.001 and +0.001 as 0
      if (weights[i] < 0.001 && weights[i] > -0.001)
        weights[i] = 0;

      // want to limit to the top n terms?
      if (topAttributesNumber != -1) {
        WeightedAttributeQueue queue = topAttributesPerLabel
            .get(classLabel);
        if (queue == null) {
          queue = new WeightedAttributeQueue(topAttributesNumber);
          topAttributesPerLabel.put(classLabel, queue);
        }
        WeightedAttribute wa = new WeightedAttribute(attLabel,
            weights[i]);
        queue.insertWithOverflow(wa);
        continue;
      }

      System.out
          .println(attLabel + "\t" + classLabel + "\t" + weights[i]);
    }

    // dump the attributes per label
    if (topAttributesNumber < 1)
      return;

    Iterator labelIter = topAttributesPerLabel.keySet().iterator();
    while (labelIter.hasNext()) {
      String label = (String) labelIter.next();
      System.out.println("LABEL : " + label);
      WeightedAttributeQueue queue = topAttributesPerLabel.get(label);
      for (int i = queue.size() - 1; i >= 0; i--) {
        WeightedAttribute wa = queue.pop();
        System.out.println((topAttributesNumber - i) + "\t" + wa.label
            + " : " + wa.weight);
      }
    }
  }

  public static void main(String[] args) {
    if (args.length < 2) {
      StringBuffer buffer = new StringBuffer();
      buffer.append("ModelUtils : \n");
      buffer.append("\t -getAttributeScores modelFile lexicon [topAttributesThreshold]\n");
      buffer.append("\t -classifyTextFile resourceDir input\n");
      System.out.println(buffer.toString());
      return;
    }

    else if (args[0].equalsIgnoreCase("-getAttributeScores")) {
      String model = args[1];
      String lexicon = args[2];
      int topAttributesNumber = -1;
      if (args.length > 3) {
        topAttributesNumber = Integer.parseInt(args[3]);
      }
      try {
        getAttributeScores(model, lexicon, topAttributesNumber);
      } catch (Exception e) {
        e.printStackTrace();
      }
    }

    else if (args[0].equalsIgnoreCase("-classifyTextFile")) {
      String resourceDir = args[1];
      String inputTextFile = args[2];

      // load text file as String
      StringBuffer text = new StringBuffer();
      String line;
      BufferedReader br;
      try {
        br = new BufferedReader(new FileReader(new File(inputTextFile)));

        while ((line = br.readLine()) != null) {
          text.append(line).append("\n");
        }
       
        br.close();

        // load classifier
        TextClassifier classifier = TextClassifier
            .getClassifier(new File(resourceDir));
        // create a document from a String
        String[] tokens = Tokenizer.tokenize(text.toString(), true);
        Document doc = classifier.createDocument(tokens);
        // classify
        double[] scores = classifier.classify(doc);
        // get best label
        String label = classifier.getBestLabel(scores);
        System.out.println("Classified as : "+label);
      } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
      }
    }

  }

}

class WeightedAttribute {

  String label;
  double weight;

  WeightedAttribute(String l, double w) {
    label = l;
    weight = w;
  }

  // rank the attributes based on weight
  public int compareTo(WeightedAttribute att) {
    double absoltarget = att.weight;
    if (absoltarget < 0)
      absoltarget = -absoltarget;
    double absolsource = this.weight;
    if (absolsource < 0)
      absolsource = -absolsource;
    double diff = absolsource - absoltarget;
    if (diff < 0)
      return -1;
    if (diff > 0)
      return 1;
    return 0;
  }

}

/**
* Sorts SuggestWord instances
*
*/
final class WeightedAttributeQueue extends PriorityQueue<WeightedAttribute> {

  public WeightedAttributeQueue(int size) {
    initialize(size);
  }

  @Override
  protected final boolean lessThan(WeightedAttribute wa, WeightedAttribute wb) {
    int val = wa.compareTo(wb);
    return val < 0;
  }
}
TOP

Related Classes of com.digitalpebble.classification.util.WeightedAttribute

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.