package aima.test.core.unit.learning.framework;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.DataSetFactory;
import aima.core.learning.framework.DataSetSpecification;
import aima.core.learning.framework.Example;
import aima.core.learning.neural.IrisDataSetNumerizer;
import aima.core.learning.neural.Numerizer;
import aima.core.learning.neural.RabbitEyeDataSet;
import aima.core.util.datastructure.Pair;
/**
* @author Ravi Mohan
*
*/
public class DataSetTest {
private static final String YES = "Yes";
DataSetSpecification spec;
@Test
public void testNormalizationOfFileBasedDataProducesCorrectMeanStdDevAndNormalizedValues()
throws Exception {
RabbitEyeDataSet reds = new RabbitEyeDataSet();
reds.createNormalizedDataFromFile("rabbiteyes");
List<Double> means = reds.getMeans();
Assert.assertEquals(2, means.size());
Assert.assertEquals(244.771, means.get(0), 0.001);
Assert.assertEquals(145.505, means.get(1), 0.001);
List<Double> stdev = reds.getStdevs();
Assert.assertEquals(2, stdev.size());
Assert.assertEquals(213.554, stdev.get(0), 0.001);
Assert.assertEquals(65.776, stdev.get(1), 0.001);
List<List<Double>> normalized = reds.getNormalizedData();
Assert.assertEquals(70, normalized.size());
// check first value
Assert.assertEquals(-1.0759, normalized.get(0).get(0), 0.001);
Assert.assertEquals(-1.882, normalized.get(0).get(1), 0.001);
// check last Value
Assert.assertEquals(2.880, normalized.get(69).get(0), 0.001);
Assert.assertEquals(1.538, normalized.get(69).get(1), 0.001);
}
@Test
public void testExampleFormation() throws Exception {
RabbitEyeDataSet reds = new RabbitEyeDataSet();
reds.createExamplesFromFile("rabbiteyes");
Assert.assertEquals(70, reds.howManyExamplesLeft());
reds.getExampleAtRandom();
Assert.assertEquals(69, reds.howManyExamplesLeft());
reds.getExampleAtRandom();
Assert.assertEquals(68, reds.howManyExamplesLeft());
}
@Test
public void testLoadsDatasetFile() throws Exception {
DataSet ds = DataSetFactory.getRestaurantDataSet();
Assert.assertEquals(12, ds.size());
Example first = ds.getExample(0);
Assert.assertEquals(YES, first.getAttributeValueAsString("alternate"));
Assert.assertEquals("$$$", first.getAttributeValueAsString("price"));
Assert.assertEquals("0-10",
first.getAttributeValueAsString("wait_estimate"));
Assert.assertEquals(YES, first.getAttributeValueAsString("will_wait"));
Assert.assertEquals(YES, first.targetValue());
}
@Test(expected = Exception.class)
public void testThrowsExceptionForNonExistentFile() throws Exception {
new DataSetFactory().fromFile("nonexistent", null, null);
}
@Test
public void testLoadsIrisDataSetWithNumericAndStringAttributes()
throws Exception {
DataSet ds = DataSetFactory.getIrisDataSet();
Example first = ds.getExample(0);
Assert.assertEquals("5.1",
first.getAttributeValueAsString("sepal_length"));
}
@Test
public void testNonDestructiveRemoveExample() throws Exception {
DataSet ds1 = DataSetFactory.getRestaurantDataSet();
DataSet ds2 = ds1.removeExample(ds1.getExample(0));
Assert.assertEquals(12, ds1.size());
Assert.assertEquals(11, ds2.size());
}
@Test
public void testNumerizesAndDeNumerizesIrisDataSetExample1()
throws Exception {
DataSet ds = DataSetFactory.getIrisDataSet();
Example first = ds.getExample(0);
Numerizer n = new IrisDataSetNumerizer();
Pair<List<Double>, List<Double>> io = n.numerize(first);
Assert.assertEquals(Arrays.asList(5.1, 3.5, 1.4, 0.2), io.getFirst());
Assert.assertEquals(Arrays.asList(0.0, 0.0, 1.0), io.getSecond());
String plant_category = n.denumerize(Arrays.asList(0.0, 0.0, 1.0));
Assert.assertEquals("setosa", plant_category);
}
@Test
public void testNumerizesAndDeNumerizesIrisDataSetExample2()
throws Exception {
DataSet ds = DataSetFactory.getIrisDataSet();
Example first = ds.getExample(51);
Numerizer n = new IrisDataSetNumerizer();
Pair<List<Double>, List<Double>> io = n.numerize(first);
Assert.assertEquals(Arrays.asList(6.4, 3.2, 4.5, 1.5), io.getFirst());
Assert.assertEquals(Arrays.asList(0.0, 1.0, 0.0), io.getSecond());
String plant_category = n.denumerize(Arrays.asList(0.0, 1.0, 0.0));
Assert.assertEquals("versicolor", plant_category);
}
@Test
public void testNumerizesAndDeNumerizesIrisDataSetExample3()
throws Exception {
DataSet ds = DataSetFactory.getIrisDataSet();
Example first = ds.getExample(100);
Numerizer n = new IrisDataSetNumerizer();
Pair<List<Double>, List<Double>> io = n.numerize(first);
Assert.assertEquals(Arrays.asList(6.3, 3.3, 6.0, 2.5), io.getFirst());
Assert.assertEquals(Arrays.asList(1.0, 0.0, 0.0), io.getSecond());
String plant_category = n.denumerize(Arrays.asList(1.0, 0.0, 0.0));
Assert.assertEquals("virginica", plant_category);
}
}