package org.sf.mustru.train;
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import com.aliasi.util.Files;
import com.aliasi.classify.DynamicLMClassifier;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.sf.mustru.utils.Constants;
import java.io.IOException;
/**
* Test classification using the Lingpipe classifier with a collection
* of Reuters documents. Three categories with 10-15 documents each
* are used for training.
*/
public class TrainMClassifier
{
//*-- define the training and testing directories and list of categories
private static String MUSTRU_HOME = Constants.MUSTRU_HOME;
private static File TRAINING_DIR = new File(
MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat");
private static String[] CATEGORIES = {"sugar", "coffee", "cocoa", "misc"};
private static int NGRAM_SIZE = 6;
private static boolean BOUNDED = false;
public static void main(String[] args) throws ClassNotFoundException, IOException
{
PropertyConfigurator.configure (Constants.LOG4J_FILE);
Logger logger = Logger.getLogger(TrainMClassifier.class.getName());
logger.debug("Started TrainMClassifier");
DynamicLMClassifier classifier = new DynamicLMClassifier(CATEGORIES, NGRAM_SIZE, BOUNDED);
//*-- Start of training
//*-- loop through the list of categories and verify that a directory exists for each one.
for (int i = 0; i < CATEGORIES.length; ++i)
{
File classDir = new File(TRAINING_DIR, CATEGORIES[i]);
if (!classDir.isDirectory())
{ logger.fatal("Could not find training directory=" + classDir); }
//*-- get the list of training files for the category and train the classifier on each of the files
String[] trainingFiles = classDir.list();
for (int j=0; j<trainingFiles.length; ++j)
{
String text = Files.readFromFile(new File(classDir,trainingFiles[j]));
logger.debug("Training on " + CATEGORIES[i] + File.separator + trainingFiles[j]);
classifier.train(CATEGORIES[i], text);
} //*-- end of inner for
} //*-- end of outer for
//*-- end of training
//*-- dump the classification model to a file
logger.info("Start compiling classifier");
String modelFile = MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat" + File.separator + "tcat_classifier";
ObjectOutputStream os = new ObjectOutputStream( new FileOutputStream(modelFile) );
classifier.compileTo(os);
os.close();
logger.info("End compiling classifier");
logger.debug("Ended TrainMClassifier");
}
}