InstanceList instances = new InstanceList(p);
instances.addThruPipe(new ArrayIterator(data));
CRF crf1 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
crf1.addFullyConnectedStatesForLabels();
CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood(
crf1);
crft1.trainIncremental(instances);
CRF crf2 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
crf2.addFullyConnectedStatesForLabels();
// Freeze some weights, before training
for (int i = 0; i < crf2.getWeights().length; i += 2)
crf2.freezeWeights(i);
CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood(
crf2);
crft2.trainIncremental(instances);
SparseVector[] w = crf2.getWeights();
double[] b = crf2.getDefaultWeights();
for (int i = 0; i < w.length; i += 2) {
assertEquals(0.0, b[i], 1e-10);
for (int loc = 0; loc < w[i].numLocations(); loc++) {
assertEquals(0.0, w[i].valueAtLocation(loc), 1e-10);
}
}
// Check that the frozen weights has worse likelihood
Optimizable.ByGradientValue optable1 = crft1
.getOptimizableCRF(instances);
Optimizable.ByGradientValue optable2 = crft2
.getOptimizableCRF(instances);
double val1 = optable1.getValue();
double val2 = optable2.getValue();
assertTrue(
"Error: Freezing weights does not harm log-likelihood! Full "