package cc.mallet.fst;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Random;
import java.util.logging.Logger;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.util.MalletLogger;
/**
* A CRF trainer that can combine multiple objective functions, each represented
* by a Optmizable.ByValueGradient.
*/
public class CRFTrainerByValueGradients extends TransducerTrainer implements TransducerTrainer.ByOptimization {
private static Logger logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());
CRF crf;
// gsc: keep objects instead of classnames, this will give more flexibility to the
// user to setup new CRFOptimizable* objects and then pass them directly in the constructor,
// so the CRFOptimizable inner class no longer creates CRFOptimizable* objects
Optimizable.ByGradientValue[] optimizableByValueGradientObjects;
// Class[] optimizableByValueGradientClasses;
OptimizableCRF ocrf;
Optimizer opt;
int iterationCount = 0;
boolean converged;
// gsc: removing these options, the user ought to set the weights before
// creating the trainer object
// boolean useSparseWeights = true;
// // gsc
// boolean useUnsupportedTrick = false;
// Various values from CRF acting as indicators of when we need to ...
private int cachedValueWeightsStamp = -1; // ... re-calculate expectations and values to getValue() because weights' values changed
private int cachedGradientWeightsStamp = -1; // ... re-calculate to getValueGradient() because weights' values changed
// gsc: removing this because the user will call setWeightsDimensionsAsIn
// private int cachedWeightsStructureStamp = -1; // ... re-allocate crf.weights, expectations & constraints because new states, transitions
// Use mcrf.trainingSet to see when we need to re-allocate crf.weights, expectations & constraints because we are using a different TrainingList than last time
// gsc: number of times to reset (the optimizer), and continue training when the "could not step in
// current direction" exception occurs
public static final int DEFAULT_MAX_RESETS = 3;
int maxResets = DEFAULT_MAX_RESETS;
public CRFTrainerByValueGradients (CRF crf, Optimizable.ByGradientValue[] optimizableByValueGradientObjects) {
this.crf = crf;
this.optimizableByValueGradientObjects = optimizableByValueGradientObjects;
}
public Transducer getTransducer() { return crf; }
public CRF getCRF () { return crf; }
public Optimizer getOptimizer() { return opt; }
/** Returns true if training converged, false otherwise. */
public boolean isConverged() { return converged; }
/** Returns true if training converged, false otherwise. */
public boolean isFinishedTraining() { return converged; }
public int getIteration () { return iterationCount; }
// gsc
public Optimizable.ByGradientValue[] getOptimizableByGradientValueObjects() {
return optimizableByValueGradientObjects;
}
/**
* Returns an optimizable CRF that contains a collection of objective functions.
* <p>
* If one doesn't exist then creates one and sets the optimizer to null.
*/
public OptimizableCRF getOptimizableCRF (InstanceList trainingSet) {
// gsc: user should call setWeightsDimensionsAsIn before the optimizable and
// trainer objects are created
// if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) {
// if (useSparseWeights)
// crf.setWeightsDimensionAsIn (trainingSet, useUnsupportedTrick);
// else
// crf.setWeightsDimensionDensely ();
// ocrf = null;
// cachedWeightsStructureStamp = crf.weightsStructureChangeStamp;
// }
if (ocrf == null || ocrf.trainingSet != trainingSet) {
ocrf = new OptimizableCRF (crf, trainingSet);
opt = null;
}
return ocrf;
}
/**
* Returns a L-BFGS optimizer, creating if one doesn't exist.
* <p>
* Also creates an optimizable CRF if required.
*/
public Optimizer getOptimizer (InstanceList trainingSet) {
getOptimizableCRF(trainingSet); // this will set this.mcrf if necessary
if (opt == null || ocrf != opt.getOptimizable())
opt = new LimitedMemoryBFGS(ocrf); // Alternative: opt = new ConjugateGradient (0.001);
return opt;
}
/** Trains a CRF until convergence. */
public boolean trainIncremental (InstanceList training)
{
return train (training, Integer.MAX_VALUE);
}
/**
* Trains a CRF until convergence or specified number of iterations, whichever is earlier.
* <p>
* Also creates an optimizable CRF and an optmizer if required.
*/
public boolean train (InstanceList trainingSet, int numIterations) {
if (numIterations <= 0)
return false;
assert (trainingSet.size() > 0);
getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary
getOptimizer(trainingSet); // This will set this.opt if necessary
int numResets = 0;
boolean converged = false;
logger.info ("CRF about to train with "+numIterations+" iterations");
for (int i = 0; i < numIterations; i++) {
try {
// gsc: timing each iteration
long startTime = System.currentTimeMillis();
converged = opt.optimize (1);
logger.info ("CRF finished one iteration of maximizer, i="+i+", "+
+(System.currentTimeMillis()-startTime)/1000 + " secs.");
iterationCount++;
runEvaluators();
} catch (OptimizationException e) {
// gsc: resetting the optimizer for specified number of times
e.printStackTrace();
logger.info ("Catching exception.");
if (numResets < maxResets) {
// reset the optimizer and get a new one
logger.info("Resetting optimizer.");
++numResets;
opt = null;
getOptimizer(trainingSet);
// logger.info ("Catching exception; saying converged.");
// converged = true;
} else {
logger.info("Saying converged.");
converged = true;
}
}
if (converged) {
logger.info ("CRF training has converged, i="+i);
break;
}
}
return converged;
}
/**
* Train a CRF on various-sized subsets of the data. This method is typically used to accelerate training by
* quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data.
* @param training The training Instances.
* @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion.
* @param trainingProportions If non-null, train on increasingly
* larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}. This can sometimes speedup convergence.
* Be sure to end in 1.0 if you want to train on all the data in the end.
* @return True if training has converged.
*/
public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions)
{
int trainingIteration = 0;
assert (trainingProportions.length > 0);
boolean converged = false;
for (int i = 0; i < trainingProportions.length; i++) {
assert (trainingProportions[i] <= 1.0);
logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");
if (trainingProportions[i] == 1.0)
converged = this.train (training, numIterationsPerProportion);
else
converged = this.train (training.split (new Random(1),
new double[] {trainingProportions[i], 1-trainingProportions[i]})[0], numIterationsPerProportion);
trainingIteration += numIterationsPerProportion;
}
return converged;
}
// gsc: see comment in getOptimizableCRF
// public void setUseSparseWeights (boolean b) { useSparseWeights = b; }
// public boolean getUseSparseWeights () { return useSparseWeights; }
//
// // gsc
// public void setUseUnsupportedTrick (boolean b) { useUnsupportedTrick = b; }
// public boolean getUseUnsupportedTrick () { return useUnsupportedTrick; }
// gsc: change max. number of times the optimizer can be reset before
// throwing the "could not step in current direction" exception
/**
* Sets the max. number of times the optimizer can be reset before throwing
* an exception.
* <p>
* Default value: <tt>DEFAULT_MAX_RESETS</tt>.
*/
public void setMaxResets(int maxResets) { this.maxResets = maxResets; }
/** An optimizable CRF that contains a collection of objective functions. */
public class OptimizableCRF implements Optimizable.ByGradientValue, Serializable
{
InstanceList trainingSet;
double cachedValue = -123456789;
double[] cachedGradie;
BitSet infiniteValues = null;
CRF crf;
Optimizable.ByGradientValue[] opts;
protected OptimizableCRF (CRF crf, InstanceList ilist)
{
// Set up
this.crf = crf;
this.trainingSet = ilist;
this.opts = optimizableByValueGradientObjects;
cachedGradie = new double[crf.parameters.getNumFactors()];
cachedValueWeightsStamp = -1;
cachedGradientWeightsStamp = -1;
}
// protected OptimizableCRF (CRF crf, InstanceList ilist)
// {
// // Set up
// this.crf = crf;
// this.trainingSet = ilist;
// cachedGradie = new double[crf.parameters.getNumFactors()];
// Class[] parameterTypes = new Class[] {CRF.class, InstanceList.class};
// for (int i = 0; i < optimizableByValueGradientClasses.length; i++) {
// try {
// Constructor c = optimizableByValueGradientClasses[i].getConstructor(parameterTypes);
// opts[i] = (Optimizable.ByGradientValue) c.newInstance(crf, ilist);
// } catch (Exception e) { throw new IllegalStateException ("Couldn't contruct Optimizable.ByGradientValue"); }
// }
// cachedValueWeightsStamp = -1;
// cachedGradientWeightsStamp = -1;
// }
// TODO Move these implementations into CRF.java, and put here stubs that call them!
public int getNumParameters () {
return crf.parameters.getNumFactors();
}
public void getParameters (double[] buffer) {
crf.parameters.getParameters(buffer);
}
public double getParameter (int index) {
return crf.parameters.getParameter(index);
}
public void setParameters (double [] buff) {
crf.parameters.setParameters(buff);
crf.weightsValueChanged();
}
public void setParameter (int index, double value) {
crf.parameters.setParameter(index, value);
crf.weightsValueChanged();
}
/** Returns the log probability of the training sequence labels and the prior over parameters. */
public double getValue ()
{
if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) {
// The cached value is not up to date; it was calculated for a different set of CRF weights.
long startingTime = System.currentTimeMillis();
cachedValue = 0;
for (int i = 0; i < opts.length; i++)
cachedValue += opts[i].getValue();
cachedValueWeightsStamp = crf.weightsValueChangeStamp; // cachedValue is now no longer stale
logger.info ("getValue() (loglikelihood) = "+cachedValue);
logger.fine ("Inference milliseconds = "+(System.currentTimeMillis() - startingTime));
}
return cachedValue;
}
public void getValueGradient (double [] buffer)
{
// PriorGradient is -parameter/gaussianPriorVariance
// Gradient is (constraint - expectation + PriorGradient)
// == -(expectation - constraint - PriorGradient).
// Gradient points "up-hill", i.e. in the direction of higher value
if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) {
getValue (); // This will fill in the this.expectation, updating it if necessary
MatrixOps.setAll(cachedGradie, 0);
double[] b2 = new double[buffer.length];
for (int i = 0; i < opts.length; i++) {
MatrixOps.setAll(b2, 0);
opts[i].getValueGradient(b2);
MatrixOps.plusEquals(cachedGradie, b2);
}
cachedGradientWeightsStamp = crf.weightsValueChangeStamp;
}
System.arraycopy(cachedGradie, 0, buffer, 0, cachedGradie.length);
}
//Serialization of MaximizableCRF
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(trainingSet);
out.writeDouble(cachedValue);
out.writeObject(cachedGradie);
out.writeObject(infiniteValues);
out.writeObject(crf);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
trainingSet = (InstanceList) in.readObject();
cachedValue = in.readDouble();
cachedGradie = (double[]) in.readObject();
infiniteValues = (BitSet) in.readObject();
crf = (CRF)in.readObject();
}
}
// Serialization for CRFTrainerByValueGradient
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
static final int NULL_INTEGER = -1;
/* Need to check for null pointers. */
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
//out.writeInt(defaultFeatureIndex);
out.writeInt(cachedGradientWeightsStamp);
out.writeInt(cachedValueWeightsStamp);
// out.writeInt(cachedWeightsStructureStamp);
// out.writeBoolean (useSparseWeights);
throw new IllegalStateException("Implementation not yet complete.");
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
//defaultFeatureIndex = in.readInt();
// useSparseWeights = in.readBoolean();
throw new IllegalStateException("Implementation not yet complete.");
}
}