/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;
import cc.mallet.grmm.types.*;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.grmm.util.CachingOptimizable;
import gnu.trove.THashMap;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
/**
* Created: Mar 15, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: PseudolikelihoodACRFTrainer.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class PseudolikelihoodACRFTrainer extends DefaultAcrfTrainer {
private static final Logger logger = MalletLogger.getLogger (PseudolikelihoodACRFTrainer.class.getName());
private static final boolean printGradient = false;
/** Use per-variable pseudolikelihood. This is the classical version of Besag. */
public static final int BY_VARIABLE = 0;
/** Use per-edge structured pseudolikelihood. */
public static final int BY_EDGE = 1;
private int structureType = BY_VARIABLE;
public int getStructureType ()
{
return structureType;
}
public void setStructureType (int structureType)
{
this.structureType = structureType;
}
public Optimizable.ByGradientValue createOptimizable (ACRF acrf, InstanceList training)
{
return new Maxable (acrf, training);
}
// Controls the structuredness of pl.
private static interface CliquesIterator {
boolean hasNext ();
void advance ();
Factor localConditional ();
ACRF.UnrolledVarSet[] cliques ();
}
private static class VariablesIterator implements CliquesIterator {
private ACRF.UnrolledGraph graph;
private Assignment observed;
// cursors
private int vidx = -1;
private Factor ptl;
private List[] cliquesByVar;
public VariablesIterator (ACRF.UnrolledGraph acrf, Assignment observed)
{
this.graph = acrf;
this.observed = observed;
cliquesByVar = new List[graph.numVariables ()];
for (int i = 0; i < cliquesByVar.length; i++) cliquesByVar[i] = new ArrayList ();
for (Iterator it = acrf.unrolledVarSetIterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
for (int vidx = 0; vidx < clique.size(); vidx++) {
Variable var = clique.get(vidx);
cliquesByVar[graph.getIndex (var)].add (clique);
}
}
}
public boolean hasNext ()
{
return vidx < graph.numVariables () - 1;
}
public void advance ()
{
vidx++;
Variable var = graph.get (vidx);
ptl = new TableFactor (var);
for (Iterator it = cliquesByVar[vidx].iterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
Factor cliquePtl = graph.factorOf (clique);
if (cliquePtl == null)
throw new IllegalStateException
("Could not find potential for clique "+clique);
VarSet vs = new HashVarSet (cliquePtl.varSet ());
vs.remove (var);
Assignment nbrAssn = (Assignment) observed.marginalize (vs);
Factor slice = cliquePtl.slice (nbrAssn);
ptl.multiplyBy (slice);
}
}
public Factor localConditional ()
{
return ptl;
}
public ACRF.UnrolledVarSet[] cliques ()
{
List cliques = cliquesByVar[vidx];
return (ACRF.UnrolledVarSet[]) cliques.toArray (new ACRF.UnrolledVarSet [cliques.size()]);
}
}
private static class EdgesIterator implements CliquesIterator {
private ACRF.UnrolledGraph graph;
private Assignment observed;
// cursors
private Iterator cursor;
private List currentCliqueList;
private Factor ptl;
private THashMap cliquesByEdge;
public EdgesIterator (ACRF.UnrolledGraph acrf, Assignment observed)
{
this.graph = acrf;
this.observed = observed;
cliquesByEdge = new THashMap();
for (Iterator it = acrf.unrolledVarSetIterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
for (int v1idx = 0; v1idx < clique.size(); v1idx++) {
Variable v1 = clique.get(v1idx);
List adjlist = graph.allFactorsContaining (v1);
for (Iterator factorIt = adjlist.iterator(); factorIt.hasNext();) {
Factor factor = (Factor) factorIt.next ();
if (!cliquesByEdge.containsKey (factor)) { cliquesByEdge.put (factor, new ArrayList()); }
List l = (List) cliquesByEdge.get (factor);
if (!l.contains (clique)) { l.add (clique); }
}
}
}
cursor = cliquesByEdge.keySet().iterator ();
}
public boolean hasNext ()
{
return cursor.hasNext();
}
public void advance ()
{
Factor pairFactor = (Factor) cursor.next ();
VarSet pairVarSet = pairFactor.varSet ();
assert pairVarSet.size() == 2; // for now
Variable v1 = pairVarSet.get (0);
Variable v2 = pairVarSet.get (1);
Variable[] vars = new Variable[] { v1, v2 };
ptl = new TableFactor (vars);
// set localObs to assignment to all data EXCEPT v1 and v2
VarSet vs = new HashVarSet (observed.varSet ());
vs.remove (v1);
vs.remove (v2);
Assignment localObs = (Assignment) observed.marginalize (vs);
currentCliqueList = (List) cliquesByEdge.get (pairFactor);
for (Iterator it = currentCliqueList.iterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
Factor cliquePtl = graph.factorOf (clique);
if (cliquePtl == null)
throw new IllegalStateException
("Could not find potential for clique "+clique);
Factor slice;
boolean hasV1 = clique.contains (v1);
boolean hasV2 = clique.contains (v2);
if (hasV1 && hasV2) {
// fast special case
if (cliquePtl.varSet().size() == 2) {
slice = cliquePtl;
} else {
slice = cliquePtl.slice (localObs);
}
} else if (hasV1) { // && !hasV2
slice = cliquePtl.slice (localObs);
} else if (hasV2) { // && !hasV1
slice = cliquePtl.slice (localObs);
} else {
throw new RuntimeException ("Illegal state: cliqu ehas neither edge variable");
}
ptl.multiplyBy (slice);
}
}
public Factor localConditional ()
{
return ptl;
}
public ACRF.UnrolledVarSet[] cliques ()
{
List cliques = currentCliqueList;
return (ACRF.UnrolledVarSet[]) cliques.toArray (new ACRF.UnrolledVarSet [cliques.size()]);
}
}
private CliquesIterator makeCliquesIterator (ACRF.UnrolledGraph acrf, Assignment observed)
{
if (structureType == BY_VARIABLE) {
return new VariablesIterator (acrf, observed);
} else if (structureType == BY_EDGE) {
return new EdgesIterator (acrf, observed);
} else {
throw new IllegalArgumentException ("Unknown structured pseudolikelihood type "+structureType);
}
}
public class Maxable extends CachingOptimizable.ByGradient implements Serializable {
private ACRF acrf;
InstanceList trainData;
private ACRF.Template[] templates;
private ACRF.Template[] fixedTmpls;
protected BitSet infiniteValues = null;
private int numParameters;
private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;
public double getGaussianPriorVariance ()
{
return gaussianPriorVariance;
}
public void setGaussianPriorVariance (double gaussianPriorVariance)
{
this.gaussianPriorVariance = gaussianPriorVariance;
}
private double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
/* Vectors that contain the counts of features observed in the
training data. Maps
(clique-template x feature-number) => count
*/
SparseVector constraints[][];
/* Vectors that contain the expected value over the
* labels of all the features, have seen the training data
* (but not the training labels).
*/
SparseVector expectations[][];
SparseVector defaultConstraints[];
SparseVector defaultExpectations[];
private void initWeights (InstanceList training)
{
// ugh!! There must be a way to abstract this back into ACRF, but I don't know the best way....
// problem is that this maxable doesn't extend the ACRF Maxiximable, so I can't just call its initWeights() method
for (int tidx = 0; tidx < templates.length; tidx++) {
numParameters += templates[tidx].initWeights (training);
}
}
/* Initialize constraints[][] and expectations[][]
* to have the same dimensions as weights, but to
* be all zero.
*/
private void initConstraintsExpectations ()
{
// Do the defaults first
defaultConstraints = new SparseVector [templates.length];
defaultExpectations = new SparseVector [templates.length];
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector defaults = templates[tidx].getDefaultWeights();
defaultConstraints[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
defaultExpectations[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
}
// And now the others
constraints = new SparseVector [templates.length][];
expectations = new SparseVector [templates.length][];
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
constraints [tidx] = new SparseVector [weights.length];
expectations [tidx] = new SparseVector [weights.length];
for (int i = 0; i < weights.length; i++) {
constraints[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
expectations[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
}
}
}
/**
* Set all expectations to 0 after they've been
* initialized.
*/
void resetExpectations ()
{
for (int tidx = 0; tidx < expectations.length; tidx++) {
defaultExpectations [tidx].setAll (0.0);
for (int i = 0; i < expectations[tidx].length; i++) {
expectations[tidx][i].setAll (0.0);
}
}
}
protected Maxable (ACRF acrf, InstanceList ilist)
{
logger.finest ("Initializing OptimizableACRF.");
this.acrf = acrf;
templates = acrf.getTemplates ();
fixedTmpls = acrf.getFixedTemplates ();
/* allocate for weights, constraints and expectations */
this.trainData = ilist;
initWeights(trainData);
initConstraintsExpectations();
int numInstances = trainData.size();
cachedValueStale = cachedGradientStale = true;
/*
if (cacheUnrolledGraphs) {
unrolledGraphs = new UnrolledGraph [numInstances];
}
*/
logger.info("Number of training instances = " + numInstances );
logger.info("Number of parameters = " + numParameters );
describePrior();
logger.fine("Computing constraints");
collectConstraints (trainData);
}
private void describePrior ()
{
logger.info ("Using gaussian prior with variance "+gaussianPriorVariance);
}
public int getNumParameters () { return numParameters; }
/* Negate initialValue and finalValue because the parameters are in
* terms of "weights", not "values".
*/
public void getParameters (double[] buf) {
if ( buf.length != numParameters )
throw new IllegalArgumentException("Argument is not of the " +
" correct dimensions");
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector defaults = tmpl.getDefaultWeights ();
double[] values = defaults.getValues();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights [assn].getValues ();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
}
}
protected void setParametersInternal (double[] params)
{
cachedValueStale = cachedGradientStale = true;
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector defaults = tmpl.getDefaultWeights();
double[] values = defaults.getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights [assn].getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
}
}
// Functions for unit tests to get constraints and expectations
// I'm too lazy to make a deep copy. Callers should not
// modify these.
public SparseVector[] getExpectations (int cnum) { return expectations [cnum]; }
public SparseVector[] getConstraints (int cnum) { return constraints [cnum]; }
/** print weights */
public void printParameters()
{
double[] buf = new double[numParameters];
getParameters(buf);
int len = buf.length;
for (int w = 0; w < len; w++)
System.out.print(buf[w] + "\t");
System.out.println();
}
protected double computeValue () {
double retval = 0.0;
int numInstances = trainData.size();
long start = System.currentTimeMillis();
long unrollTime = 0;
/* Instance values must either always or never be included in
* the total values; we can't just sometimes skip a value
* because it is infinite, that throws off the total values.
* We only allow an instance to have infinite value if it happens
* from the start (we don't compute the value for the instance
* after the first round. If any other instance has infinite
* value after that it is an error. */
boolean initializingInfiniteValues = false;
if (infiniteValues == null) {
/* We could initialize bitset with one slot for every
* instance, but it is *probably* cheaper not to, taking the
* time hit to allocate the space if a bit becomes
* necessary. */
infiniteValues = new BitSet ();
initializingInfiniteValues = true;
}
/* Clear the sufficient statistics that we are about to fill */
resetExpectations();
/* Fill in expectations for each instance */
for (int i = 0; i < numInstances; i++)
{
Instance instance = trainData.get(i);
/* Compute marginals for each clique */
long unrollStart = System.currentTimeMillis ();
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (instance, templates, fixedTmpls);
long unrollEnd = System.currentTimeMillis ();
unrollTime += (unrollEnd - unrollStart);
if (unrolled.numVariables () == 0) continue; // Happens if all nodes are pruned.
/* Save the expected value of each feature for when we
compute the gradient. */
Assignment observations = unrolled.getAssignment ();
double value = collectExpectationsAndValue (unrolled, observations);
if (Double.isInfinite(value))
{
if (initializingInfiniteValues) {
logger.warning ("Instance " + instance.getName() +
" has infinite value; skipping.");
infiniteValues.set (i);
continue;
} else if (!infiniteValues.get(i)) {
logger.warning ("Infinite value on instance "+instance.getName()+
"returning -infinity");
return Double.NEGATIVE_INFINITY;
/*
printDebugInfo (unrolled);
throw new IllegalStateException
("Instance " + instance.getName()+ " used to have non-infinite"
+ " value, but now it has infinite value.");
*/
}
} else if (Double.isNaN (value)) {
System.out.println("NaN on instance "+i+" : "+instance.getName ());
printDebugInfo (unrolled);
/* throw new IllegalStateException
("Value is NaN in ACRF.getValue() Instance "+i);
*/
logger.warning ("Value is NaN in ACRF.getValue() Instance "+i+" : "+
"returning -infinity... ");
return Double.NEGATIVE_INFINITY;
} else {
retval += value;
}
}
/* Incorporate Gaussian prior on parameters. This means
that for each weight, we will add w^2 / (2 * variance) to the
log probability. */
double priorDenom = 2 * gaussianPriorVariance;
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector[] weights = templates [tidx].getWeights ();
for (int j = 0; j < weights.length; j++) {
for (int fnum = 0; fnum < weights[j].numLocations(); fnum++) {
double w = weights [j].valueAtLocation (fnum);
if (weightValid (w, tidx, j)) {
retval += -w*w/priorDenom;
}
}
}
}
long end = System.currentTimeMillis ();
logger.info ("ACRF Inference time (ms) = "+(end-start));
logger.info ("ACRF unroll time (ms) = "+unrollTime);
logger.info ("getValue (loglikelihood) = "+retval);
return retval;
}
/**
* Computes the gradient of the penalized log likelihood of the
* ACRF, and places it in cachedGradient[].
*
* Gradient is
* constraint - expectation - parameters/gaussianPriorVariance
*/
protected void computeValueGradient (double[] grad)
{
/* Index into current element of cachedGradient[] array. */
int gidx = 0;
// First do gradient wrt defaultWeights
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector theseWeights = templates[tidx].getDefaultWeights ();
SparseVector theseConstraints = defaultConstraints [tidx];
SparseVector theseExpectations = defaultExpectations [tidx];
for (int j = 0; j < theseWeights.numLocations(); j++) {
double weight = theseWeights.valueAtLocation (j);
double constraint = theseConstraints.valueAtLocation (j);
double expectation = theseExpectations.valueAtLocation (j);
if (printGradient) {
System.out.println(" gradient ["+gidx+"] = "+constraint+" (ctr) - "+expectation+" (exp) - "+
(weight / gaussianPriorVariance)+" (reg) [feature=DEFAULT]");
}
grad [gidx++] = constraint - expectation - (weight / gaussianPriorVariance);
}
}
// Now do other weights
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates [tidx];
SparseVector[] weights = tmpl.getWeights ();
for (int i = 0; i < weights.length; i++) {
SparseVector thisWeightVec = weights [i];
SparseVector thisConstraintVec = constraints [tidx][i];
SparseVector thisExpectationVec = expectations [tidx][i];
for (int j = 0; j < thisWeightVec.numLocations(); j++) {
double w = thisWeightVec.valueAtLocation (j);
double gradient; // Computed below
double constraint = thisConstraintVec.valueAtLocation(j);
double expectation = thisExpectationVec.valueAtLocation(j);
/* A parameter may be set to -infinity by an external user.
* We set gradient to 0 because the parameter's value can
* never change anyway and it will mess up future calculations
* on the matrix. */
if (Double.isInfinite(w)) {
logger.warning("Infinite weight for node index " +i+
" feature " +
acrf.getInputAlphabet().lookupObject(j) );
gradient = 0.0;
} else {
gradient = constraint
- (w/gaussianPriorVariance)
- expectation;
}
if (printGradient) {
int idx = thisWeightVec.indexAtLocation (j);
Object fname = acrf.getInputAlphabet ().lookupObject (idx);
System.out.println(" gradient ["+gidx+"] = "+constraint+" (ctr) - "+expectation+" (exp) - "+
(w / gaussianPriorVariance)+" (reg) [feature="+fname+"]");
}
grad [gidx++] = gradient;
}
}
}
}
/**
* For every feature f_k, computes the expected value of f_k
* aver all possible label sequences given the list of instances
* we have.
*
* These values are stored in collector, that is,
* collector[i][j][k] gets the expected value for the
* feature for clique i, label assignment j, and input features k.
*/
private double collectExpectationsAndValue (ACRF.UnrolledGraph unrolled, Assignment observations)
{
double value = 0.0;
for (CliquesIterator it = makeCliquesIterator (unrolled, observations); it.hasNext();) {
it.advance ();
TableFactor ptl = (TableFactor) it.localConditional ();
double logZ = ptl.logsum ();
ACRF.UnrolledVarSet[] cliques = it.cliques ();
Assignment assn = (Assignment) observations.duplicate ();
// for each assigment to the clique
// xxx SLOW this will need to be sparsified
AssignmentIterator assnIt = ptl.assignmentIterator ();
while (assnIt.hasNext ()) {
double marginal = Math.exp (ptl.logValue (assnIt) - logZ);
// This is ugly need to map from assignments to the single twiddled variable to clique assignments
Assignment currentAssn = assnIt.assignment ();
for (int vi = 0; vi < currentAssn.numVariables (); vi++) {
Variable var = currentAssn.getVariable (vi);
assn.setValue (0, var, currentAssn.get (var));
}
for (int cidx = 0; cidx < cliques.length; cidx++) {
ACRF.UnrolledVarSet clique = cliques[cidx];
int tidx = clique.getTemplate().index;
if (tidx == -1) continue;
int assnIdx = clique.lookupNumberOfAssignment (assn);
expectations [tidx][assnIdx].plusEqualsSparse (clique.getFv (), marginal);
if (defaultExpectations[tidx].location (assnIdx) != -1)
defaultExpectations [tidx].incrementValue (assnIdx, marginal);
}
assnIt.advance ();
}
value += (ptl.logValue (observations) - logZ);
}
return value;
}
private void collectConstraintsForGraph (ACRF.UnrolledGraph unrolled, Assignment observations)
{
for (CliquesIterator it = makeCliquesIterator (unrolled, observations); it.hasNext();) {
it.advance ();
ACRF.UnrolledVarSet[] cliques = it.cliques ();
for (int cidx = 0; cidx < cliques.length; cidx++) {
ACRF.UnrolledVarSet clique = cliques[cidx];
int tidx = clique.getTemplate().index;
if (tidx < 0) continue;
int assnIdx = clique.lookupNumberOfAssignment (observations);
constraints [tidx][assnIdx].plusEqualsSparse (clique.getFv (), 1.0);
if (defaultConstraints[tidx].location (assnIdx) != -1)
defaultConstraints [tidx].incrementValue (assnIdx, 1.0);
}
}
}
public void collectConstraints (InstanceList ilist)
{
for (int inum = 0; inum < ilist.size(); inum++) {
logger.finest ("*** Collecting constraints for instance "+inum);
Instance inst = ilist.get (inum);
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (inst, templates, null, true);
Assignment assn = unrolled.getAssignment ();
collectConstraintsForGraph (unrolled, assn);
}
}
void dumpGradientToFile (String fileName)
{
try {
double[] grad = new double [getNumParameters ()];
getValueGradient (grad);
PrintStream w = new PrintStream (new FileOutputStream (fileName));
for (int i = 0; i < numParameters; i++) {
w.println (grad[i]);
}
w.close ();
} catch (IOException e) {
System.err.println("Could not open output file.");
e.printStackTrace ();
}
}
void dumpDefaults ()
{
System.out.println("Default constraints");
for (int i = 0; i < defaultConstraints.length; i++) {
System.out.println("Template "+i);
defaultConstraints[i].print ();
}
System.out.println("Default expectations");
for (int i = 0; i < defaultExpectations.length; i++) {
System.out.println("Template "+i);
defaultExpectations[i].print ();
}
}
void printDebugInfo (ACRF.UnrolledGraph unrolled)
{
acrf.print (System.err);
Assignment assn = unrolled.getAssignment ();
for (Iterator it = unrolled.varSetIterator (); it.hasNext();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next();
System.out.println("Clique "+clique);
dumpAssnForClique (assn, clique);
Factor ptl = unrolled.factorOf (clique);
System.out.println("Value = "+ptl.value (assn));
System.out.println(ptl);
}
}
void dumpAssnForClique (Assignment assn, ACRF.UnrolledVarSet clique)
{
for (Iterator it = clique.iterator(); it.hasNext();) {
Variable var = (Variable) it.next();
System.out.println(var+" ==> "+assn.getObject (var)
+" ("+assn.get (var)+")");
}
}
private boolean weightValid (double w, int cnum, int j)
{
if (Double.isInfinite (w)) {
logger.warning ("Weight is infinite for clique "+cnum+"assignment "+j);
return false;
} else if (Double.isNaN (w)) {
logger.warning ("Weight is Nan for clique "+cnum+"assignment "+j);
return false;
} else {
return true;
}
}
} // OptimizableACRF
}