Package cc.mallet.classify.tests

Source Code of cc.mallet.classify.tests.TestClassifiers

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.classify.tests;

//import edu.umass.cs.mallet.base.pipe.SerialPipe;
import junit.framework.*;
import java.net.URI;
import java.util.Iterator;

import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.pipe.iterator.RandomTokenSequenceIterator;
import cc.mallet.types.*;
import cc.mallet.util.*;

public class TestClassifiers extends TestCase
{
  public TestClassifiers (String name)
  {
    super (name);
  }

  private static Alphabet dictOfSize (int size)
  {
    Alphabet ret = new Alphabet ();
    for (int i = 0; i < size; i++)
      ret.lookupIndex ("feature"+i);
    return ret;
  }
 
  public void testRandomTrained ()
  {
    ClassifierTrainer[] trainers = new ClassifierTrainer[1];
    //trainers[0] = new NaiveBayesTrainer();
    trainers[0] = new MaxEntTrainer();
    //trainers[2] = new DecisionTreeTrainer();

    Alphabet fd = dictOfSize (3);
    String[] classNames = new String[] {"class0", "class1", "class2"};

    InstanceList ilist = new InstanceList (new Randoms(1), fd, classNames, 200);

    InstanceList lists[] = ilist.split (new java.util.Random(2), new double[] {.5, .5});
    //System.out.println ("Training set size = "+lists[0].size());
    //System.out.println ("Testing set size = "+lists[1].size());
   
    Classifier[] classifiers = new Classifier[trainers.length];
    for (int i = 0; i < trainers.length; i++)
      classifiers[i] = trainers[i].train (lists[0]);

    System.out.println ("Accuracy on training set:");
    for (int i = 0; i < trainers.length; i++)
      System.out.println (classifiers[i].getClass().getName()
                          + ": " + new Trial (classifiers[i], lists[0]).getAccuracy());

    System.out.println ("Accuracy on testing set:");
    for (int i = 0; i < trainers.length; i++)
      System.out.println (classifiers[i].getClass().getName()
                          + ": " + new Trial (classifiers[i], lists[1]).getAccuracy());
  }
 
  public void testNewFeatures ()
    {
      ClassifierTrainer[] trainers = new ClassifierTrainer[1];
      trainers[0] = new MaxEntTrainer();

      Alphabet fd = dictOfSize (3);
      String[] classNames = new String[] {"class0", "class1", "class2"};

      Randoms r = new Randoms(1);
      InstanceList training = new InstanceList (r, fd, classNames, 50);
      expandDict (fd, 25);

      Classifier[] classifiers = new Classifier[trainers.length];
      for (int i = 0; i < trainers.length; i++)
        classifiers[i] = trainers[i].train (training);

      System.out.println ("Accuracy on training set:");
      for (int i = 0; i < trainers.length; i++)
        System.out.println (classifiers[i].getClass().getName()
                            + ": " + new Trial (classifiers[i], training).getAccuracy());

      InstanceList testing = new InstanceList (training.getPipe ());
      Iterator<Instance> iter = new RandomTokenSequenceIterator (
        r,  new Dirichlet (fd, 2.0),
        30, 0,
        10, 50,
        classNames);
      testing.addThruPipe (iter);

      for (int i = 0; i < testing.size (); i++) {
        Instance inst = testing.get (i);
        System.out.println ("DATA:"+inst.getData());
      }

      System.out.println ("Accuracy on testing set:");
      for (int i = 0; i < trainers.length; i++)
        System.out.println (classifiers[i].getClass().getName()
                            + ": " + new Trial (classifiers[i], testing).getAccuracy());
    }

  private void expandDict (Alphabet fd, int size)
  {
    fd.startGrowth ();
    for (int i = 0; i < size; i++)
      fd.lookupIndex ("feature"+i, true);
  }

  public static Test suite ()
  {
    return new TestSuite (TestClassifiers.class);
  }

  protected void setUp ()
  {
  }

  public static void main (String[] args)
  {
    junit.textui.TestRunner.run (suite());
  }
 
}
TOP

Related Classes of cc.mallet.classify.tests.TestClassifiers

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.