/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.types.*;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.*;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Timing;
import cc.mallet.grmm.util.CachingOptimizable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
/**
* Implementation of piecewise PL (Sutton and McCallum, 2007)
*
* NB The wrong-wrong options are for an extension that we tried that never quite worked
*
* Created: Mar 15, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: BiconditionalPiecewiseACRFTrainer.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class PwplACRFTrainer extends DefaultAcrfTrainer {
private static final Logger logger = MalletLogger.getLogger (PwplACRFTrainer.class.getName ());
public static boolean printGradient = false;
public static final int NO_WRONG_WRONG = 0;
public static final int CONDITION_WW = 1;
private int wrongWrongType = NO_WRONG_WRONG;
private int wrongWrongIter = 10;
private double wrongWrongThreshold = 0.1;
private File outputPrefix = new File (".");
public Optimizable.ByGradientValue createOptimizable (ACRF acrf, InstanceList training)
{
return new PwplACRFTrainer.Maxable (acrf, training);
}
public double getWrongWrongThreshold ()
{
return wrongWrongThreshold;
}
public void setWrongWrongThreshold (double wrongWrongThreshold)
{
this.wrongWrongThreshold = wrongWrongThreshold;
}
public void setWrongWrongType (int wrongWrongType)
{
this.wrongWrongType = wrongWrongType;
}
public void setWrongWrongIter (int wrongWrongIter)
{
this.wrongWrongIter = wrongWrongIter;
}
public boolean train (ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet,
ACRFEvaluator eval, int numIter, Optimizable.ByGradientValue macrf)
{
if (wrongWrongType == NO_WRONG_WRONG) {
return super.train (acrf, trainingList, validationList, testSet, eval, numIter, macrf);
} else {
Maxable bipwMaxable = (Maxable) macrf;
// add wrong wrongs after 5 iterations
logger.info ("BiconditionalPiecewiseACRFTrainer: Initial training");
super.train (acrf, trainingList, validationList, testSet, eval, wrongWrongIter, macrf);
FileUtils.writeGzippedObject (new File (outputPrefix, "initial-acrf.ser.gz"), acrf);
logger.info ("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
bipwMaxable.addWrongWrong (trainingList);
logger.info ("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
boolean converged = super.train (acrf, trainingList, validationList, testSet, eval, numIter, macrf);
reportTrainingLikelihood (acrf, trainingList);
return converged;
}
}
// Reports true joint likelihood of estimated parameters on the training set.
public static void reportTrainingLikelihood (ACRF acrf, InstanceList trainingList)
{
double total = 0;
Inferencer inf = acrf.getInferencer ();
for (int i = 0; i < trainingList.size (); i++) {
Instance inst = trainingList.get (i);
ACRF.UnrolledGraph unrolled = acrf.unroll (inst);
inf.computeMarginals (unrolled);
double lik = inf.lookupLogJoint (unrolled.getAssignment ());
total += lik;
logger.info ("...instance "+i+" likelihood = "+lik);
}
logger.info ("Unregularized joint likelihood = "+total);
}
public class Maxable extends CachingOptimizable.ByGradient {
private ACRF acrf;
InstanceList trainData;
private ACRF.Template[] templates;
protected BitSet infiniteValues = null;
private int numParameters;
private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;
public double getGaussianPriorVariance ()
{
return gaussianPriorVariance;
}
public void setGaussianPriorVariance (double gaussianPriorVariance)
{
this.gaussianPriorVariance = gaussianPriorVariance;
}
private double gaussianPriorVariance = PwplACRFTrainer.Maxable.DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
/* Vectors that contain the counts of features observed in the
training data. Maps
(clique-template x feature-number) => count
*/
SparseVector constraints[][];
/* Vectors that contain the expected value over the
* labels of all the features, have seen the training data
* (but not the training labels).
*/
SparseVector expectations[][];
SparseVector defaultConstraints[];
SparseVector defaultExpectations[];
private void initWeights (InstanceList training)
{
for (int tidx = 0; tidx < templates.length; tidx++) {
numParameters += templates[tidx].initWeights (training);
}
}
/* Initialize constraints[][] and expectations[][]
* to have the same dimensions as weights, but to
* be all zero.
*/
private void initConstraintsExpectations ()
{
// Do the defaults first
defaultConstraints = new SparseVector [templates.length];
defaultExpectations = new SparseVector [templates.length];
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector defaults = templates[tidx].getDefaultWeights ();
defaultConstraints[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
defaultExpectations[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
}
// And now the others
constraints = new SparseVector [templates.length][];
expectations = new SparseVector [templates.length][];
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector[] weights = tmpl.getWeights ();
constraints[tidx] = new SparseVector [weights.length];
expectations[tidx] = new SparseVector [weights.length];
for (int i = 0; i < weights.length; i++) {
constraints[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
expectations[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
}
}
}
private int numCvgaCalls = 0;
private long timePerCvgaCall = 0;
void resetProfilingForCall ()
{
numCvgaCalls = 0;
timePerCvgaCall = 0;
}
/**
* Set all expectations to 0 after they've been
* initialized.
*/
void resetExpectations ()
{
for (int tidx = 0; tidx < expectations.length; tidx++) {
defaultExpectations[tidx].setAll (0.0);
for (int i = 0; i < expectations[tidx].length; i++) {
expectations[tidx][i].setAll (0.0);
}
}
}
void resetConstraints ()
{
for (int tidx = 0; tidx < constraints.length; tidx++) {
defaultConstraints[tidx].setAll (0.0);
for (int i = 0; i < constraints[tidx].length; i++) {
constraints[tidx][i].setAll (0.0);
}
}
}
protected Maxable (ACRF acrf, InstanceList ilist)
{
PwplACRFTrainer.logger.finest ("Initializing OptimizableACRF.");
this.acrf = acrf;
templates = acrf.getTemplates ();
/* allocate for weights, constraints and expectations */
this.trainData = ilist;
initWeights (trainData);
initConstraintsExpectations ();
int numInstances = trainData.size ();
cachedValueStale = cachedGradientStale = true;
/*
if (cacheUnrolledGraphs) {
unrolledGraphs = new UnrolledGraph [numInstances];
}
*/
PwplACRFTrainer.logger.info ("Number of training instances = " + numInstances);
PwplACRFTrainer.logger.info ("Number of parameters = " + numParameters);
describePrior ();
PwplACRFTrainer.logger.fine ("Computing constraints");
collectConstraints (trainData);
}
private void describePrior ()
{
PwplACRFTrainer.logger.info ("Using gaussian prior with variance " + gaussianPriorVariance);
}
public int getNumParameters () { return numParameters; }
/* Negate initialValue and finalValue because the parameters are in
* terms of "weights", not "values".
*/
public void getParameters (double[] buf)
{
if (buf.length != numParameters) {
throw new IllegalArgumentException ("Argument is not of the " +
" correct dimensions");
}
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector defaults = tmpl.getDefaultWeights ();
double[] values = defaults.getValues ();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector[] weights = tmpl.getWeights ();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights[assn].getValues ();
System.arraycopy (values, 0, buf, idx, values.length);
idx += values.length;
}
}
}
protected void setParametersInternal (double[] params)
{
cachedValueStale = cachedGradientStale = true;
int idx = 0;
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector defaults = tmpl.getDefaultWeights ();
double[] values = defaults.getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector[] weights = tmpl.getWeights ();
for (int assn = 0; assn < weights.length; assn++) {
double[] values = weights[assn].getValues ();
System.arraycopy (params, idx, values, 0, values.length);
idx += values.length;
}
}
}
// Functions for unit tests to get constraints and expectations
// I'm too lazy to make a deep copy. Callers should not
// modify these.
public SparseVector[] getExpectations (int cnum) { return expectations[cnum]; }
public SparseVector[] getConstraints (int cnum) { return constraints[cnum]; }
/**
* print weights
*/
public void printParameters ()
{
double[] buf = new double[numParameters];
getParameters (buf);
int len = buf.length;
for (int w = 0; w < len; w++)
System.out.print (buf[w] + "\t");
System.out.println ();
}
protected double computeValue ()
{
double retval = 0.0;
int numInstances = trainData.size ();
long start = System.currentTimeMillis ();
long unrollTime = 0;
resetProfilingForCall ();
/* Instance values must either always or never be included in
* the total values; we can't just sometimes skip a value
* because it is infinite, that throws off the total values.
* We only allow an instance to have infinite value if it happens
* from the start (we don't compute the value for the instance
* after the first round. If any other instance has infinite
* value after that it is an error. */
boolean initializingInfiniteValues = false;
if (infiniteValues == null) {
/* We could initialize bitset with one slot for every
* instance, but it is *probably* cheaper not to, taking the
* time hit to allocate the space if a bit becomes
* necessary. */
infiniteValues = new BitSet ();
initializingInfiniteValues = true;
}
/* Clear the sufficient statistics that we are about to fill */
resetExpectations ();
/* Fill in expectations for each instance */
for (int i = 0; i < numInstances; i++) {
Instance instance = trainData.get (i);
/* Compute marginals for each clique */
long unrollStart = System.currentTimeMillis ();
ACRF.UnrolledGraph unrolled = acrf.unrollStructureOnly (instance);
// ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (instance, templates, Arrays.asList (fixedTmpls), false);
long unrollEnd = System.currentTimeMillis ();
unrollTime += (unrollEnd - unrollStart);
// if (unrolled.numVariables () == 0) continue; // Happens if all nodes are pruned.
/* Save the expected value of each feature for when we
compute the gradient. */
Assignment observations = unrolled.getAssignment ();
double value = collectExpectationsAndValue (unrolled, observations, i);
if (Double.isInfinite (value)) {
if (initializingInfiniteValues) {
PwplACRFTrainer.logger.warning ("Instance " + instance.getName () +
" has infinite value; skipping.");
infiniteValues.set (i);
// continue;
} else if (!infiniteValues.get (i)) {
PwplACRFTrainer.logger.warning ("Infinite value on instance " + instance.getName () +
"returning -infinity");
return Double.NEGATIVE_INFINITY;
/*
printDebugInfo (unrolled);
throw new IllegalStateException
("Instance " + instance.getName()+ " used to have non-infinite"
+ " value, but now it has infinite value.");
*/
}
} else if (Double.isNaN (value)) {
System.out.println ("NaN on instance " + i + " : " + instance.getName ());
printDebugInfo (unrolled);
/* throw new IllegalStateException
("Value is NaN in ACRF.getValue() Instance "+i);
*/
PwplACRFTrainer.logger.warning ("Value is NaN in ACRF.getValue() Instance " + i + " : " +
"returning -infinity... ");
return Double.NEGATIVE_INFINITY;
} else {
retval += value;
}
}
/* Incorporate Gaussian prior on parameters. This means
that for each weight, we will add w^2 / (2 * variance) to the
log probability. */
double priorDenom = 2 * gaussianPriorVariance;
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector[] weights = templates[tidx].getWeights ();
for (int j = 0; j < weights.length; j++) {
for (int fnum = 0; fnum < weights[j].numLocations (); fnum++) {
double w = weights[j].valueAtLocation (fnum);
if (weightValid (w, tidx, j)) {
retval += -w * w / priorDenom;
}
}
}
}
long end = System.currentTimeMillis ();
PwplACRFTrainer.logger.info ("ACRF Inference time (ms) = " + (end - start));
PwplACRFTrainer.logger.info ("ACRF unroll time (ms) = " + unrollTime);
PwplACRFTrainer.logger.info ("getValue (loglikelihood) = " + retval);
logger.info ("Number cVGA calls = " + numCvgaCalls);
logger.info ("Total cVGA time (ms) = " + timePerCvgaCall);
return retval;
}
/**
* Computes the gradient of the penalized log likelihood of the
* ACRF, and places it in cachedGradient[].
* <p/>
* Gradient is
* constraint - expectation - parameters/gaussianPriorVariance
*/
protected void computeValueGradient (double[] grad)
{
/* Index into current element of cachedGradient[] array. */
int gidx = 0;
// First do gradient wrt defaultWeights
for (int tidx = 0; tidx < templates.length; tidx++) {
SparseVector theseWeights = templates[tidx].getDefaultWeights ();
SparseVector theseConstraints = defaultConstraints[tidx];
SparseVector theseExpectations = defaultExpectations[tidx];
for (int j = 0; j < theseWeights.numLocations (); j++) {
double weight = theseWeights.valueAtLocation (j);
double constraint = theseConstraints.valueAtLocation (j);
double expectation = theseExpectations.valueAtLocation (j);
if (PwplACRFTrainer.printGradient) {
System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " +
(weight / gaussianPriorVariance) + " (reg) [feature=DEFAULT]");
}
grad[gidx++] = constraint - expectation - (weight / gaussianPriorVariance);
}
}
// Now do other weights
for (int tidx = 0; tidx < templates.length; tidx++) {
ACRF.Template tmpl = templates[tidx];
SparseVector[] weights = tmpl.getWeights ();
for (int i = 0; i < weights.length; i++) {
SparseVector thisWeightVec = weights[i];
SparseVector thisConstraintVec = constraints[tidx][i];
SparseVector thisExpectationVec = expectations[tidx][i];
for (int j = 0; j < thisWeightVec.numLocations (); j++) {
double w = thisWeightVec.valueAtLocation (j);
double gradient; // Computed below
double constraint = thisConstraintVec.valueAtLocation (j);
double expectation = thisExpectationVec.valueAtLocation (j);
/* A parameter may be set to -infinity by an external user.
* We set gradient to 0 because the parameter's value can
* never change anyway and it will mess up future calculations
* on the matrix. */
if (Double.isInfinite (w)) {
PwplACRFTrainer.logger.warning ("Infinite weight for node index " + i +
" feature " +
acrf.getInputAlphabet ().lookupObject (j));
gradient = 0.0;
} else {
gradient = constraint
- (w / gaussianPriorVariance)
- expectation;
}
if (PwplACRFTrainer.printGradient) {
int idx = thisWeightVec.indexAtLocation (j);
Object fname = acrf.getInputAlphabet ().lookupObject (idx);
System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " +
(w / gaussianPriorVariance) + " (reg) [feature=" + fname + "]");
}
grad[gidx++] = gradient;
}
}
}
}
/**
* For every feature f_k, computes the expected value of f_k
* aver all possible label sequences given the list of instances
* we have.
* <p/>
* These values are stored in collector, that is,
* collector[i][j][k] gets the expected value for the
* feature for clique i, label assignment j, and input features k.
*/
private double collectExpectationsAndValue (ACRF.UnrolledGraph unrolled, Assignment observations, int inum)
{
double value = 0.0;
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
ACRF.Template tmpl = clique.getTemplate ();
int tidx = tmpl.index;
if (tidx == -1) continue;
for (int vi = 0; vi < clique.size (); vi++) {
Variable target = clique.get (vi);
value += computeValueGradientForAssn (observations, clique, target);
}
}
switch (wrongWrongType) {
case NO_WRONG_WRONG:
break;
case CONDITION_WW:
value += addConditionalWW (unrolled, inum);
break;
default:
throw new IllegalStateException ();
}
return value;
}
private double addConditionalWW (ACRF.UnrolledGraph unrolled, int inum)
{
double value = 0;
if (allWrongWrongs != null) {
List wrongs = allWrongWrongs[inum];
for (Iterator it = wrongs.iterator (); it.hasNext ();) {
WrongWrong ww = (WrongWrong) it.next ();
Variable target = ww.findVariable (unrolled);
ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
Assignment wrong = Assignment.makeFromSingleIndex (clique, ww.assnIdx);
// System.out.println ("Computing for WW: "+clique+" idx "+ww.assnIdx+" target "+target);
value += computeValueGradientForAssn (wrong, clique, target);
}
}
return value;
}
private double computeValueGradientForAssn (Assignment observations, ACRF.UnrolledVarSet clique, Variable target)
{
numCvgaCalls++;
Timing timing = new Timing ();
ACRF.Template tmpl = clique.getTemplate ();
int tidx = tmpl.index;
Assignment cliqueAssn = Assignment.restriction (observations, clique);
int M = target.getNumOutcomes ();
double[] vals = new double [M];
int[] singles = new int [M];
for (int assnIdx = 0; assnIdx < M; assnIdx++) {
cliqueAssn.setValue (target, assnIdx);
vals[assnIdx] = computeLogFactorValue (cliqueAssn, tmpl, clique.getFv ());
singles[assnIdx] = cliqueAssn.singleIndex ();
}
double logZ = Maths.sumLogProb (vals);
for (int assnIdx = 0; assnIdx < M; assnIdx++) {
double marginal = Math.exp (vals[assnIdx] - logZ);
int expIdx = singles[assnIdx];
expectations[tidx][expIdx].plusEqualsSparse (clique.getFv (), marginal);
if (defaultExpectations[tidx].location (expIdx) != -1) {
defaultExpectations[tidx].incrementValue (expIdx, marginal);
}
}
int observedVal = observations.get (target);
timePerCvgaCall += timing.elapsedTime ();
return vals[observedVal] - logZ;
}
private double computeLogFactorValue (Assignment cliqueAssn, ACRF.Template tmpl, FeatureVector fv)
{
SparseVector[] weights = tmpl.getWeights ();
int idx = cliqueAssn.singleIndex ();
SparseVector w = weights[idx];
double dp = w.dotProduct (fv);
dp += tmpl.getDefaultWeight (idx);
return dp;
}
public void collectConstraints (InstanceList ilist)
{
for (int inum = 0; inum < ilist.size (); inum++) {
PwplACRFTrainer.logger.finest ("*** Collecting constraints for instance " + inum);
Instance inst = ilist.get (inum);
ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (inst, templates, null, false);
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
int tidx = clique.getTemplate ().index;
if (tidx == -1) continue;
int assn = clique.lookupAssignmentNumber ();
constraints[tidx][assn].plusEqualsSparse (clique.getFv (), clique.size ());
if (defaultConstraints[tidx].location (assn) != -1) {
defaultConstraints[tidx].incrementValue (assn, clique.size ());
}
}
// constraints for wrong-wrongs for instance
if (allWrongWrongs != null) {
List wrongs = allWrongWrongs[inum];
for (Iterator wwIt = wrongs.iterator (); wwIt.hasNext ();) {
WrongWrong ww = (WrongWrong) wwIt.next ();
ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
int tidx = clique.getTemplate ().index;
int wrong2rightId = ww.assnIdx;
constraints[tidx][wrong2rightId].plusEqualsSparse (clique.getFv (), 1.0);
if (defaultConstraints[tidx].location (wrong2rightId) != -1) {
defaultConstraints[tidx].incrementValue (wrong2rightId, 1.0);
}
}
}
}
}
void dumpGradientToFile (String fileName)
{
try {
double[] grad = new double [getNumParameters ()];
getValueGradient (grad);
PrintStream w = new PrintStream (new FileOutputStream (fileName));
for (int i = 0; i < numParameters; i++) {
w.println (grad[i]);
}
w.close ();
} catch (IOException e) {
System.err.println ("Could not open output file.");
e.printStackTrace ();
}
}
void dumpDefaults ()
{
System.out.println ("Default constraints");
for (int i = 0; i < defaultConstraints.length; i++) {
System.out.println ("Template " + i);
defaultConstraints[i].print ();
}
System.out.println ("Default expectations");
for (int i = 0; i < defaultExpectations.length; i++) {
System.out.println ("Template " + i);
defaultExpectations[i].print ();
}
}
void printDebugInfo (ACRF.UnrolledGraph unrolled)
{
acrf.print (System.err);
Assignment assn = unrolled.getAssignment ();
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
System.out.println ("Clique " + clique);
dumpAssnForClique (assn, clique);
Factor ptl = unrolled.factorOf (clique);
System.out.println ("Value = " + ptl.value (assn));
System.out.println (ptl);
}
}
void dumpAssnForClique (Assignment assn, ACRF.UnrolledVarSet clique)
{
for (Iterator it = clique.iterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
System.out.println (var + " ==> " + assn.getObject (var)
+ " (" + assn.get (var) + ")");
}
}
private boolean weightValid (double w, int cnum, int j)
{
if (Double.isInfinite (w)) {
PwplACRFTrainer.logger.warning ("Weight is infinite for clique " + cnum + "assignment " + j);
return false;
} else if (Double.isNaN (w)) {
PwplACRFTrainer.logger.warning ("Weight is Nan for clique " + cnum + "assignment " + j);
return false;
} else {
return true;
}
}
// WRONG WRONG HANDLING
private class WrongWrong {
int varIdx;
int vsIdx;
int assnIdx;
public WrongWrong (ACRF.UnrolledGraph graph, VarSet vs, Variable var, int assnIdx)
{
varIdx = graph.getIndex (var);
vsIdx = graph.getIndex (vs);
this.assnIdx = assnIdx;
}
public ACRF.UnrolledVarSet findVarSet (ACRF.UnrolledGraph unrolled)
{
return unrolled.getUnrolledVarSet (vsIdx);
}
public Variable findVariable (ACRF.UnrolledGraph unrolled)
{
return unrolled.get (varIdx);
}
}
private List allWrongWrongs[];
private void addWrongWrong (InstanceList training)
{
allWrongWrongs = new List [training.size ()];
int totalAdded = 0;
// if (!acrf.isCacheUnrolledGraphs ()) {
// throw new IllegalStateException ("Wrong-wrong won't work without caching unrolled graphs.");
// }
for (int i = 0; i < training.size (); i++) {
allWrongWrongs[i] = new ArrayList ();
int numAdded = 0;
Instance instance = training.get (i);
ACRF.UnrolledGraph unrolled = acrf.unroll (instance);
if (unrolled.factors ().size () == 0) {
System.err.println ("WARNING: FactorGraph for instance " + instance.getName () + " : no factors.");
continue;
}
Inferencer inf = acrf.getInferencer ();
inf.computeMarginals (unrolled);
Assignment target = unrolled.getAssignment ();
for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
ACRF.UnrolledVarSet vs = (ACRF.UnrolledVarSet) it.next ();
Factor marg = inf.lookupMarginal (vs);
for (AssignmentIterator assnIt = vs.assignmentIterator (); assnIt.hasNext (); assnIt.advance ()) {
if (marg.value (assnIt) > wrongWrongThreshold) {
Assignment assn = assnIt.assignment ();
for (int vi = 0; vi < vs.size (); vi++) {
Variable var = vs.get (vi);
if (isWrong2RightAssn (target, assn, var)) {
int assnIdx = assn.singleIndex ();
// System.out.println ("Computing for WW: "+vs+" idx "+assnIdx+" target "+var);
allWrongWrongs[i].add (new WrongWrong (unrolled, vs, var, assnIdx));
numAdded++;
}
}
}
}
}
logger.info ("WrongWrongs: Instance " + i + " : " + instance.getName () + " Num added = " + numAdded);
totalAdded += numAdded;
}
resetConstraints ();
collectConstraints (training);
forceStale ();
logger.info ("Total timesteps = " + totalTimesteps (training));
logger.info ("Total WrongWrongs = " + totalAdded);
}
private int totalTimesteps (InstanceList ilist)
{
int total = 0;
for (int i = 0; i < ilist.size (); i++) {
Instance inst = ilist.get (i);
Sequence seq = (Sequence) inst.getData ();
total += seq.size ();
}
return total;
}
private boolean isWrong2RightAssn (Assignment target, Assignment assn, Variable toExclude)
{
Variable[] vars = assn.getVars ();
for (int i = 0; i < vars.length; i++) {
Variable variable = vars[i];
if ((variable != toExclude) && (assn.get (variable) != target.get (variable))) {
// return true;
return assn.get (toExclude) == target.get (toExclude);
}
}
return false;
}
} // OptimizableACRF
}