Package org.sf.mustru.test

Source Code of org.sf.mustru.test.TestMClassifier

package org.sf.mustru.test;

import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;

import com.aliasi.util.Files;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.LMClassifier;
import com.aliasi.classify.JointClassification;

import org.sf.mustru.utils.Constants;
import org.sf.mustru.utils.StringTools;

import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

import java.io.IOException;
import java.text.NumberFormat;

/**
* Test classification using the Lingpipe classifier with a collection
* of Reuters documents. Three categories with 10-15 documents each
* are used for training.
*/
public class TestMClassifier
    //*-- define the training and testing directories and list of categories
    //*-- three categories - sugar, coffee, and cocoa
    private static String MUSTRU_HOME =  Constants.MUSTRU_HOME;
    private static File TESTING_DIR =  new File(
        MUSTRU_HOME + File.separator + "data" + File.separator + "testing" + File.separator + "tcat");
    private static String[] CATEGORIES = {"sugar", "coffee", "cocoa", "misc"};
   
    /**
     * Test the classifier using the files provided in the test/category directories
     */ 
    public static void main(String[] args) throws ClassNotFoundException, IOException
    {
      PropertyConfigurator.configure (Constants.LOG4J_FILE);
      Logger logger = Logger.getLogger(TestMClassifier.class.getName());
      logger.debug("Started TestClassifier");

      //*-- read the classification model
      String modelFile = MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat" + File.separator + "tcat_classifier";
      ObjectInputStream oi = new ObjectInputStream( new FileInputStream(modelFile) );
      LMClassifier compiledClassifier = (LMClassifier) oi.readObject();
      oi.close();
     
      //*-- loop through the identical categories and test the classification of test documents
    ConfusionMatrix confMatrix = new ConfusionMatrix(CATEGORIES);
      NumberFormat nf = NumberFormat.getInstance();
      nf.setMaximumIntegerDigits(1); nf.setMaximumFractionDigits(3);
    for (int i=0; i < CATEGORIES.length; ++i)
       {
    File classDir = new File(TESTING_DIR, CATEGORIES[i]);
    String[] testingFiles = classDir.list();
       
        //*-- for each file, find the best category using the classifier and compare with the
        //*-- designated category
    for (int j=0; j < testingFiles.length; ++j)
         {
      String text = Files.readFromFile( new File(classDir, testingFiles[j]) );
      logger.debug("Testing on " + CATEGORIES[i] + File.separator + testingFiles[j]);
      JointClassification jc =  compiledClassifier.classifyJoint(text);
          confMatrix.increment(CATEGORIES[i], jc.bestCategory());
          logger.debug("Best Category: " + jc.bestCategory() );
          StringBuffer sb = new StringBuffer();
          sb.append("Scores ");
          for (int k = 0; k < CATEGORIES.length; k++) sb.append(nf.format(jc.score(k)) + " ");
          logger.debug(sb);
     } //*-- end of inner for
    } //*-- end of outer for
     
      logger.info("--------------------------------------------");
      logger.info("- Results ");
      logger.info("--------------------------------------------");
      int[][] imatrix = confMatrix.matrix();
      StringBuffer sb = new StringBuffer();
      sb.append(StringTools.fillin("CATEGORY", 10, true, ' ') );
      for (int i = 0; i < CATEGORIES.length; i++) sb.append(StringTools.fillin(CATEGORIES[i], 8, false, ' ') );
      logger.info(sb.toString());
     
      for (int i = 0; i < imatrix.length; i++)
      { sb = new StringBuffer();
        sb.append(StringTools.fillin(CATEGORIES[i], 10, true, ' ', 10 - CATEGORIES[i].length() ) );
        for (int j = 0; j < imatrix.length; j++)
         {  String out = "" + imatrix[i][j];
          sb.append(StringTools.fillin(out, 8, false, ' ', 8 - out.length() ) );
         }
        logger.info(sb.toString());
      }
     
    logger.info("Total Accuracy: " + nf.format(confMatrix.totalAccuracy()) );
      logger.info("Total Correct : " + confMatrix.totalCorrect() + " out of " + confMatrix.totalCount() );
    }
}
TOP

Related Classes of org.sf.mustru.test.TestMClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.