package org.sf.mustru.test;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import com.aliasi.util.Files;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.LMClassifier;
import com.aliasi.classify.JointClassification;
import org.sf.mustru.utils.Constants;
import org.sf.mustru.utils.StringTools;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import java.io.IOException;
import java.text.NumberFormat;
/**
* Test classification using the Lingpipe classifier with a collection
* of Reuters documents. Three categories with 10-15 documents each
* are used for training.
*/
public class TestMClassifier
{
//*-- define the training and testing directories and list of categories
//*-- three categories - sugar, coffee, and cocoa
private static String MUSTRU_HOME = Constants.MUSTRU_HOME;
private static File TESTING_DIR = new File(
MUSTRU_HOME + File.separator + "data" + File.separator + "testing" + File.separator + "tcat");
private static String[] CATEGORIES = {"sugar", "coffee", "cocoa", "misc"};
/**
* Test the classifier using the files provided in the test/category directories
*/
public static void main(String[] args) throws ClassNotFoundException, IOException
{
PropertyConfigurator.configure (Constants.LOG4J_FILE);
Logger logger = Logger.getLogger(TestMClassifier.class.getName());
logger.debug("Started TestClassifier");
//*-- read the classification model
String modelFile = MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat" + File.separator + "tcat_classifier";
ObjectInputStream oi = new ObjectInputStream( new FileInputStream(modelFile) );
LMClassifier compiledClassifier = (LMClassifier) oi.readObject();
oi.close();
//*-- loop through the identical categories and test the classification of test documents
ConfusionMatrix confMatrix = new ConfusionMatrix(CATEGORIES);
NumberFormat nf = NumberFormat.getInstance();
nf.setMaximumIntegerDigits(1); nf.setMaximumFractionDigits(3);
for (int i=0; i < CATEGORIES.length; ++i)
{
File classDir = new File(TESTING_DIR, CATEGORIES[i]);
String[] testingFiles = classDir.list();
//*-- for each file, find the best category using the classifier and compare with the
//*-- designated category
for (int j=0; j < testingFiles.length; ++j)
{
String text = Files.readFromFile( new File(classDir, testingFiles[j]) );
logger.debug("Testing on " + CATEGORIES[i] + File.separator + testingFiles[j]);
JointClassification jc = compiledClassifier.classifyJoint(text);
confMatrix.increment(CATEGORIES[i], jc.bestCategory());
logger.debug("Best Category: " + jc.bestCategory() );
StringBuffer sb = new StringBuffer();
sb.append("Scores ");
for (int k = 0; k < CATEGORIES.length; k++) sb.append(nf.format(jc.score(k)) + " ");
logger.debug(sb);
} //*-- end of inner for
} //*-- end of outer for
logger.info("--------------------------------------------");
logger.info("- Results ");
logger.info("--------------------------------------------");
int[][] imatrix = confMatrix.matrix();
StringBuffer sb = new StringBuffer();
sb.append(StringTools.fillin("CATEGORY", 10, true, ' ') );
for (int i = 0; i < CATEGORIES.length; i++) sb.append(StringTools.fillin(CATEGORIES[i], 8, false, ' ') );
logger.info(sb.toString());
for (int i = 0; i < imatrix.length; i++)
{ sb = new StringBuffer();
sb.append(StringTools.fillin(CATEGORIES[i], 10, true, ' ', 10 - CATEGORIES[i].length() ) );
for (int j = 0; j < imatrix.length; j++)
{ String out = "" + imatrix[i][j];
sb.append(StringTools.fillin(out, 8, false, ' ', 8 - out.length() ) );
}
logger.info(sb.toString());
}
logger.info("Total Accuracy: " + nf.format(confMatrix.totalAccuracy()) );
logger.info("Total Correct : " + confMatrix.totalCorrect() + " out of " + confMatrix.totalCount() );
}
}