/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised;
import java.util.ArrayList;
import java.util.logging.Logger;
import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
/**
* Trains a CRF using Generalized Expectation constraints that
* consider either a single label or a pair of labels of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* @author Gregory Druck
*/
public class CRFTrainerByGE extends TransducerTrainer {
private static Logger logger = MalletLogger.getLogger(CRFTrainerByGE.class.getName());
private static final int DEFAULT_NUM_RESETS = 1;
private static final int DEFAULT_GPV = 10;
private boolean converged;
private int iteration;
private int numThreads;
private int numResets;
private double gaussianPriorVariance;
private ArrayList<GEConstraint> constraints;
private CRF crf;
private StateLabelMap stateLabelMap;
private CRFOptimizableByGE optimizable;
private Optimizer optimizer;
public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints) {
this(crf,constraints,1);
}
public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints, int numThreads) {
this.converged = false;
this.iteration = 0;
this.constraints = constraints;
this.crf = crf;
this.numThreads = numThreads;
this.numResets = DEFAULT_NUM_RESETS;
this.gaussianPriorVariance = DEFAULT_GPV;
// default one to one state label map
// other maps can be set with setStateLabelMap
this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(),true);
}
@Override
public int getIteration() {
return iteration;
}
@Override
public Transducer getTransducer() {
return crf;
}
@Override
public boolean isFinishedTraining() {
return converged;
}
public void setGaussianPriorVariance(double gpv) {
this.gaussianPriorVariance = gpv;
}
/**
* Sets number of resets of L-BFGS during
* optimization. Resetting more times
* can be useful since the GE objective
* function is non-convex
*
* @param numResets Number of resets of L-BFGS
*/
public void setNumResets(int numResets) {
this.numResets = numResets;
}
// map between states in CRF FST and labels
public void setStateLabelMap(StateLabelMap map) {
this.stateLabelMap = map;
}
public void setOptimizable(Optimizer optimizer) {
this.optimizer = optimizer;
}
public Optimizable.ByGradientValue getOptimizable(InstanceList unlabeled) {
if (optimizable == null) {
optimizable = new CRFOptimizableByGE(crf, constraints, unlabeled, stateLabelMap,numThreads);
optimizable.setGaussianPriorVariance(gaussianPriorVariance);
}
return optimizable;
}
public Optimizer getOptimizer(Optimizable.ByGradientValue optimizable) {
if (optimizer == null) {
optimizer = new LimitedMemoryBFGS(optimizable);
}
return optimizer;
}
@Override
public boolean train(InstanceList unlabeledSet, int numIterations) {
assert(constraints.size() > 0);
if (constraints.size() == 0) {
throw new RuntimeException("No constraints specified!");
}
getOptimizable(unlabeledSet);
getOptimizer(optimizable);
if (optimizer instanceof LimitedMemoryBFGS) {
((LimitedMemoryBFGS)optimizer).reset();
}
converged = false;
logger.info ("CRF about to train with "+numIterations+" iterations");
// sometimes resetting the optimizer helps to find
// a better parameter setting
int iter = 0;
for (int reset = 0; reset < numResets + 1; reset++) {
for (; iter < numIterations; iter++) {
try {
converged = optimizer.optimize (1);
iteration++;
logger.info ("CRF finished one iteration of maximizer, i="+iter);
runEvaluators();
} catch (IllegalArgumentException e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
converged = true;
} catch (Exception e) {
e.printStackTrace();
logger.info("Catching exception; saying converged.");
converged = true;
}
if (converged) {
logger.info ("CRF training has converged, i="+iter);
break;
}
}
if (optimizer instanceof LimitedMemoryBFGS) {
((LimitedMemoryBFGS)optimizer).reset();
}
}
// shutdown threads
optimizable.shutdown();
return converged;
}
}