DataInputStream in = new DataInputStream(new ByteArrayInputStream(r.get()));
Classifier c = PolymorphicWritable.read(in, Classifier.class);
assertEquals(2, c.getCategories().size());
assertEquals("0", c.getCategories().get(0));
assertEquals("1", c.getCategories().get(1));
OnlineLogisticRegression model = c.getModel();
assertEquals(lr.getModel().currentLearningRate(), model.currentLearningRate(), 1e-10);
in.close();
// with that many data points, model should point in the same direction as the original vector
Vector v = model.getBeta().viewRow(0);
double z = n.dot(v) / (n.norm(2) * v.norm(2));
assertEquals(1.0, z, 1e-2);
// just for grins, we should check whether the model actually computes the correct values
List<String> categories = ImmutableList.of("0", "1");
for (Tuple example : examples) {
double score = model.classifyScalar(PigVector.fromBytes((DataByteArray) example.get(1)));
int actual = categories.indexOf(example.get(0));
score = score * actual + (1 - score) * (1 - actual);
assertTrue(score > 0.4);
}
}