Package edu.stanford.nlp.classify

Source Code of edu.stanford.nlp.classify.LinearClassifierITest

package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import junit.framework.Assert;
import junit.framework.TestCase;

import java.util.ArrayList;
import java.util.List;

/**
*
*/
public class LinearClassifierITest extends TestCase {

  private static <L, F> RVFDatum<L, F> newDatum(L label,
                                                F[] features,
                                                Double[] counts) {
    ClassicCounter<F> counter = new ClassicCounter<F>();
    for (int i = 0; i < features.length; i++) {
      counter.setCount(features[i], counts[i]);
    }
    return new RVFDatum<L, F>(counter, label);
  }

  /**
   * Tests string based features
   *
   * @throws Exception
   */
  private static void testStrBinaryDatums(double d1f1, double d1f2, double d2f1, double d2f2) throws Exception {
    RVFDataset<String, String> trainData = new RVFDataset<String, String>();
    RVFDatum<String, String> d1 = newDatum("alpha",
      new String[]{"f1", "f2"},
      new Double[]{d1f1, d1f2});
    RVFDatum<String, String> d2 = newDatum("beta",
      new String[]{"f1", "f2"},
      new Double[]{d2f1, d2f2});
    trainData.add(d1);
    trainData.add(d2);
    LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
    LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
    // Try the obvious (should get train data with 100% acc)
    Assert.assertEquals(d1.label(), lc.classOf(d1));
    Assert.assertEquals(d2.label(), lc.classOf(d2));
  }

  public void testStrBinaryDatums() throws Exception {
    testStrBinaryDatums(-1.0, 0.0, 1.0, 0.0);
    testStrBinaryDatums(1.0, 0.0, -1.0, 0.0);
    testStrBinaryDatums(0.0, 1.0, 0.0, -1.0);
    testStrBinaryDatums(0.0, -1.0, 0.0, 1.0);   
    testStrBinaryDatums(1.0, 1.0, -1.0, -1.0);
    testStrBinaryDatums(0.0, 1.0, 1.0, 0.0);
    testStrBinaryDatums(1.0, 0.0, 0.0, 1.0);
  }

  public void testStrMultiClassDatums() throws Exception {
    RVFDataset<String, String> trainData = new RVFDataset<String, String>();
    List<RVFDatum<String, String>> datums = new ArrayList<RVFDatum<String, String>>();
    datums.add(newDatum("alpha",
      new String[]{"f1", "f2"},
      new Double[]{1.0, 0.0}));
    ;
    datums.add(newDatum("beta",
      new String[]{"f1", "f2"},
      new Double[]{0.0, 1.0}));
    datums.add(newDatum("charlie",
      new String[]{"f1", "f2"},
      new Double[]{5.0, 5.0}));
    for (RVFDatum<String, String> datum : datums)
      trainData.add(datum);
    LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
    LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);

    RVFDatum td1 = newDatum("alpha",
      new String[]{"f1", "f2","f3"},
      new Double[]{2.0, 0.0, 5.5});

    // Try the obvious (should get train data with 100% acc)
    for (RVFDatum<String, String> datum : datums)
      Assert.assertEquals(datum.label(), lc.classOf(datum));

    // Test data
    Assert.assertEquals(td1.label(), lc.classOf(td1));
  }
}
TOP

Related Classes of edu.stanford.nlp.classify.LinearClassifierITest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.