boolean useSaved,
boolean useSparseWeights)
{
Pipe p = makeSpacePredictionPipe ();
MEMM savedCRF;
File f = new File("TestObject.obj");
InstanceList instances = new InstanceList(p);
instances.addThruPipe(new ArrayIterator(data));
InstanceList[] lists = instances.split(new double[]{.5, .5});
MEMM crf = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
crf.addFullyConnectedStatesForLabels();
if (useSparseWeights)
crf.setWeightsDimensionAsIn(lists[0]);
else
crf.setWeightsDimensionDensely();
MEMMTrainer memmt = new MEMMTrainer (crf);
// memmt.setUseSparseWeights (useSparseWeights);
if (testValueAndGradient) {
Optimizable.ByGradientValue minable = memmt.getOptimizableMEMM(lists[0]);
TestOptimizable.testValueAndGradient(minable);
} else {
System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));
System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));
savedCRF = crf;
System.out.println("Training serialized crf.");
memmt.train(lists[0], 100);
double preTrainAcc = crf.averageTokenAccuracy(lists[0]);
double preTestAcc = crf.averageTokenAccuracy(lists[1]);
System.out.println("Training Accuracy after training = " + preTrainAcc);
System.out.println("Testing Accuracy after training = " + preTestAcc);
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
oos.writeObject(crf);
oos.close();
} catch (IOException e) {
System.err.println("Exception writing file: " + e);
}
System.err.println("Wrote out CRF");
// And read it back in
if (useSaved) {
crf = null;
try {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
crf = (MEMM) ois.readObject();
ois.close();
} catch (IOException e) {
System.err.println("Exception reading file: " + e);
} catch (ClassNotFoundException cnfe) {
System.err.println("Cound not find class reading in object: " + cnfe);
}
System.err.println("Read in CRF.");
crf = savedCRF;
double postTrainAcc = crf.averageTokenAccuracy(lists[0]);
double postTestAcc = crf.averageTokenAccuracy(lists[1]);
System.out.println("Training Accuracy after saving = " + postTrainAcc);
System.out.println("Testing Accuracy after saving = " + postTestAcc);
assertEquals(postTrainAcc, preTrainAcc, 0.0001);
assertEquals(postTestAcc, preTestAcc, 0.0001);