package cc.mallet.fst;
import java.util.Random;
import java.util.logging.Logger;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
/**
* @author Gregory Druck gdruck@cs.umass.edu
*
* Multi-threaded version of CRF trainer. Note that multi-threaded feature induction
* and hyperbolic prior are not supported by this code.
*/
public class CRFTrainerByThreadedLabelLikelihood extends TransducerTrainer implements TransducerTrainer.ByOptimization {
private static Logger logger = MalletLogger.getLogger(CRFTrainerByThreadedLabelLikelihood.class.getName());
static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
private boolean useSparseWeights;
private boolean useNoWeights;
private transient boolean useSomeUnsupportedTrick;
private boolean converged;
private int numThreads;
private int iterationCount;
private double gaussianPriorVariance;
private CRF crf;
private CRFOptimizableByBatchLabelLikelihood optimizable;
private ThreadedOptimizable threadedOptimizable;
private Optimizer optimizer;
private int cachedWeightsStructureStamp;
public CRFTrainerByThreadedLabelLikelihood (CRF crf, int numThreads) {
this.crf = crf;
this.useSparseWeights = true;
this.useNoWeights = false;
this.useSomeUnsupportedTrick = true;
this.converged = false;
this.numThreads = numThreads;
this.iterationCount = 0;
this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
this.cachedWeightsStructureStamp = -1;
}
public Transducer getTransducer() { return crf; }
public CRF getCRF () { return crf; }
public Optimizer getOptimizer() { return optimizer; }
public boolean isConverged() { return converged; }
public boolean isFinishedTraining() { return converged; }
public int getIteration () { return iterationCount; }
public void setGaussianPriorVariance (double p) { gaussianPriorVariance = p; }
public double getGaussianPriorVariance () { return gaussianPriorVariance; }
public void setUseSparseWeights (boolean b) { useSparseWeights = b; }
public boolean getUseSparseWeights () { return useSparseWeights; }
/** Sets whether to use the 'some unsupported trick.' This trick is, if training a CRF
* where some training has been done and sparse weights are used, to add a few weights
* for feaures that do not occur in the tainig data.
* <p>
* This generally leads to better accuracy at only a small memory cost.
*
* @param b Whether to use the trick
*/
public void setUseSomeUnsupportedTrick (boolean b) { useSomeUnsupportedTrick = b; }
/**
* Use this method to specify whether or not factors
* are added to the CRF by this trainer. If you have
* already setup the factors in your CRF, you may
* not want the trainer to add additional factors.
*
* @param flag If true, this trainer adds no factors to the CRF.
*/
public void setAddNoFactors(boolean flag) {
this.useNoWeights = flag;
}
public void shutdown() {
threadedOptimizable.shutdown();
}
public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF (InstanceList trainingSet) {
if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) {
if (!useNoWeights) {
if (useSparseWeights) {
crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick);
}
else {
crf.setWeightsDimensionDensely ();
}
}
optimizable = null;
cachedWeightsStructureStamp = crf.weightsStructureChangeStamp;
}
if (optimizable == null || optimizable.trainingSet != trainingSet) {
optimizable = new CRFOptimizableByBatchLabelLikelihood(crf, trainingSet, numThreads);
optimizable.setGaussianPriorVariance(gaussianPriorVariance);
threadedOptimizable = new ThreadedOptimizable(optimizable, trainingSet, crf.getParameters().getNumFactors(),
new CRFCacheStaleIndicator(crf));
optimizer = null;
}
return optimizable;
}
public Optimizer getOptimizer (InstanceList trainingSet) {
getOptimizableCRF(trainingSet);
if (optimizer == null || optimizable != optimizer.getOptimizable()) {
optimizer = new LimitedMemoryBFGS(threadedOptimizable);
}
return optimizer;
}
public boolean trainIncremental (InstanceList training) {
return train (training, Integer.MAX_VALUE);
}
public boolean train (InstanceList trainingSet, int numIterations) {
if (numIterations <= 0) {
return false;
}
assert (trainingSet.size() > 0);
getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary
getOptimizer(trainingSet); // This will set this.opt if necessary
boolean converged = false;
logger.info ("CRF about to train with "+numIterations+" iterations");
for (int i = 0; i < numIterations; i++) {
try {
converged = optimizer.optimize (1);
iterationCount++;
logger.info ("CRF finished one iteration of maximizer, i="+i);
runEvaluators();
} catch (IllegalArgumentException e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
converged = true;
} catch (Exception e) {
e.printStackTrace();
logger.info("Catching exception; saying converged.");
converged = true;
}
if (converged) {
logger.info ("CRF training has converged, i="+i);
break;
}
}
return converged;
}
/**
* Train a CRF on various-sized subsets of the data. This method is typically used to accelerate training by
* quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data.
* @param training The training Instances.
* @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion.
* @param trainingProportions If non-null, train on increasingly
* larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}. This can sometimes speedup convergence.
* Be sure to end in 1.0 if you want to train on all the data in the end.
* @return True if training has converged.
*/
public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions)
{
int trainingIteration = 0;
assert (trainingProportions.length > 0);
boolean converged = false;
for (int i = 0; i < trainingProportions.length; i++) {
assert (trainingProportions[i] <= 1.0);
logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");
if (trainingProportions[i] == 1.0) {
converged = this.train (training, numIterationsPerProportion);
}
else {
converged = this.train (training.split (new Random(1),
new double[] {trainingProportions[i], 1-trainingProportions[i]})[0], numIterationsPerProportion);
}
trainingIteration += numIterationsPerProportion;
}
return converged;
}
}