/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
package cc.mallet.fst;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.text.DecimalFormat;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureInducer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.IndexedSparseVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
/* There are several different kinds of numeric values:
"weights" range from -Inf to Inf. High weights make a path more
likely. These don't appear directly in Transducer.java, but appear
as parameters to many subclasses, such as CRFs. Weights are also
often summed, or combined in a dot product with feature vectors.
"unnormalized costs" range from -Inf to Inf. High costs make a
path less likely. Unnormalized costs can be obtained from negated
weights or negated sums of weights. These are often returned by a
TransitionIterator's getValue() method. The LatticeNode.alpha
values are unnormalized costs.
"normalized costs" range from 0 to Inf. High costs make a path
less likely. Normalized costs can safely be considered as the
-log(probability) of some event. They can be obtained by
subtracting a (negative) normalizer from unnormalized costs, for
example, subtracting the total cost of a lattice. Typically
initialCosts and finalCosts are examples of normalized costs, but
they are also allowed to be unnormalized costs. The gammas[][],
stateGammas[], and transitionXis[][] are all normalized costs, as
well as the return value of Lattice.getValue().
"probabilities" range from 0 to 1. High probabilities make a path
more likely. They are obtained from normalized costs by taking the
log and negating.
"sums of probabilities" range from 0 to positive numbers. They are
the sum of several probabilities. These are passed to the
incrementCount() methods.
*/
/**
* Represents a CRF model.
*/
public class CRF extends Transducer implements Serializable
{
private static Logger logger = MalletLogger.getLogger(CRF.class.getName());
static final String LABEL_SEPARATOR = ",";
protected Alphabet inputAlphabet;
protected Alphabet outputAlphabet;
protected ArrayList<State> states = new ArrayList<State> ();
protected ArrayList<State> initialStates = new ArrayList<State> ();
protected HashMap<String,State> name2state = new HashMap<String,State> ();
protected Factors parameters = new Factors ();
//SparseVector[] weights;
//double[] defaultWeights; // parameters for default feature
//Alphabet weightAlphabet = new Alphabet ();
//boolean[] weightsFrozen;
// FeatureInduction can fill this in
protected FeatureSelection globalFeatureSelection;
// "featureSelections" is on a per- weights[i] basis, and over-rides
// (permanently disabling) FeatureInducer's and
// setWeightsDimensionsAsIn() from using these features on these transitions
protected FeatureSelection[] featureSelections;
// Store here the induced feature conjunctions so that these conjunctions can be added to test instances before transduction
protected ArrayList<FeatureInducer> featureInducers = new ArrayList<FeatureInducer>();
// An integer index that gets incremented each time this CRFs parameters get changed
protected int weightsValueChangeStamp = 0;
// An integer index that gets incremented each time this CRFs parameters' structure get changed
protected int weightsStructureChangeStamp = 0;
protected int cachedNumParametersStamp = -1; // A copy of weightsStructureChangeStamp the last time numParameters was calculated
protected int numParameters;
/** A simple, transparent container to hold the parameters or sufficient statistics for the CRF. */
public static class Factors implements Serializable {
public Alphabet weightAlphabet;
public SparseVector[] weights; // parameters on transitions, indexed by "weight index"
public double[] defaultWeights;// parameters for default features, indexed by "weight index"
public boolean[] weightsFrozen; // flag, if true indicating that the weights of this "weight index" should not be changed by learning, indexed by "weight index"
public double [] initialWeights; // indexed by state index
public double [] finalWeights; // indexed by state index
/** Construct a new empty Factors with a new empty weightsAlphabet, 0-length initialWeights and finalWeights, and the other arrays null. */
public Factors () {
weightAlphabet = new Alphabet();
initialWeights = new double[0];
finalWeights = new double[0];
// Leave the rest as null. They will get set later by addState() and addWeight()
// Alternatively, we could create zero-length arrays
}
/** Construct new Factors by mimicking the structure of the other one, but with zero values.
* Always simply point to the other's Alphabet; do not clone it. */
public Factors (Factors other) {
weightAlphabet = other.weightAlphabet;
weights = new SparseVector[other.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) other.weights[i].cloneMatrixZeroed();
defaultWeights = new double[other.defaultWeights.length];
weightsFrozen = other.weightsFrozen; // We don't copy here because we want "expectation" and "constraint" factors to get changes to a CRF.parameters factor. Alternatively we declare freezing to be a change of structure, and force reallocation of "expectations", etc.
initialWeights = new double[other.initialWeights.length];
finalWeights = new double[other.finalWeights.length];
}
/** Construct new Factors by copying the other one. */
public Factors (Factors other, boolean cloneAlphabet) {
weightAlphabet = cloneAlphabet ? (Alphabet) other.weightAlphabet.clone() : other.weightAlphabet;
weights = new SparseVector[other.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) other.weights[i].cloneMatrix();
defaultWeights = other.defaultWeights.clone();
weightsFrozen = other.weightsFrozen;
initialWeights = other.initialWeights.clone();
finalWeights = other.finalWeights.clone();
}
/** Construct a new Factors with the same structure as the parameters of 'crf', but with values initialized to zero.
* This method is typically used to allocate storage for sufficient statistics, expectations, constraints, etc. */
public Factors (CRF crf) {
// TODO Change this implementation to this(crf.parameters)
weightAlphabet = crf.parameters.weightAlphabet; // TODO consider cloning this instead
weights = new SparseVector[crf.parameters.weights.length];
for (int i = 0; i < weights.length; i++)
weights[i] = (SparseVector) crf.parameters.weights[i].cloneMatrixZeroed();
defaultWeights = new double[crf.parameters.weights.length];
weightsFrozen = crf.parameters.weightsFrozen;
assert (crf.numStates() == crf.parameters.initialWeights.length);
assert (crf.parameters.initialWeights.length == crf.parameters.finalWeights.length);
initialWeights = new double[crf.parameters.initialWeights.length];
finalWeights = new double[crf.parameters.finalWeights.length];
}
public int getNumFactors () {
assert (initialWeights.length == finalWeights.length);
assert (defaultWeights.length == weights.length);
int ret = initialWeights.length + finalWeights.length + defaultWeights.length;
for (int i = 0; i < weights.length; i++)
ret += weights[i].numLocations();
return ret;
}
public void zero () {
for (int i = 0; i < weights.length; i++)
weights[i].setAll(0);
Arrays.fill(defaultWeights, 0);
Arrays.fill(initialWeights, 0);
Arrays.fill(finalWeights, 0);
}
public boolean structureMatches (Factors other) {
if (weightAlphabet.size() != other.weightAlphabet.size()) return false;
if (weights.length != other.weights.length) return false;
// gsc: checking each SparseVector's size within weights.
for (int i = 0; i < weights.length; i++)
if (weights[i].numLocations() != other.weights[i].numLocations()) return false;
// Note that we are not checking the indices of the SparseVectors in weights
if (defaultWeights.length != other.defaultWeights.length) return false;
assert (initialWeights.length == finalWeights.length);
if (initialWeights.length != other.initialWeights.length) return false;
return true;
}
public void assertNotNaN () {
for (int i = 0; i < weights.length; i++)
assert (!weights[i].isNaN());
assert (!MatrixOps.isNaN(defaultWeights));
assert (!MatrixOps.isNaN(initialWeights));
assert (!MatrixOps.isNaN(finalWeights));
}
// gsc: checks all weights to make sure there are no NaN or Infinite values,
// this method can be called for checking the weights of constraints and
// expectations but not for crf.parameters since it can have infinite
// weights associated with states that are not likely.
public void assertNotNaNOrInfinite () {
for (int i = 0; i < weights.length; i++)
assert (!weights[i].isNaNOrInfinite());
assert (!MatrixOps.isNaNOrInfinite(defaultWeights));
assert (!MatrixOps.isNaNOrInfinite(initialWeights));
assert (!MatrixOps.isNaNOrInfinite(finalWeights));
}
public void plusEquals (Factors other, double factor) {
plusEquals(other, factor, false);
}
public void plusEquals (Factors other, double factor, boolean obeyWeightsFrozen) {
for (int i = 0; i < weights.length; i++) {
if (obeyWeightsFrozen && weightsFrozen[i]) continue;
this.weights[i].plusEqualsSparse(other.weights[i], factor);
this.defaultWeights[i] += other.defaultWeights[i] * factor;
}
for (int i = 0; i < initialWeights.length; i++) {
this.initialWeights[i] += other.initialWeights[i] * factor;
this.finalWeights[i] += other.finalWeights[i] * factor;
}
}
/** Return the log(p(parameters)) according to a zero-mean Gaussian with given variance. */
public double gaussianPrior (double variance) {
double value = 0;
double priorDenom = 2 * variance;
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
if (!Double.isInfinite(initialWeights[i])) value -= initialWeights[i] * initialWeights[i] / priorDenom;
if (!Double.isInfinite(finalWeights[i])) value -= finalWeights[i] * finalWeights[i] / priorDenom;
}
double w;
for (int i = 0; i < weights.length; i++) {
if (!Double.isInfinite(defaultWeights[i])) value -= defaultWeights[i] * defaultWeights[i] / priorDenom;
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) value -= w * w / priorDenom;
}
}
return value;
}
public void plusEqualsGaussianPriorGradient (Factors other, double variance) {
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
// gsc: checking initial/final weights of crf.parameters as well since we could
// have a state machine where some states have infinite initial and/or final weight
if (!Double.isInfinite(initialWeights[i]) && !Double.isInfinite(other.initialWeights[i]))
initialWeights[i] -= other.initialWeights[i] / variance;
if (!Double.isInfinite(finalWeights[i]) && !Double.isInfinite(other.finalWeights[i]))
finalWeights[i] -= other.finalWeights[i] / variance;
}
double w, ow;
for (int i = 0; i < weights.length; i++) {
if (weightsFrozen[i]) continue;
// TODO Note that there doesn't seem to be a way to freeze the initialWeights and finalWeights
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
if (!Double.isInfinite(defaultWeights[i])) defaultWeights[i] -= other.defaultWeights[i] / variance;
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
ow = other.weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) weights[i].setValueAtLocation(j, w - (ow/variance));
}
}
}
/** Return the log(p(parameters)) according to a a hyperbolic curve that is a smooth approximation to an L1 prior. */
public double hyberbolicPrior (double slope, double sharpness) {
double value = 0;
assert (initialWeights.length == finalWeights.length);
for (int i = 0; i < initialWeights.length; i++) {
if (!Double.isInfinite(initialWeights[i]))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * -initialWeights[i])));
if (!Double.isInfinite(finalWeights[i]))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * -finalWeights[i])));
}
double w;
for (int i = 0; i < weights.length; i++) {
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * defaultWeights[i])));
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation(j);
if (!Double.isInfinite(w))
value -= (slope / sharpness * Math.log (Maths.cosh (sharpness * w)));
}
}
return value;
}
public void plusEqualsHyperbolicPriorGradient (Factors other, double slope, double sharpness) {
// TODO This method could use some careful checking over, especially for flipped negations
assert (initialWeights.length == finalWeights.length);
double ss = slope * sharpness;
for (int i = 0; i < initialWeights.length; i++) {
// gsc: checking initial/final weights of crf.parameters as well since we could
// have a state machine where some states have infinite initial and/or final weight
if (!Double.isInfinite(initialWeights[i]) && !Double.isInfinite(other.initialWeights[i]))
initialWeights[i] += ss * Maths.tanh (-other.initialWeights[i]);
if (!Double.isInfinite(finalWeights[i]) && !Double.isInfinite(other.finalWeights[i]))
finalWeights[i] += ss * Maths.tanh (-other.finalWeights[i]);
}
double w, ow;
for (int i = 0; i < weights.length; i++) {
if (weightsFrozen[i]) continue;
// TODO Note that there doesn't seem to be a way to freeze the initialWeights and finalWeights
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
if (!Double.isInfinite(defaultWeights[i])) defaultWeights[i] += ss * Maths.tanh(-other.defaultWeights[i]);
for (int j = 0; j < weights[i].numLocations(); j++) {
w = weights[i].valueAtLocation (j);
ow = other.weights[i].valueAtLocation (j);
if (!Double.isInfinite(w)) weights[i].setValueAtLocation(j, w + (ss * Maths.tanh(-ow)));
}
}
}
/** Instances of this inner class can be passed to various inference methods, which can then
* gather/increment sufficient statistics counts into the containing Factor instance. */
public class Incrementor implements Transducer.Incrementor {
public void incrementFinalState(Transducer.State s, double count) {
finalWeights[s.getIndex()] += count;
}
public void incrementInitialState(Transducer.State s, double count) {
initialWeights[s.getIndex()] += count;
}
public void incrementTransition(Transducer.TransitionIterator ti, double count) {
int index = ti.getIndex();
CRF.State source = (CRF.State)ti.getSourceState();
int nwi = source.weightsIndices[index].length;
int weightsIndex;
for (int wi = 0; wi < nwi; wi++) {
weightsIndex = source.weightsIndices[index][wi];
// For frozen weights, don't even gather their sufficient statistics; this is how we ensure that the gradient for these will be zero
if (weightsFrozen[weightsIndex]) continue;
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
weights[weightsIndex].plusEqualsSparse ((FeatureVector)ti.getInput(), count);
defaultWeights[weightsIndex] += count;
}
}
}
public double getParametersAbsNorm ()
{
double ret = 0;
for (int i = 0; i < initialWeights.length; i++) {
if (initialWeights[i] > Transducer.IMPOSSIBLE_WEIGHT)
ret += Math.abs(initialWeights[i]);
if (finalWeights[i] > Transducer.IMPOSSIBLE_WEIGHT)
ret += Math.abs(finalWeights[i]);
}
for (int i = 0; i < weights.length; i++) {
ret += Math.abs(defaultWeights[i]);
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
ret += Math.abs(weights[i].valueAtLocation(j));
}
return ret;
}
public class WeightedIncrementor implements Transducer.Incrementor {
double instanceWeight = 1.0;
public WeightedIncrementor (double instanceWeight) {
this.instanceWeight = instanceWeight;
}
public void incrementFinalState(Transducer.State s, double count) {
finalWeights[s.getIndex()] += count * instanceWeight;
}
public void incrementInitialState(Transducer.State s, double count) {
initialWeights[s.getIndex()] += count * instanceWeight;
}
public void incrementTransition(Transducer.TransitionIterator ti, double count) {
int index = ti.getIndex();
CRF.State source = (CRF.State)ti.getSourceState();
int nwi = source.weightsIndices[index].length;
int weightsIndex;
count *= instanceWeight;
for (int wi = 0; wi < nwi; wi++) {
weightsIndex = source.weightsIndices[index][wi];
// For frozen weights, don't even gather their sufficient statistics; this is how we ensure that the gradient for these will be zero
if (weightsFrozen[weightsIndex]) continue;
// TODO Should we also obey FeatureSelection here? No need; it is enforced by the creation of the weights.
weights[weightsIndex].plusEqualsSparse ((FeatureVector)ti.getInput(), count);
defaultWeights[weightsIndex] += count;
}
}
}
public void getParameters (double[] buffer)
{
if (buffer.length != getNumFactors ())
throw new IllegalArgumentException ("Expected size of buffer: " + getNumFactors() + ", actual size: " + buffer.length);
int pi = 0;
for (int i = 0; i < initialWeights.length; i++) {
buffer[pi++] = initialWeights[i];
buffer[pi++] = finalWeights[i];
}
for (int i = 0; i < weights.length; i++) {
buffer[pi++] = defaultWeights[i];
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
buffer[pi++] = weights[i].valueAtLocation(j);
}
}
public double getParameter (int index) {
int numStateParms = 2 * initialWeights.length;
if (index < numStateParms) {
if (index % 2 == 0)
return initialWeights[index/2];
return finalWeights[index/2];
}
index -= numStateParms;
for (int i = 0; i < weights.length; i++) {
if (index == 0)
return this.defaultWeights[i];
index--;
if (index < weights[i].numLocations())
return weights[i].valueAtLocation (index);
index -= weights[i].numLocations();
}
throw new IllegalArgumentException ("index too high = "+index);
}
public void setParameters (double [] buff) {
assert (buff.length == getNumFactors());
int pi = 0;
for (int i = 0; i < initialWeights.length; i++) {
initialWeights[i] = buff[pi++];
finalWeights[i] = buff[pi++];
}
for (int i = 0; i < weights.length; i++) {
this.defaultWeights[i] = buff[pi++];
int nl = weights[i].numLocations();
for (int j = 0; j < nl; j++)
weights[i].setValueAtLocation (j, buff[pi++]);
}
}
public void setParameter (int index, double value) {
int numStateParms = 2 * initialWeights.length;
if (index < numStateParms) {
if (index % 2 == 0)
initialWeights[index/2] = value;
else
finalWeights[index/2] = value;
} else {
index -= numStateParms;
for (int i = 0; i < weights.length; i++) {
if (index == 0) {
defaultWeights[i] = value;
return;
}
index--;
if (index < weights[i].numLocations()) {
weights[i].setValueAtLocation (index, value);
return;
} else {
index -= weights[i].numLocations();
}
}
throw new IllegalArgumentException ("index too high = "+index);
}
}
// gsc: Serialization for Factors
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (weightAlphabet);
out.writeObject (weights);
out.writeObject (defaultWeights);
out.writeObject (weightsFrozen);
out.writeObject (initialWeights);
out.writeObject (finalWeights);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
weightAlphabet = (Alphabet) in.readObject ();
weights = (SparseVector[]) in.readObject ();
defaultWeights = (double[]) in.readObject ();
weightsFrozen = (boolean[]) in.readObject ();
initialWeights = (double[]) in.readObject ();
finalWeights = (double[]) in.readObject ();
}
}
public CRF (Pipe inputPipe, Pipe outputPipe)
{
super (inputPipe, outputPipe);
this.inputAlphabet = inputPipe.getDataAlphabet();
this.outputAlphabet = inputPipe.getTargetAlphabet();
//inputAlphabet.stopGrowth();
}
public CRF (Alphabet inputAlphabet, Alphabet outputAlphabet)
{
super (new Noop(inputAlphabet, outputAlphabet), null);
inputAlphabet.stopGrowth();
logger.info ("CRF input dictionary size = "+inputAlphabet.size());
//xxx outputAlphabet.stopGrowth();
this.inputAlphabet = inputAlphabet;
this.outputAlphabet = outputAlphabet;
}
/** Create a CRF whose states and weights are a copy of those from another CRF. */
public CRF (CRF other)
{
// This assumes that "other" has non-null inputPipe and outputPipe. We'd need to add another constructor to handle this if not.
this (other.getInputPipe (), other.getOutputPipe ());
copyStatesAndWeightsFrom (other);
assertWeightsLength ();
}
private void copyStatesAndWeightsFrom (CRF initialCRF)
{
this.parameters = new Factors (initialCRF.parameters, true); // This will copy all the transition weights
this.parameters.weightAlphabet = (Alphabet) initialCRF.parameters.weightAlphabet.clone();
//weightAlphabet = (Alphabet) initialCRF.weightAlphabet.clone ();
//weights = new SparseVector [initialCRF.weights.length];
states.clear ();
// Clear these, because they will be filled by this.addState()
this.parameters.initialWeights = new double[0];
this.parameters.finalWeights = new double[0];
for (int i = 0; i < initialCRF.states.size(); i++) {
State s = (State) initialCRF.getState (i);
String[][] weightNames = new String[s.weightsIndices.length][];
for (int j = 0; j < weightNames.length; j++) {
int[] thisW = s.weightsIndices[j];
weightNames[j] = (String[]) initialCRF.parameters.weightAlphabet.lookupObjects(thisW, new String [s.weightsIndices[j].length]);
}
addState (s.name, initialCRF.parameters.initialWeights[i], initialCRF.parameters.finalWeights[i],
s.destinationNames, s.labels, weightNames);
}
featureSelections = initialCRF.featureSelections.clone ();
// yyy weightsFrozen = (boolean[]) initialCRF.weightsFrozen.clone();
}
public Alphabet getInputAlphabet () { return inputAlphabet; }
public Alphabet getOutputAlphabet () { return outputAlphabet; }
/** This method should be called whenever the CRFs weights (parameters) have their structure/arity/number changed. */
public void weightsStructureChanged () {
weightsStructureChangeStamp++;
weightsValueChangeStamp++;
}
/** This method should be called whenever the CRFs weights (parameters) are changed. */
public void weightsValueChanged () {
weightsValueChangeStamp++;
}
// This method can be over-ridden in subclasses of CRF to return subclasses of CRF.State
protected CRF.State newState (String name, int index,
double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames,
CRF crf)
{
return new State (name, index, initialWeight, finalWeight,
destinationNames, labelNames, weightNames, crf);
}
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames)
{
assert (weightNames.length == destinationNames.length);
assert (labelNames.length == destinationNames.length);
weightsStructureChanged();
if (name2state.get(name) != null)
throw new IllegalArgumentException ("State with name `"+name+"' already exists.");
parameters.initialWeights = MatrixOps.append(parameters.initialWeights, initialWeight);
parameters.finalWeights = MatrixOps.append(parameters.finalWeights, finalWeight);
State s = newState (name, states.size(), initialWeight, finalWeight,
destinationNames, labelNames, weightNames, this);
s.print ();
states.add (s);
if (initialWeight > IMPOSSIBLE_WEIGHT)
initialStates.add (s);
name2state.put (name, s);
}
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[] weightNames)
{
String[][] newWeightNames = new String[weightNames.length][1];
for (int i = 0; i < weightNames.length; i++)
newWeightNames[i][0] = weightNames[i];
this.addState (name, initialWeight, finalWeight, destinationNames, labelNames, newWeightNames);
}
/** Default gives separate parameters to each transition. */
public void addState (String name, double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames)
{
assert (destinationNames.length == labelNames.length);
String[] weightNames = new String[labelNames.length];
for (int i = 0; i < labelNames.length; i++)
weightNames[i] = name + "->" + destinationNames[i] + ":" + labelNames[i];
this.addState (name, initialWeight, finalWeight, destinationNames, labelNames, weightNames);
}
/** Add a state with parameters equal zero, and labels on out-going arcs
the same name as their destination state names. */
public void addState (String name, String[] destinationNames)
{
this.addState (name, 0, 0, destinationNames, destinationNames);
}
/** Add a group of states that are fully connected with each other,
* with parameters equal zero, and labels on their out-going arcs
* the same name as their destination state names. */
public void addFullyConnectedStates (String[] stateNames)
{
for (int i = 0; i < stateNames.length; i++)
addState (stateNames[i], stateNames);
}
public void addFullyConnectedStatesForLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.info ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
addFullyConnectedStates (labels);
}
public void addStartState ()
{
addStartState ("<START>");
}
public void addStartState (String name)
{
for (int i = 0; i < numStates (); i++)
parameters.initialWeights[i] = IMPOSSIBLE_WEIGHT;
String[] dests = new String [numStates()];
for (int i = 0; i < dests.length; i++)
dests[i] = getState(i).getName();
addState (name, 0, 0.0, dests, dests); // initialWeight of 0.0
}
public void setAsStartState (State state)
{
for (int i = 0; i < numStates(); i++) {
Transducer.State other = getState (i);
if (other == state) {
other.setInitialWeight (0);
} else {
other.setInitialWeight (IMPOSSIBLE_WEIGHT);
}
}
weightsValueChanged();
}
private boolean[][] labelConnectionsIn (InstanceList trainingSet)
{
return labelConnectionsIn (trainingSet, null);
}
private boolean[][] labelConnectionsIn (InstanceList trainingSet, String start)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = new boolean[numLabels][numLabels];
for (int i = 0; i < trainingSet.size(); i++) {
Instance instance = trainingSet.get(i);
FeatureSequence output = (FeatureSequence) instance.getTarget();
for (int j = 1; j < output.size(); j++) {
int sourceIndex = outputAlphabet.lookupIndex (output.get(j-1));
int destIndex = outputAlphabet.lookupIndex (output.get(j));
assert (sourceIndex >= 0 && destIndex >= 0);
connections[sourceIndex][destIndex] = true;
}
}
// Handle start state
if (start != null) {
int startIndex = outputAlphabet.lookupIndex (start);
for (int j = 0; j < outputAlphabet.size(); j++) {
connections[startIndex][j] = true;
}
}
return connections;
}
/**
* Add states to create a first-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);
addState ((String)outputAlphabet.lookupObject(i), destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate weights
* for each source-destination pair of states. Instead have all the incoming
* transitions to a state share the same weights.
*/
public void addStatesForHalfLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate
* observational-test-weights for each source-destination pair of
* states---instead have all the incoming transitions to a state share the
* same observational-feature-test weights. However, do create separate
* default feature for each transition, (which acts as an HMM-style transition
* probability).
*/
public void addStatesForThreeQuarterLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) numDestinations++;
String[] destinationNames = new String[numDestinations];
String[][] weightNames = new String[numDestinations][];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) {
String labelName = (String)outputAlphabet.lookupObject(j);
destinationNames[destinationIndex] = labelName;
weightNames[destinationIndex] = new String[2];
// The "half-labels" will include all observed tests
weightNames[destinationIndex][0] = labelName;
// The "transition" weights will include only the default feature
String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j);
weightNames[destinationIndex][1] = wn;
int wi = getWeightsIndex (wn);
// A new empty FeatureSelection won't allow any features here, so we only
// get the default feature for transitions
featureSelections[wi] = new FeatureSelection(trainingSet.getDataAlphabet());
destinationIndex++;
}
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, weightNames);
}
}
public void addFullyConnectedStatesForThreeQuarterLabels (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
for (int i = 0; i < numLabels; i++) {
String[] destinationNames = new String[numLabels];
String[][] weightNames = new String[numLabels][];
for (int j = 0; j < numLabels; j++) {
String labelName = (String)outputAlphabet.lookupObject(j);
destinationNames[j] = labelName;
weightNames[j] = new String[2];
// The "half-labels" will include all observational tests
weightNames[j][0] = labelName;
// The "transition" weights will include only the default feature
String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j);
weightNames[j][1] = wn;
int wi = getWeightsIndex (wn);
// A new empty FeatureSelection won't allow any features here, so we only
// get the default feature for transitions
featureSelections[wi] = new FeatureSelection(trainingSet.getDataAlphabet());
}
addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames, weightNames);
}
}
public void addFullyConnectedStatesForBiLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.info ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
String[] destinationNames = new String[labels.length];
for (int k = 0; k < labels.length; k++)
destinationNames[k] = labels[j]+LABEL_SEPARATOR+labels[k];
addState (labels[i]+LABEL_SEPARATOR+labels[j], 0.0, 0.0,
destinationNames, labels);
}
}
}
/**
* Add states to create a second-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet)
{
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn (trainingSet);
for (int i = 0; i < numLabels; i++) {
for (int j = 0; j < numLabels; j++) {
if (!connections[i][j])
continue;
int numDestinations = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k]) numDestinations++;
String[] destinationNames = new String[numDestinations];
String[] labels = new String[numDestinations];
int destinationIndex = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k]) {
destinationNames[destinationIndex] =
(String)outputAlphabet.lookupObject(j)+LABEL_SEPARATOR+(String)outputAlphabet.lookupObject(k);
labels[destinationIndex] = (String)outputAlphabet.lookupObject(k);
destinationIndex++;
}
addState ((String)outputAlphabet.lookupObject(i)+LABEL_SEPARATOR+
(String)outputAlphabet.lookupObject(j), 0.0, 0.0,
destinationNames, labels);
}
}
}
public void addFullyConnectedStatesForTriLabels ()
{
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.info ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
for (int k = 0; k < labels.length; k++) {
String[] destinationNames = new String[labels.length];
for (int l = 0; l < labels.length; l++)
destinationNames[l] = labels[j]+LABEL_SEPARATOR+labels[k]+LABEL_SEPARATOR+labels[l];
addState (labels[i]+LABEL_SEPARATOR+labels[j]+LABEL_SEPARATOR+labels[k], 0.0, 0.0,
destinationNames, labels);
}
}
}
}
public void addSelfTransitioningStateForAllLabels (String name)
{
String[] labels = new String[outputAlphabet.size()];
String[] destinationNames = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.info ("CRF: outputAlphabet.lookup class = "+
outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = (String) outputAlphabet.lookupObject(i);
destinationNames[i] = name;
}
addState (name, 0.0, 0.0, destinationNames, labels);
}
private String concatLabels(String[] labels)
{
String sep = "";
StringBuffer buf = new StringBuffer();
for (int i = 0; i < labels.length; i++)
{
buf.append(sep).append(labels[i]);
sep = LABEL_SEPARATOR;
}
return buf.toString();
}
private String nextKGram(String[] history, int k, String next)
{
String sep = "";
StringBuffer buf = new StringBuffer();
int start = history.length + 1 - k;
for (int i = start; i < history.length; i++)
{
buf.append(sep).append(history[i]);
sep = LABEL_SEPARATOR;
}
buf.append(sep).append(next);
return buf.toString();
}
private boolean allowedTransition(String prev, String curr,
Pattern no, Pattern yes)
{
String pair = concatLabels(new String[]{prev, curr});
if (no != null && no.matcher(pair).matches())
return false;
if (yes != null && !yes.matcher(pair).matches())
return false;
return true;
}
private boolean allowedHistory(String[] history, Pattern no, Pattern yes) {
for (int i = 1; i < history.length; i++)
if (!allowedTransition(history[i-1], history[i], no, yes))
return false;
return true;
}
/**
* Assumes that the CRF's output alphabet contains
* <code>String</code>s. Creates an order-<em>n</em> CRF with input
* predicates and output labels given by <code>trainingSet</code>
* and order, connectivity, and weights given by the remaining
* arguments.
*
* @param trainingSet the training instances
* @param orders an array of increasing non-negative numbers giving
* the orders of the features for this CRF. The largest number
* <em>n</em> is the Markov order of the CRF. States are
* <em>n</em>-tuples of output labels. Each of the other numbers
* <em>k</em> in <code>orders</code> represents a weight set shared
* by all destination states whose last (most recent) <em>k</em>
* labels agree. If <code>orders</code> is <code>null</code>, an
* order-0 CRF is built.
* @param defaults If non-null, it must be the same length as
* <code>orders</code>, with <code>true</code> positions indicating
* that the weight set for the corresponding order contains only the
* weight for a default feature; otherwise, the weight set has
* weights for all features built from input predicates.
* @param start The label that represents the context of the start of
* a sequence. It may be also used for sequence labels. If no label of
* this name exists, one will be added. Connection wills be added between
* the start label and all other labels, even if <tt>fullyConnected</tt> is
* <tt>false</tt>. This argument may be null, in which case no special
* start state is added.
* @param forbidden If non-null, specifies what pairs of successive
* labels are not allowed, both for constructing <em>n</em>order
* states or for transitions. A label pair (<em>u</em>,<em>v</em>)
* is not allowed if <em>u</em> + "," + <em>v</em> matches
* <code>forbidden</code>.
* @param allowed If non-null, specifies what pairs of successive
* labels are allowed, both for constructing <em>n</em>order
* states or for transitions. A label pair (<em>u</em>,<em>v</em>)
* is allowed only if <em>u</em> + "," + <em>v</em> matches
* <code>allowed</code>.
* @param fullyConnected Whether to include all allowed transitions,
* even those not occurring in <code>trainingSet</code>,
* @return The name of the start state.
*
*/
public String addOrderNStates(InstanceList trainingSet, int[] orders,
boolean[] defaults, String start,
Pattern forbidden, Pattern allowed,
boolean fullyConnected)
{
boolean[][] connections = null;
if (start != null)
outputAlphabet.lookupIndex (start);
if (!fullyConnected)
connections = labelConnectionsIn (trainingSet, start);
int order = -1;
if (defaults != null && defaults.length != orders.length)
throw new IllegalArgumentException("Defaults must be null or match orders");
if (orders == null)
order = 0;
else
{
for (int i = 0; i < orders.length; i++) {
if (orders[i] <= order)
throw new IllegalArgumentException("Orders must be non-negative and in ascending order");
order = orders[i];
}
if (order < 0) order = 0;
}
if (order > 0)
{
int[] historyIndexes = new int[order];
String[] history = new String[order];
String label0 = (String)outputAlphabet.lookupObject(0);
for (int i = 0; i < order; i++)
history[i] = label0;
int numLabels = outputAlphabet.size();
while (historyIndexes[0] < numLabels)
{
logger.info("Preparing " + concatLabels(history));
if (allowedHistory(history, forbidden, allowed))
{
String stateName = concatLabels(history);
int nt = 0;
String[] destNames = new String[numLabels];
String[] labelNames = new String[numLabels];
String[][] weightNames = new String[numLabels][orders.length];
for (int nextIndex = 0; nextIndex < numLabels; nextIndex++)
{
String next = (String)outputAlphabet.lookupObject(nextIndex);
if (allowedTransition(history[order-1], next, forbidden, allowed)
&& (fullyConnected ||
connections[historyIndexes[order-1]][nextIndex]))
{
destNames[nt] = nextKGram(history, order, next);
labelNames[nt] = next;
for (int i = 0; i < orders.length; i++)
{
weightNames[nt][i] = nextKGram(history, orders[i]+1, next);
if (defaults != null && defaults[i]) {
int wi = getWeightsIndex (weightNames[nt][i]);
// Using empty feature selection gives us only the
// default features
featureSelections[wi] =
new FeatureSelection(trainingSet.getDataAlphabet());
}
}
nt++;
}
}
if (nt < numLabels)
{
String[] newDestNames = new String[nt];
String[] newLabelNames = new String[nt];
String[][] newWeightNames = new String[nt][];
for (int t = 0; t < nt; t++)
{
newDestNames[t] = destNames[t];
newLabelNames[t] = labelNames[t];
newWeightNames[t] = weightNames[t];
}
destNames = newDestNames;
labelNames = newLabelNames;
weightNames = newWeightNames;
}
for (int i = 0; i < destNames.length; i++)
{
StringBuffer b = new StringBuffer();
for (int j = 0; j < orders.length; j++)
b.append(" ").append(weightNames[i][j]);
logger.info(stateName + "->" + destNames[i] +
"(" + labelNames[i] + ")" + b.toString());
}
addState (stateName, 0.0, 0.0, destNames, labelNames, weightNames);
}
for (int o = order-1; o >= 0; o--)
if (++historyIndexes[o] < numLabels)
{
history[o] = (String)outputAlphabet.lookupObject(historyIndexes[o]);
break;
} else if (o > 0)
{
historyIndexes[o] = 0;
history[o] = label0;
}
}
for (int i = 0; i < order; i++)
history[i] = start;
return concatLabels(history);
}
String[] stateNames = new String[outputAlphabet.size()];
for (int s = 0; s < outputAlphabet.size(); s++)
stateNames[s] = (String)outputAlphabet.lookupObject(s);
for (int s = 0; s < outputAlphabet.size(); s++)
addState(stateNames[s], 0.0, 0.0, stateNames, stateNames, stateNames);
return start;
}
public State getState (String name)
{
return name2state.get(name);
}
public void setWeights (int weightsIndex, SparseVector transitionWeights)
{
weightsStructureChanged();
if (weightsIndex >= parameters.weights.length || weightsIndex < 0)
throw new IllegalArgumentException ("weightsIndex "+weightsIndex+" is out of bounds");
parameters.weights[weightsIndex] = transitionWeights;
}
public void setWeights (String weightName, SparseVector transitionWeights)
{
setWeights (getWeightsIndex (weightName), transitionWeights);
}
public String getWeightsName (int weightIndex)
{
return (String) parameters.weightAlphabet.lookupObject (weightIndex);
}
public SparseVector getWeights (String weightName)
{
return parameters.weights[getWeightsIndex (weightName)];
}
public SparseVector getWeights (int weightIndex)
{
return parameters.weights[weightIndex];
}
public double[] getDefaultWeights () {
return parameters.defaultWeights;
}
public SparseVector[] getWeights () {
return parameters.weights;
}
public void setWeights (SparseVector[] m) {
weightsStructureChanged();
parameters.weights = m;
}
public void setDefaultWeights (double[] w) {
weightsStructureChanged();
parameters.defaultWeights = w;
}
public void setDefaultWeight (int widx, double val) {
weightsValueChanged();
parameters.defaultWeights[widx] = val;
}
// Support for making cc.mallet.optimize.Optimizable CRFs
public boolean isWeightsFrozen (int weightsIndex)
{
return parameters.weightsFrozen [weightsIndex];
}
/**
* Freezes a set of weights to their current values.
* Frozen weights are used for labeling sequences (as in <tt>transduce</tt>),
* but are not be modified by the <tt>train</tt> methods.
* @param weightsIndex Index of weight set to freeze.
*/
public void freezeWeights (int weightsIndex)
{
parameters.weightsFrozen [weightsIndex] = true;
}
/**
* Freezes a set of weights to their current values.
* Frozen weights are used for labeling sequences (as in <tt>transduce</tt>),
* but are not be modified by the <tt>train</tt> methods.
* @param weightsName Name of weight set to freeze.
*/
public void freezeWeights (String weightsName)
{
int widx = getWeightsIndex (weightsName);
freezeWeights (widx);
}
/**
* Unfreezes a set of weights.
* Frozen weights are used for labeling sequences (as in <tt>transduce</tt>),
* but are not be modified by the <tt>train</tt> methods.
* @param weightsName Name of weight set to unfreeze.
*/
public void unfreezeWeights (String weightsName)
{
int widx = getWeightsIndex (weightsName);
parameters.weightsFrozen[widx] = false;
}
public void setFeatureSelection (int weightIdx, FeatureSelection fs)
{
featureSelections [weightIdx] = fs;
weightsStructureChanged(); // Is this necessary? -akm 11/2007
}
public void setWeightsDimensionAsIn (InstanceList trainingData) {
setWeightsDimensionAsIn(trainingData, false);
}
// gsc: changing this to consider the case when trainingData is a mix of labeled and unlabeled data,
// and we want to use the unlabeled data as well to set some weights (while using the unsupported trick)
// *note*: 'target' sequence of an unlabeled instance is either null or is of size zero.
public void setWeightsDimensionAsIn (InstanceList trainingData, boolean useSomeUnsupportedTrick)
{
final BitSet[] weightsPresent;
int numWeights = 0;
// The value doesn't actually change, because the "new" parameters will have zero value
// but the gradient changes because the parameters now have different layout.
weightsStructureChanged();
weightsPresent = new BitSet[parameters.weights.length];
for (int i = 0; i < parameters.weights.length; i++)
weightsPresent[i] = new BitSet();
// Put in the weights that are already there
for (int i = 0; i < parameters.weights.length; i++)
for (int j = parameters.weights[i].numLocations()-1; j >= 0; j--)
weightsPresent[i].set (parameters.weights[i].indexAtLocation(j));
// Put in the weights in the training set
for (int i = 0; i < trainingData.size(); i++) {
Instance instance = trainingData.get(i);
FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
// gsc: trainingData can have unlabeled instances as well
if (output != null && output.size() > 0) {
// Do it for the paths consistent with the labels...
sumLatticeFactory.newSumLattice (this, input, output, new Transducer.Incrementor() {
public void incrementTransition (Transducer.TransitionIterator ti, double count) {
State source = (CRF.State)ti.getSourceState();
FeatureVector input = (FeatureVector)ti.getInput();
int index = ti.getIndex();
int nwi = source.weightsIndices[index].length;
for (int wi = 0; wi < nwi; wi++) {
int weightsIndex = source.weightsIndices[index][wi];
for (int i = 0; i < input.numLocations(); i++) {
int featureIndex = input.indexAtLocation(i);
if ((globalFeatureSelection == null || globalFeatureSelection.contains(featureIndex))
&& (featureSelections == null
|| featureSelections[weightsIndex] == null
|| featureSelections[weightsIndex].contains(featureIndex)))
weightsPresent[weightsIndex].set (featureIndex);
}
}
}
public void incrementInitialState (Transducer.State s, double count) { }
public void incrementFinalState (Transducer.State s, double count) { }
});
}
// ...and also do it for the paths selected by the current model (so we will get some negative weights)
if (useSomeUnsupportedTrick && this.getParametersAbsNorm() > 0) {
if (i == 0)
logger.info ("CRF: Incremental training detected. Adding weights for some unsupported features...");
// (do this once some training is done)
sumLatticeFactory.newSumLattice (this, input, null, new Transducer.Incrementor() {
public void incrementTransition (Transducer.TransitionIterator ti, double count) {
if (count < 0.2) // Only create features for transitions with probability above 0.2
return; // This 0.2 is somewhat arbitrary -akm
State source = (CRF.State)ti.getSourceState();
FeatureVector input = (FeatureVector)ti.getInput();
int index = ti.getIndex();
int nwi = source.weightsIndices[index].length;
for (int wi = 0; wi < nwi; wi++) {
int weightsIndex = source.weightsIndices[index][wi];
for (int i = 0; i < input.numLocations(); i++) {
int featureIndex = input.indexAtLocation(i);
if ((globalFeatureSelection == null || globalFeatureSelection.contains(featureIndex))
&& (featureSelections == null
|| featureSelections[weightsIndex] == null
|| featureSelections[weightsIndex].contains(featureIndex)))
weightsPresent[weightsIndex].set (featureIndex);
}
}
}
public void incrementInitialState (Transducer.State s, double count) { }
public void incrementFinalState (Transducer.State s, double count) { }
});
}
}
SparseVector[] newWeights = new SparseVector[parameters.weights.length];
for (int i = 0; i < parameters.weights.length; i++) {
int numLocations = weightsPresent[i].cardinality ();
logger.info ("CRF weights["+parameters.weightAlphabet.lookupObject(i)+"] num features = "+numLocations);
int[] indices = new int[numLocations];
for (int j = 0; j < numLocations; j++) {
indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);
//System.out.println ("CRF4 has index "+indices[j]);
}
newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],
numLocations, numLocations, false, false, false);
newWeights[i].plusEqualsSparse (parameters.weights[i]); // Put in the previous weights
numWeights += (numLocations + 1);
}
logger.info("Number of weights = "+numWeights);
parameters.weights = newWeights;
}
public void setWeightsDimensionDensely ()
{
weightsStructureChanged();
SparseVector[] newWeights = new SparseVector [parameters.weights.length];
int max = inputAlphabet.size();
int numWeights = 0;
logger.info ("CRF using dense weights, num input features = "+max);
for (int i = 0; i < parameters.weights.length; i++) {
int nfeatures;
if (featureSelections[i] == null) {
nfeatures = max;
newWeights [i] = new SparseVector (null, new double [max],
max, max, false, false, false);
} else {
// Respect the featureSelection
FeatureSelection fs = featureSelections[i];
nfeatures = fs.getBitSet ().cardinality ();
int[] idxs = new int [nfeatures];
int j = 0, thisIdx = -1;
while ((thisIdx = fs.nextSelectedIndex (thisIdx + 1)) >= 0) {
idxs[j++] = thisIdx;
}
newWeights[i] = new IndexedSparseVector (idxs, new double [nfeatures], nfeatures, nfeatures, false, false, false);
}
newWeights [i].plusEqualsSparse (parameters.weights [i]);
numWeights += (nfeatures + 1);
}
logger.info("Number of weights = "+numWeights);
parameters.weights = newWeights;
}
// Create a new weight Vector if weightName is new.
public int getWeightsIndex (String weightName)
{
int wi = parameters.weightAlphabet.lookupIndex (weightName);
if (wi == -1)
throw new IllegalArgumentException ("Alphabet frozen, and no weight with name "+ weightName);
if (parameters.weights == null) {
assert (wi == 0);
parameters.weights = new SparseVector[1];
parameters.defaultWeights = new double[1];
featureSelections = new FeatureSelection[1];
parameters.weightsFrozen = new boolean [1];
// Use initial capacity of 8
parameters.weights[0] = new IndexedSparseVector ();
parameters.defaultWeights[0] = 0;
featureSelections[0] = null;
weightsStructureChanged();
} else if (wi == parameters.weights.length) {
SparseVector[] newWeights = new SparseVector[parameters.weights.length+1];
double[] newDefaultWeights = new double[parameters.weights.length+1];
FeatureSelection[] newFeatureSelections = new FeatureSelection[parameters.weights.length+1];
for (int i = 0; i < parameters.weights.length; i++) {
newWeights[i] = parameters.weights[i];
newDefaultWeights[i] = parameters.defaultWeights[i];
newFeatureSelections[i] = featureSelections[i];
}
newWeights[wi] = new IndexedSparseVector ();
newDefaultWeights[wi] = 0;
newFeatureSelections[wi] = null;
parameters.weights = newWeights;
parameters.defaultWeights = newDefaultWeights;
featureSelections = newFeatureSelections;
parameters.weightsFrozen = ArrayUtils.append (parameters.weightsFrozen, false);
weightsStructureChanged();
}
//setTrainable (false);
return wi;
}
private void assertWeightsLength ()
{
if (parameters.weights != null) {
assert parameters.defaultWeights != null;
assert featureSelections != null;
assert parameters.weightsFrozen != null;
int n = parameters.weights.length;
assert parameters.defaultWeights.length == n;
assert featureSelections.length == n;
assert parameters.weightsFrozen.length == n;
}
}
public int numStates () { return states.size(); }
public Transducer.State getState (int index) {
return states.get(index); }
public Iterator initialStateIterator () {
return initialStates.iterator (); }
public boolean isTrainable () { return true; }
// gsc: accessor methods
public int getWeightsValueChangeStamp() {
return weightsValueChangeStamp;
}
// kedar: access structure stamp method
public int getWeightsStructureChangeStamp() {
return weightsStructureChangeStamp;
}
public Factors getParameters ()
{
return parameters;
}
// gsc
public double getParametersAbsNorm ()
{
double ret = 0;
for (int i = 0; i < numStates(); i++) {
ret += Math.abs (parameters.initialWeights[i]);
ret += Math.abs (parameters.finalWeights[i]);
}
for (int i = 0; i < parameters.weights.length; i++) {
ret += Math.abs (parameters.defaultWeights[i]);
ret += parameters.weights[i].absNorm();
}
return ret;
}
/** Only sets the parameter from the first group of parameters. */
public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value)
{
setParameter(sourceStateIndex, destStateIndex, featureIndex, 0, value);
}
public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, int weightIndex, double value)
{
weightsValueChanged();
State source = (State)getState(sourceStateIndex);
State dest = (State) getState(destStateIndex);
int rowIndex;
for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)
if (source.destinationNames[rowIndex].equals (dest.name))
break;
if (rowIndex == source.destinationNames.length)
throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");
int weightsIndex = source.weightsIndices[rowIndex][weightIndex];
if (featureIndex < 0)
parameters.defaultWeights[weightsIndex] = value;
else {
parameters.weights[weightsIndex].setValue (featureIndex, value);
}
}
/** Only gets the parameter from the first group of parameters. */
public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex)
{
return getParameter(sourceStateIndex,destStateIndex,featureIndex,0);
}
public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, int weightIndex)
{
State source = (State)getState(sourceStateIndex);
State dest = (State) getState(destStateIndex);
int rowIndex;
for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++)
if (source.destinationNames[rowIndex].equals (dest.name))
break;
if (rowIndex == source.destinationNames.length)
throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+".");
int weightsIndex = source.weightsIndices[rowIndex][weightIndex];
if (featureIndex < 0)
return parameters.defaultWeights[weightsIndex];
return parameters.weights[weightsIndex].value (featureIndex);
}
public int getNumParameters () {
if (cachedNumParametersStamp != weightsStructureChangeStamp) {
this.numParameters = 2 * this.numStates() + this.parameters.defaultWeights.length;
for (int i = 0; i < parameters.weights.length; i++)
numParameters += parameters.weights[i].numLocations();
}
return this.numParameters;
}
/** This method is deprecated. */
// But it is here as a reminder to do something about induceFeaturesFor(). */
@Deprecated
public Sequence[] predict (InstanceList testing) {
testing.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);
klfi.induceFeaturesFor (testing, false, false);
}
Sequence[] ret = new Sequence[testing.size()];
for (int i = 0; i < testing.size(); i++) {
Instance instance = testing.get(i);
Sequence input = (Sequence) instance.getData();
Sequence trueOutput = (Sequence) instance.getTarget();
assert (input.size() == trueOutput.size());
Sequence predOutput = new MaxLatticeDefault(this, input).bestOutputSequence();
assert (predOutput.size() == trueOutput.size());
ret[i] = predOutput;
}
return ret;
}
/** This method is deprecated. */
@Deprecated
public void evaluate (TransducerEvaluator eval, InstanceList testing) {
throw new IllegalStateException ("This method is no longer usable. Use CRF.induceFeaturesFor() instead.");
/*
testing.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);
klfi.induceFeaturesFor (testing, false, false);
}
eval.evaluate (this, true, 0, true, 0.0, null, null, testing);
*/
}
/** When the CRF has done feature induction, these new feature conjunctions must be
* created in the test or validation data in order for them to take effect. */
public void induceFeaturesFor (InstanceList instances) {
instances.setFeatureSelection(this.globalFeatureSelection);
for (int i = 0; i < featureInducers.size(); i++) {
FeatureInducer klfi = featureInducers.get(i);
klfi.induceFeaturesFor (instances, false, false);
}
}
// TODO Put support to Optimizable here, including getValue(InstanceList)??
public void print ()
{
print (new PrintWriter (new OutputStreamWriter (System.out), true));
}
public void print (PrintWriter out)
{
out.println ("*** CRF STATES ***");
for (int i = 0; i < numStates (); i++) {
State s = (State) getState (i);
out.print ("STATE NAME=\"");
out.print (s.name); out.print ("\" ("); out.print (s.destinations.length); out.print (" outgoing transitions)\n");
out.print (" "); out.print ("initialWeight = "); out.print (parameters.initialWeights[i]); out.print ('\n');
out.print (" "); out.print ("finalWeight = "); out.print (parameters.finalWeights[i]); out.print ('\n');
out.println (" transitions:");
for (int j = 0; j < s.destinations.length; j++) {
out.print (" "); out.print (s.name); out.print (" -> "); out.println (s.getDestinationState (j).getName ());
for (int k = 0; k < s.weightsIndices[j].length; k++) {
out.print (" WEIGHTS = \"");
int widx = s.weightsIndices[j][k];
out.print (parameters.weightAlphabet.lookupObject (widx).toString ());
out.print ("\"\n");
}
}
out.println ();
}
if (parameters.weights == null)
out.println ("\n\n*** NO WEIGHTS ***");
else {
out.println ("\n\n*** CRF WEIGHTS ***");
for (int widx = 0; widx < parameters.weights.length; widx++) {
out.println ("WEIGHTS NAME = " + parameters.weightAlphabet.lookupObject (widx));
out.print (": <DEFAULT_FEATURE> = "); out.print (parameters.defaultWeights[widx]); out.print ('\n');
SparseVector transitionWeights = parameters.weights[widx];
if (transitionWeights.numLocations () == 0)
continue;
RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights);
for (int m = 0; m < rfv.numLocations (); m++) {
double v = rfv.getValueAtRank (m);
//int index = rfv.indexAtLocation (rfv.getIndexAtRank (m)); // This doesn't make any sense. How did this ever work? -akm 12/2007
int index = rfv.getIndexAtRank (m);
Object feature = inputAlphabet.lookupObject (index);
if (v != 0) {
out.print (": "); out.print (feature); out.print (" = "); out.println (v);
}
}
}
}
out.flush ();
}
public void write (File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
oos.writeObject(this);
oos.close();
}
catch (IOException e) {
System.err.println("Exception writing file " + f + ": " + e);
}
}
// gsc: Serialization for CRF class
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (inputAlphabet);
out.writeObject (outputAlphabet);
out.writeObject (states);
out.writeObject (initialStates);
out.writeObject (name2state);
out.writeObject (parameters);
out.writeObject (globalFeatureSelection);
out.writeObject (featureSelections);
out.writeObject (featureInducers);
out.writeInt (weightsValueChangeStamp);
out.writeInt (weightsStructureChangeStamp);
out.writeInt (cachedNumParametersStamp);
out.writeInt (numParameters);
}
@SuppressWarnings("unchecked")
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
inputAlphabet = (Alphabet) in.readObject ();
outputAlphabet = (Alphabet) in.readObject ();
states = (ArrayList<State>) in.readObject ();
initialStates = (ArrayList<State>) in.readObject ();
name2state = (HashMap) in.readObject ();
parameters = (Factors) in.readObject ();
globalFeatureSelection = (FeatureSelection) in.readObject ();
featureSelections = (FeatureSelection[]) in.readObject ();
featureInducers = (ArrayList<FeatureInducer>) in.readObject ();
weightsValueChangeStamp = in.readInt ();
weightsStructureChangeStamp = in.readInt ();
cachedNumParametersStamp = in.readInt ();
numParameters = in.readInt ();
}
// Why is this "static"? Couldn't it be a non-static inner class? (In Transducer also) -akm 12/2007
public static class State extends Transducer.State implements Serializable
{
// Parameters indexed by destination state, feature index
String name;
int index;
String[] destinationNames;
State[] destinations; // N.B. elements are null until getDestinationState(int) is called
int[][] weightsIndices; // contains indices into CRF.weights[],
String[] labels;
CRF crf;
// No arg constructor so serialization works
protected State() {
super ();
}
protected State (String name, int index,
double initialWeight, double finalWeight,
String[] destinationNames,
String[] labelNames,
String[][] weightNames,
CRF crf)
{
super ();
assert (destinationNames.length == labelNames.length);
assert (destinationNames.length == weightNames.length);
this.name = name;
this.index = index;
// Note: setting these parameters here is actually redundant; they were set already in CRF.addState(...)
// I'm considering removing initialWeight and finalWeight as arguments to this constructor, but need to think more -akm 12/2007
// If CRF.State were non-static, then this constructor could add the state to the list of states, and put it in the name2state also.
crf.parameters.initialWeights[index] = initialWeight;
crf.parameters.finalWeights[index] = finalWeight;
this.destinationNames = new String[destinationNames.length];
this.destinations = new State[labelNames.length];
this.weightsIndices = new int[labelNames.length][];
this.labels = new String[labelNames.length];
this.crf = crf;
for (int i = 0; i < labelNames.length; i++) {
// Make sure this label appears in our output Alphabet
crf.outputAlphabet.lookupIndex (labelNames[i]);
this.destinationNames[i] = destinationNames[i];
this.labels[i] = labelNames[i];
this.weightsIndices[i] = new int[weightNames[i].length];
for (int j = 0; j < weightNames[i].length; j++)
this.weightsIndices[i][j] = crf.getWeightsIndex (weightNames[i][j]);
}
crf.weightsStructureChanged();
}
public Transducer getTransducer () { return crf; }
public double getInitialWeight () { return crf.parameters.initialWeights[index]; }
public void setInitialWeight (double c) { crf.parameters.initialWeights[index]= c; }
public double getFinalWeight () { return crf.parameters.finalWeights[index]; }
public void setFinalWeight (double c) { crf.parameters.finalWeights[index] = c; }
public void print ()
{
System.out.println ("State #"+index+" \""+name+"\"");
System.out.println ("initialWeight="+crf.parameters.initialWeights[index]+", finalWeight="+crf.parameters.finalWeights[index]);
System.out.println ("#destinations="+destinations.length);
for (int i = 0; i < destinations.length; i++)
System.out.println ("-> "+destinationNames[i]);
}
public int numDestinations () { return destinations.length;}
public String[] getWeightNames (int index) {
int[] indices = this.weightsIndices[index];
String[] ret = new String[indices.length];
for (int i=0; i < ret.length; i++)
ret[i] = crf.parameters.weightAlphabet.lookupObject(indices[i]).toString();
return ret;
}
public void addWeight (int didx, String weightName) {
int widx = crf.getWeightsIndex (weightName);
weightsIndices[didx] = ArrayUtils.append (weightsIndices[didx], widx);
}
public String getLabelName (int index) {
return labels [index];
}
public State getDestinationState (int index)
{
State ret;
if ((ret = destinations[index]) == null) {
ret = destinations[index] = crf.name2state.get (destinationNames[index]);
if (ret == null)
throw new IllegalArgumentException ("this.name="+this.name+" index="+index+" destinationNames[index]="+destinationNames[index]+" name2state.size()="+ crf.name2state.size());
}
return ret;
}
public Transducer.TransitionIterator transitionIterator (Sequence inputSequence, int inputPosition,
Sequence outputSequence, int outputPosition)
{
if (inputPosition < 0 || outputPosition < 0)
throw new UnsupportedOperationException ("Epsilon transitions not implemented.");
if (inputSequence == null)
throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence.");
return new TransitionIterator (this, (FeatureVectorSequence)inputSequence, inputPosition,
(outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf);
}
public Transducer.TransitionIterator transitionIterator (FeatureVector fv, String output)
{
return new TransitionIterator (this, fv, output, crf);
}
public String getName () { return name; }
// "final" to make it efficient inside incrementTransition
public final int getIndex () { return index; }
// Serialization
// For class State
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject(name);
out.writeInt(index);
out.writeObject(destinationNames);
out.writeObject(destinations);
out.writeObject(weightsIndices);
out.writeObject(labels);
out.writeObject(crf);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
name = (String) in.readObject();
index = in.readInt();
destinationNames = (String[]) in.readObject();
destinations = (CRF.State[]) in.readObject();
weightsIndices = (int[][]) in.readObject();
labels = (String[]) in.readObject();
crf = (CRF) in.readObject();
}
}
protected static class TransitionIterator extends Transducer.TransitionIterator implements Serializable
{
State source;
int index, nextIndex;
protected double[] weights;
FeatureVector input;
CRF crf;
public TransitionIterator (State source,
FeatureVectorSequence inputSeq,
int inputPosition,
String output, CRF crf)
{
this (source, inputSeq.get(inputPosition), output, crf);
}
protected TransitionIterator (State source,
FeatureVector fv,
String output, CRF crf)
{
this.source = source;
this.crf = crf;
this.input = fv;
this.weights = new double[source.destinations.length];
int nwi, swi;
for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) {
// xxx Or do we want output.equals(...) here?
if (output == null || output.equals(source.labels[transIndex])) {
// Here is the dot product of the feature weights with the lambda weights
// for one transition
weights[transIndex] = 0;
nwi = source.weightsIndices[transIndex].length;
for (int wi = 0; wi < nwi; wi++) {
swi = source.weightsIndices[transIndex][wi];
weights[transIndex] += (crf.parameters.weights[swi].dotProduct (fv)
// include with implicit weight 1.0 the default feature
+ crf.parameters.defaultWeights[swi]);
}
assert (!Double.isNaN(weights[transIndex]));
assert (weights[transIndex] != Double.POSITIVE_INFINITY);
}
else
weights[transIndex] = IMPOSSIBLE_WEIGHT;
}
// Prepare nextIndex, pointing at the next non-impossible transition
nextIndex = 0;
while (nextIndex < source.destinations.length && weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
}
public boolean hasNext () { return nextIndex < source.destinations.length; }
public Transducer.State nextState ()
{
assert (nextIndex < source.destinations.length);
index = nextIndex;
nextIndex++;
while (nextIndex < source.destinations.length && weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
return source.getDestinationState (index);
}
// These "final"s are just to try to make this more efficient. Perhaps some of them will have to go away
public final int getIndex () { return index; }
public final Object getInput () { return input; }
public final Object getOutput () { return source.labels[index]; }
public final double getWeight () { return weights[index]; }
public final Transducer.State getSourceState () { return source; }
public final Transducer.State getDestinationState () { return source.getDestinationState (index); }
// Serialization
// TransitionIterator
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeObject (source);
out.writeInt (index);
out.writeInt (nextIndex);
out.writeObject(weights);
out.writeObject (input);
out.writeObject(crf);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.readInt ();
source = (State) in.readObject();
index = in.readInt ();
nextIndex = in.readInt ();
weights = (double[]) in.readObject();
input = (FeatureVector) in.readObject();
crf = (CRF) in.readObject();
}
public String describeTransition (double cutoff)
{
DecimalFormat f = new DecimalFormat ("0.###");
StringBuffer buf = new StringBuffer ();
buf.append ("Value: " + f.format (-getWeight ()) + " <br />\n");
try {
int[] theseWeights = source.weightsIndices[index];
for (int i = 0; i < theseWeights.length; i++) {
int wi = theseWeights[i];
SparseVector w = crf.parameters.weights[wi];
buf.append ("WEIGHTS <br />\n" + crf.parameters.weightAlphabet.lookupObject (wi) + "<br />\n");
buf.append (" d.p. = "+f.format (w.dotProduct (input))+"<br />\n");
double[] vals = new double[input.numLocations ()];
double[] absVals = new double[input.numLocations ()];
for (int k = 0; k < vals.length; k++) {
int index = input.indexAtLocation (k);
vals[k] = w.value (index) * input.value (index);
absVals[k] = Math.abs (vals[k]);
}
buf.append ("DEFAULT " + f.format (crf.parameters.defaultWeights[wi]) + "<br />\n");
RankedFeatureVector rfv = new RankedFeatureVector (crf.inputAlphabet, input.getIndices (), absVals);
for (int rank = 0; rank < absVals.length; rank++) {
int fidx = rfv.getIndexAtRank (rank);
Object fname = crf.inputAlphabet.lookupObject (input.indexAtLocation (fidx));
if (absVals[fidx] < cutoff) break; // Break looping over features
if (vals[fidx] != 0) {
buf.append (fname + " " + f.format (vals[fidx]) + "<br />\n");
}
}
}
} catch (Exception e) {
System.err.println ("Error writing transition descriptions.");
e.printStackTrace ();
buf.append ("ERROR WHILE WRITING OUTPUT...\n");
}
return buf.toString ();
}
}
}