/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised;
import java.util.logging.Logger;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFOptimizableByGradientValues;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
/**
* A CRF trainer that maximizes the log-likelihood plus
* a weighted entropy regularization term on unlabeled
* data. Intuitively, it aims to make the CRF's predictions
* on unlabeled data more confident.
*
* References:
* Feng Jiao, Shaojun Wang, Chi-Hoon Lee, Russell Greiner, Dale Schuurmans
* "Semi-supervised conditional random fields for improved sequence segmentation and labeling"
* ACL 2006
*
* Gideon Mann, Andrew McCallum
* "Efficient Computation of Entropy Gradient for Semi-Supervised Conditional Random Fields"
* HLT/NAACL 2007
*
* @author Gregory Druck
*/
public class CRFTrainerByEntropyRegularization extends TransducerTrainer implements TransducerTrainer.ByOptimization {
private static Logger logger = MalletLogger.getLogger(CRFTrainerByEntropyRegularization.class.getName());
private static final int DEFAULT_NUM_RESETS = 1;
private static final double DEFAULT_ER_SCALING_FACTOR = 1;
private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1;
private boolean converged;
private int iteration;
private double entRegScalingFactor;
private double gaussianPriorVariance;
private CRF crf;
private LimitedMemoryBFGS bfgs;
public CRFTrainerByEntropyRegularization(CRF crf) {
this.crf = crf;
this.iteration = 0;
this.entRegScalingFactor = DEFAULT_ER_SCALING_FACTOR;
this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
}
public void setGaussianPriorVariance(double variance) {
this.gaussianPriorVariance = variance;
}
/**
* Sets the scaling factor for the entropy regularization term.
* In [Jiao et al. 06], this is gamma.
*
* @param gamma
*/
public void setEntropyWeight(double gamma) {
this.entRegScalingFactor = gamma;
}
@Override
public int getIteration() {
return this.iteration;
}
@Override
public Transducer getTransducer() {
return this.crf;
}
@Override
public boolean isFinishedTraining() {
return this.converged;
}
/*
* This is not used because we require both labeled and unlabeled data.
*/
public boolean train(InstanceList trainingSet, int numIterations) {
throw new RuntimeException("Use train(InstanceList labeled, InstanceList unlabeled, int numIterations) instead.");
}
/**
* Performs CRF training with label likelihood and entropy regularization.
* The CRF is first trained with label likelihood only. This parameter
* setting is used as a starting point for the combined optimization.
*
* @param labeled Labeled data, only used for label likelihood term.
* @param unlabeled Unlabeled data, only used for entropy regularization term.
* @param numIterations Number of iterations.
* @return True if training has converged.
*/
public boolean train(InstanceList labeled, InstanceList unlabeled, int numIterations) {
if (iteration == 0) {
// train with log-likelihood only first
CRFOptimizableByLabelLikelihood likelihood =
new CRFOptimizableByLabelLikelihood(crf, labeled);
likelihood.setGaussianPriorVariance(gaussianPriorVariance);
this.bfgs = new LimitedMemoryBFGS(likelihood);
logger.info ("CRF about to train with "+numIterations+" iterations");
for (int i = 0; i < numIterations; i++) {
try {
converged = bfgs.optimize(1);
iteration++;
logger.info ("CRF finished one iteration of maximizer, i="+i);
runEvaluators();
} catch (IllegalArgumentException e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
converged = true;
} catch (Exception e) {
e.printStackTrace();
logger.info("Catching exception; saying converged.");
converged = true;
}
if (converged) {
logger.info ("CRF training has converged, i="+i);
break;
}
}
iteration = 0;
}
// train with log-likelihood + entropy regularization
CRFOptimizableByLabelLikelihood likelihood = new CRFOptimizableByLabelLikelihood(crf, labeled);
likelihood.setGaussianPriorVariance(gaussianPriorVariance);
CRFOptimizableByEntropyRegularization regularization = new CRFOptimizableByEntropyRegularization(crf, unlabeled);
regularization.setScalingFactor(this.entRegScalingFactor);
CRFOptimizableByGradientValues regLikelihood = new CRFOptimizableByGradientValues(crf,
new Optimizable.ByGradientValue[] { likelihood, regularization} );
this.bfgs = new LimitedMemoryBFGS(regLikelihood);
converged = false;
logger.info ("CRF about to train with "+numIterations+" iterations");
// sometimes resetting the optimizer helps to find
// a better parameter setting
for (int reset = 0; reset < DEFAULT_NUM_RESETS + 1; reset++) {
for (int i = 0; i < numIterations; i++) {
try {
converged = bfgs.optimize (1);
iteration++;
logger.info ("CRF finished one iteration of maximizer, i="+i);
runEvaluators();
} catch (IllegalArgumentException e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
converged = true;
} catch (Exception e) {
e.printStackTrace();
logger.info("Catching exception; saying converged.");
converged = true;
}
if (converged) {
logger.info ("CRF training has converged, i="+i);
break;
}
}
this.bfgs.reset();
}
return converged;
}
public Optimizer getOptimizer() {
return bfgs;
}
}