package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import junit.framework.Assert;
import junit.framework.TestCase;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public class LinearClassifierITest extends TestCase {
private static <L, F> RVFDatum<L, F> newDatum(L label,
F[] features,
Double[] counts) {
ClassicCounter<F> counter = new ClassicCounter<F>();
for (int i = 0; i < features.length; i++) {
counter.setCount(features[i], counts[i]);
}
return new RVFDatum<L, F>(counter, label);
}
/**
* Tests string based features
*
* @throws Exception
*/
private static void testStrBinaryDatums(double d1f1, double d1f2, double d2f1, double d2f2) throws Exception {
RVFDataset<String, String> trainData = new RVFDataset<String, String>();
RVFDatum<String, String> d1 = newDatum("alpha",
new String[]{"f1", "f2"},
new Double[]{d1f1, d1f2});
RVFDatum<String, String> d2 = newDatum("beta",
new String[]{"f1", "f2"},
new Double[]{d2f1, d2f2});
trainData.add(d1);
trainData.add(d2);
LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
// Try the obvious (should get train data with 100% acc)
Assert.assertEquals(d1.label(), lc.classOf(d1));
Assert.assertEquals(d2.label(), lc.classOf(d2));
}
public void testStrBinaryDatums() throws Exception {
testStrBinaryDatums(-1.0, 0.0, 1.0, 0.0);
testStrBinaryDatums(1.0, 0.0, -1.0, 0.0);
testStrBinaryDatums(0.0, 1.0, 0.0, -1.0);
testStrBinaryDatums(0.0, -1.0, 0.0, 1.0);
testStrBinaryDatums(1.0, 1.0, -1.0, -1.0);
testStrBinaryDatums(0.0, 1.0, 1.0, 0.0);
testStrBinaryDatums(1.0, 0.0, 0.0, 1.0);
}
public void testStrMultiClassDatums() throws Exception {
RVFDataset<String, String> trainData = new RVFDataset<String, String>();
List<RVFDatum<String, String>> datums = new ArrayList<RVFDatum<String, String>>();
datums.add(newDatum("alpha",
new String[]{"f1", "f2"},
new Double[]{1.0, 0.0}));
;
datums.add(newDatum("beta",
new String[]{"f1", "f2"},
new Double[]{0.0, 1.0}));
datums.add(newDatum("charlie",
new String[]{"f1", "f2"},
new Double[]{5.0, 5.0}));
for (RVFDatum<String, String> datum : datums)
trainData.add(datum);
LinearClassifierFactory<String, String> lfc = new LinearClassifierFactory<String, String>();
LinearClassifier<String, String> lc = lfc.trainClassifier(trainData);
RVFDatum td1 = newDatum("alpha",
new String[]{"f1", "f2","f3"},
new Double[]{2.0, 0.0, 5.5});
// Try the obvious (should get train data with 100% acc)
for (RVFDatum<String, String> datum : datums)
Assert.assertEquals(datum.label(), lc.classOf(datum));
// Test data
Assert.assertEquals(td1.label(), lc.classOf(td1));
}
}