package cc.mallet.fst;
import java.util.BitSet;
import java.util.logging.Logger;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.fst.MEMM.State;
import cc.mallet.fst.MEMM.TransitionIterator;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.util.MalletLogger;
/**
* Trains and evaluates a {@link MEMM}.
*/
public class MEMMTrainer extends TransducerTrainer
{
private static Logger logger = MalletLogger.getLogger(MEMMTrainer.class.getName());
MEMM memm;
private boolean gatheringTrainingData = false;
// After training sets have been gathered in the states, record which
// InstanceList we've gathers, so we don't double-count instances.
private InstanceList trainingGatheredFor;
// gsc: user is supposed to set the weights manually, so this flag is not needed
// boolean useSparseWeights = true;
MEMMOptimizableByLabelLikelihood omemm;
public MEMMTrainer (MEMM memm) {
this.memm = memm;
}
public MEMMOptimizableByLabelLikelihood getOptimizableMEMM (InstanceList trainingSet) {
return new MEMMOptimizableByLabelLikelihood (memm, trainingSet);
}
// public MEMMTrainer setUseSparseWeights (boolean f) { useSparseWeights = f; return this; }
/**
* Trains a MEMM until convergence.
*/
public boolean train (InstanceList training) {
return train (training, Integer.MAX_VALUE);
}
/**
* Trains a MEMM for specified number of iterations or until convergence whichever
* occurs first; returns true if training converged within specified iterations.
*/
public boolean train (InstanceList training, int numIterations)
{
if (numIterations <= 0)
return false;
assert (training.size() > 0);
// Allocate space for the parameters, and place transition FeatureVectors in
// per-source-state InstanceLists.
// Here, gatheringTrainingSets will be true, and these methods will result
// in new InstanceList's being created in each source state, and the FeatureVectors
// of their outgoing transitions to be added to them as the data field in the Instances.
if (trainingGatheredFor != training) {
gatherTrainingSets (training);
}
// gsc: the user has to set the weights manually
// if (useSparseWeights) {
// memm.setWeightsDimensionAsIn (training, false);
// } else {
// memm.setWeightsDimensionDensely ();
// }
/*
if (false) {
// Expectation-based placement of training data would go here.
for (int i = 0; i < training.size(); i++) {
Instance instance = training.get(i);
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
// Do it for the paths consistent with the labels...
gatheringConstraints = true;
new SumLatticeDefault (this, input, output, true);
// ...and also do it for the paths selected by the current model (so we will get some negative weights)
gatheringConstraints = false;
if (this.someTrainingDone)
// (do this once some training is done)
new SumLatticeDefault (this, input, null, true);
}
gatheringWeightsPresent = false;
SparseVector[] newWeights = new SparseVector[weights.length];
for (int i = 0; i < weights.length; i++) {
int numLocations = weightsPresent[i].cardinality ();
logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations);
int[] indices = new int[numLocations];
for (int j = 0; j < numLocations; j++) {
indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);
//System.out.println ("CRF4 has index "+indices[j]);
}
newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],
numLocations, numLocations, false, false, false);
newWeights[i].plusEqualsSparse (weights[i]);
}
weights = newWeights;
}
*/
omemm = new MEMMOptimizableByLabelLikelihood (memm, training);
// Gather the constraints
omemm.gatherExpectationsOrConstraints (true);
Optimizer maximizer = new LimitedMemoryBFGS(omemm);
int i;
// boolean continueTraining = true;
boolean converged = false;
logger.info ("CRF about to train with "+numIterations+" iterations");
for (i = 0; i < numIterations; i++) {
try {
converged = maximizer.optimize (1);
logger.info ("CRF finished one iteration of maximizer, i="+i);
runEvaluators();
} catch (IllegalArgumentException e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
converged = true;
}
if (converged) {
logger.info ("CRF training has converged, i="+i);
break;
}
}
logger.info ("About to setTrainable(false)");
return converged;
}
void gatherTrainingSets (InstanceList training)
{
if (trainingGatheredFor != null) {
// It would be easy enough to support this, just go through all the states and set trainingSet to null.
throw new UnsupportedOperationException ("Training with multiple sets not supported.");
}
trainingGatheredFor = training;
for (int i = 0; i < training.size(); i++) {
Instance instance = training.get(i);
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
// Do it for the paths consistent with the labels...
new SumLatticeDefault (memm, input, output, new Transducer.Incrementor() {
public void incrementFinalState(Transducer.State s, double count) { }
public void incrementInitialState(Transducer.State s, double count) { }
public void incrementTransition(Transducer.TransitionIterator ti, double count) {
MEMM.State source = (MEMM.State) ti.getSourceState();
if (count != 0) {
// Create the source state's trainingSet if it doesn't exist yet.
if (source.trainingSet == null)
// New InstanceList with a null pipe, because it doesn't do any processing of input.
source.trainingSet = new InstanceList (null);
// TODO We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..!
// TODO Note that when the training data still allows ambiguous outgoing transitions
// this will add the same FV more than once to the source state's trainingSet, each
// with >1.0 weight. Not incorrect, but inefficient.
// System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput());
source.trainingSet.add (new Instance(ti.getInput (), ti.getOutput (), null, null), count);
}
}
});
}
}
/**
* Not implemented yet.
*
* @throws UnsupportedOperationException
*/
public boolean train (InstanceList training, InstanceList validation, InstanceList testing,
TransducerEvaluator eval, int numIterations,
int numIterationsPerProportion,
double[] trainingProportions)
{
throw new UnsupportedOperationException();
}
/**
* Not implemented yet.
*
* @throws UnsupportedOperationException
*/
public boolean trainWithFeatureInduction (InstanceList trainingData,
InstanceList validationData, InstanceList testingData,
TransducerEvaluator eval, int numIterations,
int numIterationsBetweenFeatureInductions,
int numFeatureInductions,
int numFeaturesPerFeatureInduction,
double trueLabelProbThreshold,
boolean clusteredFeatureInduction,
double[] trainingProportions,
String gainName)
{
throw new UnsupportedOperationException();
}
public void printInstanceLists ()
{
for (int i = 0; i < memm.numStates(); i++) {
State state = (State) memm.getState (i);
InstanceList training = state.trainingSet;
System.out.println ("State "+i+" : "+state.getName());
if (training == null) {
System.out.println ("No data");
continue;
}
for (int j = 0; j < training.size(); j++) {
Instance inst = training.get (j);
System.out.println ("From : "+state.getName()+" To : "+inst.getTarget());
System.out.println ("Instance "+j);
System.out.println (inst.getTarget());
System.out.println (inst.getData());
}
}
}
/**
* Represents the terms in the objective function.
* <p>
* The weights are trained by matching the expectations of the model to the observations gathered from the data.
*/
@SuppressWarnings("serial")
public class MEMMOptimizableByLabelLikelihood extends CRFOptimizableByLabelLikelihood implements Optimizable.ByGradientValue
{
BitSet infiniteValues = null;
protected MEMMOptimizableByLabelLikelihood (MEMM memm, InstanceList trainingData)
{
super (memm, trainingData);
expectations = new CRF.Factors (memm);
constraints = new CRF.Factors (memm);
}
// if constraints=false, return log probability of the training labels
protected double gatherExpectationsOrConstraints (boolean gatherConstraints)
{
// Instance values must either always or never be included in
// the total values; we can't just sometimes skip a value
// because it is infinite, this throws off the total values.
boolean initializingInfiniteValues = false;
CRF.Factors factors = gatherConstraints ? constraints : expectations;
CRF.Factors.Incrementor factorIncrementor = factors.new Incrementor ();
if (infiniteValues == null) {
infiniteValues = new BitSet ();
initializingInfiniteValues = true;
}
double labelLogProb = 0;
for (int i = 0; i < memm.numStates(); i++) {
MEMM.State s = (State) memm.getState (i);
if (s.trainingSet == null) {
System.out.println ("Empty training set for state "+s.name);
continue;
}
for (int j = 0; j < s.trainingSet.size(); j++) {
Instance instance = s.trainingSet.get (j);
double instWeight = s.trainingSet.getInstanceWeight (j);
FeatureVector fv = (FeatureVector) instance.getData ();
String labelString = (String) instance.getTarget ();
TransitionIterator iter = new TransitionIterator (s, fv, gatherConstraints?labelString:null, memm);
while (iter.hasNext ()) {
// gsc
iter.nextState(); // advance the iterator
// State destination = (MEMM.State) iter.nextState(); // Just to advance the iterator
double weight = iter.getWeight();
factorIncrementor.incrementTransition(iter, Math.exp(weight) * instWeight);
//iter.incrementCount (Math.exp(weight) * instWeight);
if (!gatherConstraints && iter.getOutput() == labelString) {
if (!Double.isInfinite (weight))
labelLogProb += instWeight * weight; // xxx ?????
else {
logger.warning ("State "+i+" transition "+j+" has infinite cost; skipping.");
if (initializingInfiniteValues)
throw new IllegalStateException ("Infinite-cost transitions not yet supported"); //infiniteValues.set (j);
else if (!infiniteValues.get(j))
throw new IllegalStateException ("Instance i used to have non-infinite value, "
+"but now it has infinite value.");
}
}
}
}
}
// Force initial & final weight parameters to 0 by making sure that
// whether factor refers to expectation or constraint, they have the same value.
for (int i = 0; i < memm.numStates(); i++) {
factors.initialWeights[i] = 0.0;
factors.finalWeights[i] = 0.0;
}
return labelLogProb;
}
// log probability of the training sequence labels, and fill in expectations[]
protected double getExpectationValue ()
{
return gatherExpectationsOrConstraints (false);
}
}
@Override
public int getIteration() {
// TODO Auto-generated method stub
return 0;
}
@Override
public Transducer getTransducer() {
return memm;
}
@Override
public boolean isFinishedTraining() {
// TODO Auto-generated method stub
return false;
}
}