package cc.mallet.fst;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
/**
* Implements label likelihood gradient computations for batches of data, can be
* easily parallelized. <p>
*
* The gradient computations are the same as that of
* <tt>CRFOptimizableByLabelLikelihood</tt>. <p>
*
* *Note*: Expectations corresponding to each batch of data can be computed in
* parallel. During gradient computation, the prior and the constraints are
* incorporated into the expectations of the last batch (see
* <tt>getBatchValue, getBatchValueGradient</tt>).
*
* *Note*: This implementation ignores instances with infinite weights (see
* <tt>getExpectationValue</tt>).
*
* @author Gaurav Chandalia
*/
public class CRFOptimizableByBatchLabelLikelihood implements Optimizable.ByCombiningBatchGradient, Serializable {
private static Logger logger = MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName());
static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
protected CRF crf;
protected InstanceList trainingSet;
// number of batches of training set
protected int numBatches;
// batch specific expectations
protected List<CRF.Factors> expectations;
// constraints over whole training set
protected CRF.Factors constraints;
// value and gradient for each batch, to avoid sharing
protected double[] cachedValue;
protected List<double[]> cachedGradient;
boolean usingHyperbolicPrior = false;
double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
double hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
public CRFOptimizableByBatchLabelLikelihood(CRF crf, InstanceList ilist, int numBatches) {
// set up
this.crf = crf;
this.trainingSet = ilist;
this.numBatches = numBatches;
cachedValue = new double[this.numBatches];
cachedGradient = new ArrayList<double[]>(this.numBatches);
expectations = new ArrayList<CRF.Factors>(this.numBatches);
int numFactors = crf.parameters.getNumFactors();
for (int i = 0; i < this.numBatches; ++i) {
cachedGradient.add(new double[numFactors]);
expectations.add(new CRF.Factors(crf.parameters));
}
constraints = new CRF.Factors(crf.parameters);
gatherConstraints(ilist);
}
/**
* Set the constraints by running forward-backward with the <i>output label
* sequence provided</i>, thus restricting it to only those paths that agree with
* the label sequence.
*/
protected void gatherConstraints(InstanceList ilist) {
logger.info("Gathering constraints...");
assert (constraints.structureMatches(crf.parameters));
constraints.zero();
for (Instance instance : ilist) {
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
double instanceWeight = ilist.getInstanceWeight(instance);
Transducer.Incrementor incrementor =
instanceWeight == 1.0 ? constraints.new Incrementor()
: constraints.new WeightedIncrementor(instanceWeight);
new SumLatticeDefault (this.crf, input, output, incrementor);
}
constraints.assertNotNaNOrInfinite();
}
/**
* Computes log probability of a batch of training data, fill in corresponding
* expectations as well
*/
protected double getExpectationValue(int batchIndex, int[] batchAssignments) {
// Reset expectations to zero before we fill them again
CRF.Factors batchExpectations = expectations.get(batchIndex);
batchExpectations.zero();
// count the number of instances that have infinite weight
int numInfLabeledWeight = 0;
int numInfUnlabeledWeight = 0;
int numInfWeight = 0;
double value = 0;
double unlabeledWeight, labeledWeight, weight;
for (int ii = batchAssignments[0]; ii < batchAssignments[1]; ii++) {
Instance instance = trainingSet.get(ii);
double instanceWeight = trainingSet.getInstanceWeight(instance);
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
labeledWeight = new SumLatticeDefault (this.crf, input, output, null).getTotalWeight();
if (Double.isInfinite (labeledWeight)) {
++numInfLabeledWeight;
}
Transducer.Incrementor incrementor = instanceWeight == 1.0 ? batchExpectations.new Incrementor()
: batchExpectations.new WeightedIncrementor (instanceWeight);
unlabeledWeight = new SumLatticeDefault (this.crf, input, null, incrementor).getTotalWeight();
if (Double.isInfinite (unlabeledWeight)) {
++numInfUnlabeledWeight;
}
// weight is log(conditional probability correct label sequence)
weight = labeledWeight - unlabeledWeight;
if (Double.isInfinite(weight)) {
++numInfWeight;
} else {
// Weights are log probabilities, and we want to return a log probability
value += weight * instanceWeight;
}
}
batchExpectations.assertNotNaNOrInfinite();
if (numInfLabeledWeight > 0 || numInfUnlabeledWeight > 0 || numInfWeight > 0) {
logger.warning("Batch: " + batchIndex + ", Number of instances with:\n" +
"\t -infinite labeled weight: " + numInfLabeledWeight + "\n" +
"\t -infinite unlabeled weight: " + numInfUnlabeledWeight + "\n" +
"\t -infinite weight: " + numInfWeight);
}
return value;
}
/**
* Returns the log probability of a batch of training sequence labels and the prior over
* parameters, if last batch then incorporate the prior on parameters as well.
*/
public double getBatchValue(int batchIndex, int[] batchAssignments) {
assert(batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " +
this.numBatches + ")";
assert(batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1])
: "Invalid batch assignments: " + Arrays.toString(batchAssignments);
// Get the value of all the true labels for current batch, also filling in expectations
double value = getExpectationValue(batchIndex, batchAssignments);
if (batchIndex == numBatches-1) {
if (usingHyperbolicPrior) // Hyperbolic prior
value += crf.parameters.hyberbolicPrior(hyperbolicPriorSlope, hyperbolicPriorSharpness);
else // Gaussian prior
value += crf.parameters.gaussianPrior(gaussianPriorVariance);
}
assert(!(Double.isNaN(value) || Double.isInfinite(value)))
: "Label likelihood is NaN/Infinite, batchIndex: " + batchIndex + "batchAssignments: " + Arrays.toString(batchAssignments);
// update cache
cachedValue[batchIndex] = value;
return value;
}
public void getBatchValueGradient(double[] buffer, int batchIndex, int[] batchAssignments) {
assert(batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " +
this.numBatches + ")";
assert(batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1])
: "Invalid batch assignments: " + Arrays.toString(batchAssignments);
CRF.Factors batchExpectations = expectations.get(batchIndex);
if (batchIndex == numBatches-1) {
// crf parameters' check has to be done only once, infinite values are allowed
crf.parameters.assertNotNaN();
// factor the constraints and the prior into the expectations of last batch
// Gradient = (constraints - expectations + prior) = -(expectations - constraints - prior)
// The minus sign is factored in combineGradients method after all gradients are computed
batchExpectations.plusEquals(constraints, -1.0);
if (usingHyperbolicPrior)
batchExpectations.plusEqualsHyperbolicPriorGradient(crf.parameters, -hyperbolicPriorSlope, hyperbolicPriorSharpness);
else
batchExpectations.plusEqualsGaussianPriorGradient(crf.parameters, -gaussianPriorVariance);
batchExpectations.assertNotNaNOrInfinite();
}
double[] gradient = cachedGradient.get(batchIndex);
// set the cached gradient
batchExpectations.getParameters(gradient);
System.arraycopy(gradient, 0, buffer, 0, gradient.length);
}
/**
* Adds gradients from all batches. <p>
* <b>Note:</b> assumes buffer is already initialized.
*/
public void combineGradients(Collection<double[]> batchGradients, double[] buffer) {
assert(buffer.length == crf.parameters.getNumFactors())
: "Incorrect buffer length: " + buffer.length + ", expected: " + crf.parameters.getNumFactors();
Arrays.fill(buffer, 0);
for (double[] gradient : batchGradients) {
MatrixOps.plusEquals(buffer, gradient);
}
// -(...) from getBatchValueGradient
MatrixOps.timesEquals(buffer, -1.0);
}
public int getNumBatches() { return numBatches; }
public void setUseHyperbolicPrior (boolean f) { usingHyperbolicPrior = f; }
public void setHyperbolicPriorSlope (double p) { hyperbolicPriorSlope = p; }
public void setHyperbolicPriorSharpness (double p) { hyperbolicPriorSharpness = p; }
public double getUseHyperbolicPriorSlope () { return hyperbolicPriorSlope; }
public double getUseHyperbolicPriorSharpness () { return hyperbolicPriorSharpness; }
public void setGaussianPriorVariance (double p) { gaussianPriorVariance = p; }
public double getGaussianPriorVariance () { return gaussianPriorVariance; }
public int getNumParameters () {return crf.parameters.getNumFactors();}
public void getParameters (double[] buffer) {
crf.parameters.getParameters(buffer);
}
public double getParameter (int index) {
return crf.parameters.getParameter(index);
}
public void setParameters (double [] buff) {
crf.parameters.setParameters(buff);
crf.weightsValueChanged();
}
public void setParameter (int index, double value) {
crf.parameters.setParameter(index, value);
crf.weightsValueChanged();
}
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(trainingSet);
out.writeObject(crf);
out.writeInt(numBatches);
out.writeObject(cachedValue);
for (double[] gradient : cachedGradient)
out.writeObject(gradient);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
trainingSet = (InstanceList) in.readObject();
crf = (CRF)in.readObject();
numBatches = in.readInt();
cachedValue = (double[]) in.readObject();
cachedGradient = new ArrayList<double[]>(numBatches);
for (int i = 0; i < numBatches; ++i)
cachedGradient.set(i, (double[]) in.readObject());
}
public static class Factory {
public Optimizable.ByCombiningBatchGradient newCRFOptimizable (CRF crf, InstanceList trainingData, int numBatches) {
return new CRFOptimizableByBatchLabelLikelihood (crf, trainingData, numBatches);
}
}
}