/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
package cc.mallet.classify.tests;
//import edu.umass.cs.mallet.base.pipe.SerialPipe;
import junit.framework.*;
import java.net.URI;
import java.util.Iterator;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.pipe.iterator.RandomTokenSequenceIterator;
import cc.mallet.types.*;
import cc.mallet.util.*;
public class TestClassifiers extends TestCase
{
public TestClassifiers (String name)
{
super (name);
}
private static Alphabet dictOfSize (int size)
{
Alphabet ret = new Alphabet ();
for (int i = 0; i < size; i++)
ret.lookupIndex ("feature"+i);
return ret;
}
public void testRandomTrained ()
{
ClassifierTrainer[] trainers = new ClassifierTrainer[1];
//trainers[0] = new NaiveBayesTrainer();
trainers[0] = new MaxEntTrainer();
//trainers[2] = new DecisionTreeTrainer();
Alphabet fd = dictOfSize (3);
String[] classNames = new String[] {"class0", "class1", "class2"};
InstanceList ilist = new InstanceList (new Randoms(1), fd, classNames, 200);
InstanceList lists[] = ilist.split (new java.util.Random(2), new double[] {.5, .5});
//System.out.println ("Training set size = "+lists[0].size());
//System.out.println ("Testing set size = "+lists[1].size());
Classifier[] classifiers = new Classifier[trainers.length];
for (int i = 0; i < trainers.length; i++)
classifiers[i] = trainers[i].train (lists[0]);
System.out.println ("Accuracy on training set:");
for (int i = 0; i < trainers.length; i++)
System.out.println (classifiers[i].getClass().getName()
+ ": " + new Trial (classifiers[i], lists[0]).getAccuracy());
System.out.println ("Accuracy on testing set:");
for (int i = 0; i < trainers.length; i++)
System.out.println (classifiers[i].getClass().getName()
+ ": " + new Trial (classifiers[i], lists[1]).getAccuracy());
}
public void testNewFeatures ()
{
ClassifierTrainer[] trainers = new ClassifierTrainer[1];
trainers[0] = new MaxEntTrainer();
Alphabet fd = dictOfSize (3);
String[] classNames = new String[] {"class0", "class1", "class2"};
Randoms r = new Randoms(1);
InstanceList training = new InstanceList (r, fd, classNames, 50);
expandDict (fd, 25);
Classifier[] classifiers = new Classifier[trainers.length];
for (int i = 0; i < trainers.length; i++)
classifiers[i] = trainers[i].train (training);
System.out.println ("Accuracy on training set:");
for (int i = 0; i < trainers.length; i++)
System.out.println (classifiers[i].getClass().getName()
+ ": " + new Trial (classifiers[i], training).getAccuracy());
InstanceList testing = new InstanceList (training.getPipe ());
Iterator<Instance> iter = new RandomTokenSequenceIterator (
r, new Dirichlet (fd, 2.0),
30, 0,
10, 50,
classNames);
testing.addThruPipe (iter);
for (int i = 0; i < testing.size (); i++) {
Instance inst = testing.get (i);
System.out.println ("DATA:"+inst.getData());
}
System.out.println ("Accuracy on testing set:");
for (int i = 0; i < trainers.length; i++)
System.out.println (classifiers[i].getClass().getName()
+ ": " + new Trial (classifiers[i], testing).getAccuracy());
}
private void expandDict (Alphabet fd, int size)
{
fd.startGrowth ();
for (int i = 0; i < size; i++)
fd.lookupIndex ("feature"+i, true);
}
public static Test suite ()
{
return new TestSuite (TestClassifiers.class);
}
protected void setUp ()
{
}
public static void main (String[] args)
{
junit.textui.TestRunner.run (suite());
}
}