Package cc.mallet.grmm.learning

Source Code of cc.mallet.grmm.learning.PwplACRFTrainer$Maxable$WrongWrong

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://mallet.cs.umass.edu/
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;


import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.types.*;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.*;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Timing;
import cc.mallet.grmm.util.CachingOptimizable;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;

/**
* Implementation of piecewise PL (Sutton and McCallum, 2007)
*
* NB The wrong-wrong options are for an extension that we tried that never quite worked
*
* Created: Mar 15, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: BiconditionalPiecewiseACRFTrainer.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class PwplACRFTrainer extends DefaultAcrfTrainer {

  private static final Logger logger = MalletLogger.getLogger (PwplACRFTrainer.class.getName ());
  public static boolean printGradient = false;

  public static final int NO_WRONG_WRONG = 0;
  public static final int CONDITION_WW = 1;
  private int wrongWrongType = NO_WRONG_WRONG;

  private int wrongWrongIter = 10;
  private double wrongWrongThreshold = 0.1;
  private File outputPrefix = new File (".");

  public Optimizable.ByGradientValue createOptimizable (ACRF acrf, InstanceList training)
  {
    return new PwplACRFTrainer.Maxable (acrf, training);
  }

  public double getWrongWrongThreshold ()
  {
    return wrongWrongThreshold;
  }

  public void setWrongWrongThreshold (double wrongWrongThreshold)
  {
    this.wrongWrongThreshold = wrongWrongThreshold;
  }

  public void setWrongWrongType (int wrongWrongType)
  {
    this.wrongWrongType = wrongWrongType;
  }

  public void setWrongWrongIter (int wrongWrongIter)
  {
    this.wrongWrongIter = wrongWrongIter;
  }

  public boolean train (ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet,
                        ACRFEvaluator eval, int numIter, Optimizable.ByGradientValue macrf)
  {
    if (wrongWrongType == NO_WRONG_WRONG) {
      return super.train (acrf, trainingList, validationList, testSet, eval, numIter, macrf);
    } else {
      Maxable bipwMaxable = (Maxable) macrf;
      // add wrong wrongs after 5 iterations
      logger.info ("BiconditionalPiecewiseACRFTrainer: Initial training");
      super.train (acrf, trainingList, validationList, testSet, eval, wrongWrongIter, macrf);
      FileUtils.writeGzippedObject (new File (outputPrefix, "initial-acrf.ser.gz"), acrf);
      logger.info ("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
      bipwMaxable.addWrongWrong (trainingList);
      logger.info ("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
      boolean converged = super.train (acrf, trainingList, validationList, testSet, eval, numIter, macrf);
      reportTrainingLikelihood (acrf, trainingList);
      return converged;
    }
  }

  // Reports true joint likelihood of estimated parameters on the training set.
  public static void reportTrainingLikelihood (ACRF acrf, InstanceList trainingList)
  {
    double total = 0;
    Inferencer inf = acrf.getInferencer ();
    for (int i = 0; i < trainingList.size (); i++) {
      Instance inst = trainingList.get (i);
      ACRF.UnrolledGraph unrolled = acrf.unroll (inst);
      inf.computeMarginals (unrolled);
      double lik = inf.lookupLogJoint (unrolled.getAssignment ());
      total += lik;
      logger.info ("...instance "+i+" likelihood = "+lik);
    }
    logger.info ("Unregularized joint likelihood = "+total);
  }

  public class Maxable extends CachingOptimizable.ByGradient {

    private ACRF acrf;
    InstanceList trainData;

    private ACRF.Template[] templates;

    protected BitSet infiniteValues = null;
    private int numParameters;

    private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;

    public double getGaussianPriorVariance ()
    {
      return gaussianPriorVariance;
    }

    public void setGaussianPriorVariance (double gaussianPriorVariance)
    {
      this.gaussianPriorVariance = gaussianPriorVariance;
    }

    private double gaussianPriorVariance = PwplACRFTrainer.Maxable.DEFAULT_GAUSSIAN_PRIOR_VARIANCE;

    /* Vectors that contain the counts of features observed in the
       training data. Maps
       (clique-template x feature-number) => count
    */
    SparseVector constraints[][];

    /* Vectors that contain the expected value over the
     *  labels of all the features, have seen the training data
     *  (but not the training labels).
     */
    SparseVector expectations[][];

    SparseVector defaultConstraints[];
    SparseVector defaultExpectations[];

    private void initWeights (InstanceList training)
    {
      for (int tidx = 0; tidx < templates.length; tidx++) {
        numParameters += templates[tidx].initWeights (training);
      }
    }

    /* Initialize constraints[][] and expectations[][]
     *  to have the same dimensions as weights, but to
     *  be all zero.
     */
    private void initConstraintsExpectations ()
    {
      // Do the defaults first
      defaultConstraints = new SparseVector [templates.length];
      defaultExpectations = new SparseVector [templates.length];
      for (int tidx = 0; tidx < templates.length; tidx++) {
        SparseVector defaults = templates[tidx].getDefaultWeights ();
        defaultConstraints[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
        defaultExpectations[tidx] = (SparseVector) defaults.cloneMatrixZeroed ();
      }

      // And now the others
      constraints = new SparseVector [templates.length][];
      expectations = new SparseVector [templates.length][];
      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector[] weights = tmpl.getWeights ();
        constraints[tidx] = new SparseVector [weights.length];
        expectations[tidx] = new SparseVector [weights.length];

        for (int i = 0; i < weights.length; i++) {
          constraints[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
          expectations[tidx][i] = (SparseVector) weights[i].cloneMatrixZeroed ();
        }
      }
    }


    private int numCvgaCalls = 0;
    private long timePerCvgaCall = 0;

    void resetProfilingForCall ()
    {
      numCvgaCalls = 0;
      timePerCvgaCall = 0;
    }

    /**
     * Set all expectations to 0 after they've been
     * initialized.
     */
    void resetExpectations ()
    {
      for (int tidx = 0; tidx < expectations.length; tidx++) {
        defaultExpectations[tidx].setAll (0.0);
        for (int i = 0; i < expectations[tidx].length; i++) {
          expectations[tidx][i].setAll (0.0);
        }
      }
    }

    void resetConstraints ()
    {
      for (int tidx = 0; tidx < constraints.length; tidx++) {
        defaultConstraints[tidx].setAll (0.0);
        for (int i = 0; i < constraints[tidx].length; i++) {
          constraints[tidx][i].setAll (0.0);
        }
      }
    }

    protected Maxable (ACRF acrf, InstanceList ilist)
    {
      PwplACRFTrainer.logger.finest ("Initializing OptimizableACRF.");

      this.acrf = acrf;
      templates = acrf.getTemplates ();

      /* allocate for weights, constraints and expectations */
      this.trainData = ilist;
      initWeights (trainData);
      initConstraintsExpectations ();

      int numInstances = trainData.size ();

      cachedValueStale = cachedGradientStale = true;

/*
  if (cacheUnrolledGraphs) {
  unrolledGraphs = new UnrolledGraph [numInstances];
  }
*/

      PwplACRFTrainer.logger.info ("Number of training instances = " + numInstances);
      PwplACRFTrainer.logger.info ("Number of parameters = " + numParameters);
      describePrior ();

      PwplACRFTrainer.logger.fine ("Computing constraints");
      collectConstraints (trainData);
    }

    private void describePrior ()
    {
      PwplACRFTrainer.logger.info ("Using gaussian prior with variance " + gaussianPriorVariance);
    }

    public int getNumParameters () { return numParameters; }

    /* Negate initialValue and finalValue because the parameters are in
     * terms of "weights", not "values".
     */
    public void getParameters (double[] buf)
    {

      if (buf.length != numParameters) {
        throw new IllegalArgumentException ("Argument is not of the " +
                " correct dimensions");
      }
      int idx = 0;
      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector defaults = tmpl.getDefaultWeights ();
        double[] values = defaults.getValues ();
        System.arraycopy (values, 0, buf, idx, values.length);
        idx += values.length;
      }

      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector[] weights = tmpl.getWeights ();
        for (int assn = 0; assn < weights.length; assn++) {
          double[] values = weights[assn].getValues ();
          System.arraycopy (values, 0, buf, idx, values.length);
          idx += values.length;
        }
      }

    }


    protected void setParametersInternal (double[] params)
    {
      cachedValueStale = cachedGradientStale = true;

      int idx = 0;
      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector defaults = tmpl.getDefaultWeights ();
        double[] values = defaults.getValues ();
        System.arraycopy (params, idx, values, 0, values.length);
        idx += values.length;
      }

      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector[] weights = tmpl.getWeights ();
        for (int assn = 0; assn < weights.length; assn++) {
          double[] values = weights[assn].getValues ();
          System.arraycopy (params, idx, values, 0, values.length);
          idx += values.length;
        }
      }
    }

    // Functions for unit tests to get constraints and expectations
    //  I'm too lazy to make a deep copy.  Callers should not
    //  modify these.

    public SparseVector[] getExpectations (int cnum) { return expectations[cnum]; }

    public SparseVector[] getConstraints (int cnum) { return constraints[cnum]; }

    /**
     * print weights
     */
    public void printParameters ()
    {
      double[] buf = new double[numParameters];
      getParameters (buf);

      int len = buf.length;
      for (int w = 0; w < len; w++)
        System.out.print (buf[w] + "\t");
      System.out.println ();
    }


    protected double computeValue ()
    {
      double retval = 0.0;
      int numInstances = trainData.size ();

      long start = System.currentTimeMillis ();
      long unrollTime = 0;
      resetProfilingForCall ();

      /* Instance values must either always or never be included in
       * the total values; we can't just sometimes skip a value
       * because it is infinite, that throws off the total values.
       * We only allow an instance to have infinite value if it happens
       * from the start (we don't compute the value for the instance
       * after the first round. If any other instance has infinite
       * value after that it is an error. */

      boolean initializingInfiniteValues = false;

      if (infiniteValues == null) {
        /* We could initialize bitset with one slot for every
         * instance, but it is *probably* cheaper not to, taking the
         * time hit to allocate the space if a bit becomes
         * necessary. */
        infiniteValues = new BitSet ();
        initializingInfiniteValues = true;
      }

      /* Clear the sufficient statistics that we are about to fill */
      resetExpectations ();

      /* Fill in expectations for each instance */
      for (int i = 0; i < numInstances; i++) {
        Instance instance = trainData.get (i);

        /* Compute marginals for each clique */
        long unrollStart = System.currentTimeMillis ();
        ACRF.UnrolledGraph unrolled = acrf.unrollStructureOnly (instance);
//        ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (instance, templates, Arrays.asList (fixedTmpls), false);
        long unrollEnd = System.currentTimeMillis ();
        unrollTime += (unrollEnd - unrollStart);

//        if (unrolled.numVariables () == 0) continue;   // Happens if all nodes are pruned.

        /* Save the expected value of each feature for when we
           compute the gradient. */
        Assignment observations = unrolled.getAssignment ();
        double value = collectExpectationsAndValue (unrolled, observations, i);

        if (Double.isInfinite (value)) {
          if (initializingInfiniteValues) {
            PwplACRFTrainer.logger.warning ("Instance " + instance.getName () +
                    " has infinite value; skipping.");
            infiniteValues.set (i);
//            continue;
          } else if (!infiniteValues.get (i)) {
            PwplACRFTrainer.logger.warning ("Infinite value on instance " + instance.getName () +
                    "returning -infinity");
            return Double.NEGATIVE_INFINITY;
/*
            printDebugInfo (unrolled);
            throw new IllegalStateException
              ("Instance " + instance.getName()+ " used to have non-infinite"
               + " value, but now it has infinite value.");
*/
          }
        } else if (Double.isNaN (value)) {
          System.out.println ("NaN on instance " + i + " : " + instance.getName ());
          printDebugInfo (unrolled);
/*          throw new IllegalStateException
            ("Value is NaN in ACRF.getValue() Instance "+i);
*/
          PwplACRFTrainer.logger.warning ("Value is NaN in ACRF.getValue() Instance " + i + " : " +
                  "returning -infinity... ");
          return Double.NEGATIVE_INFINITY;
        } else {
          retval += value;
        }

      }

      /* Incorporate Gaussian prior on parameters. This means
         that for each weight, we will add w^2 / (2 * variance) to the
         log probability. */

      double priorDenom = 2 * gaussianPriorVariance;

      for (int tidx = 0; tidx < templates.length; tidx++) {
        SparseVector[] weights = templates[tidx].getWeights ();
        for (int j = 0; j < weights.length; j++) {
          for (int fnum = 0; fnum < weights[j].numLocations (); fnum++) {
            double w = weights[j].valueAtLocation (fnum);
            if (weightValid (w, tidx, j)) {
              retval += -w * w / priorDenom;
            }
          }
        }
      }

      long end = System.currentTimeMillis ();
      PwplACRFTrainer.logger.info ("ACRF Inference time (ms) = " + (end - start));
      PwplACRFTrainer.logger.info ("ACRF unroll time (ms) = " + unrollTime);
      PwplACRFTrainer.logger.info ("getValue (loglikelihood) = " + retval);

      logger.info ("Number cVGA calls = " + numCvgaCalls);
      logger.info ("Total cVGA time (ms) = " + timePerCvgaCall);

      return retval;
    }


    /**
     * Computes the gradient of the penalized log likelihood of the
     * ACRF, and places it in cachedGradient[].
     * <p/>
     * Gradient is
     * constraint - expectation - parameters/gaussianPriorVariance
     */
    protected void computeValueGradient (double[] grad)
    {
      /* Index into current element of cachedGradient[] array. */
      int gidx = 0;

      // First do gradient wrt defaultWeights
      for (int tidx = 0; tidx < templates.length; tidx++) {
        SparseVector theseWeights = templates[tidx].getDefaultWeights ();
        SparseVector theseConstraints = defaultConstraints[tidx];
        SparseVector theseExpectations = defaultExpectations[tidx];
        for (int j = 0; j < theseWeights.numLocations (); j++) {
          double weight = theseWeights.valueAtLocation (j);
          double constraint = theseConstraints.valueAtLocation (j);
          double expectation = theseExpectations.valueAtLocation (j);
          if (PwplACRFTrainer.printGradient) {
            System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " +
                    (weight / gaussianPriorVariance) + " (reg)  [feature=DEFAULT]");
          }
          grad[gidx++] = constraint - expectation - (weight / gaussianPriorVariance);
        }
      }

      // Now do other weights
      for (int tidx = 0; tidx < templates.length; tidx++) {
        ACRF.Template tmpl = templates[tidx];
        SparseVector[] weights = tmpl.getWeights ();
        for (int i = 0; i < weights.length; i++) {
          SparseVector thisWeightVec = weights[i];
          SparseVector thisConstraintVec = constraints[tidx][i];
          SparseVector thisExpectationVec = expectations[tidx][i];

          for (int j = 0; j < thisWeightVec.numLocations (); j++) {
            double w = thisWeightVec.valueAtLocation (j);
            double gradient;  // Computed below

            double constraint = thisConstraintVec.valueAtLocation (j);
            double expectation = thisExpectationVec.valueAtLocation (j);

            /* A parameter may be set to -infinity by an external user.
             * We set gradient to 0 because the parameter's value can
             * never change anyway and it will mess up future calculations
             * on the matrix. */
            if (Double.isInfinite (w)) {
              PwplACRFTrainer.logger.warning ("Infinite weight for node index " + i +
                      " feature " +
                      acrf.getInputAlphabet ().lookupObject (j));
              gradient = 0.0;
            } else {
              gradient = constraint
                      - (w / gaussianPriorVariance)
                      - expectation;
            }

            if (PwplACRFTrainer.printGradient) {
              int idx = thisWeightVec.indexAtLocation (j);
              Object fname = acrf.getInputAlphabet ().lookupObject (idx);
              System.out.println (" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " +
                      (w / gaussianPriorVariance) + " (reg)  [feature=" + fname + "]");
            }

            grad[gidx++] = gradient;
          }
        }
      }
    }

    /**
     * For every feature f_k, computes the expected value of f_k
     * aver all possible label sequences given the list of instances
     * we have.
     * <p/>
     * These values are stored in collector, that is,
     * collector[i][j][k]  gets the expected value for the
     * feature for clique i, label assignment j, and input features k.
     */
    private double collectExpectationsAndValue (ACRF.UnrolledGraph unrolled, Assignment observations, int inum)
    {
      double value = 0.0;
      for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
        ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
        ACRF.Template tmpl = clique.getTemplate ();
        int tidx = tmpl.index;
        if (tidx == -1) continue;

        for (int vi = 0; vi < clique.size (); vi++) {
          Variable target = clique.get (vi);
          value += computeValueGradientForAssn (observations, clique, target);
        }
      }

      switch (wrongWrongType) {
        case NO_WRONG_WRONG:
          break;

        case CONDITION_WW:
          value += addConditionalWW (unrolled, inum);
          break;

        default:
          throw new IllegalStateException ();
      }

      return value;
    }

    private double addConditionalWW (ACRF.UnrolledGraph unrolled, int inum)
    {
      double value = 0;
      if (allWrongWrongs != null) {
        List wrongs = allWrongWrongs[inum];
        for (Iterator it = wrongs.iterator (); it.hasNext ();) {
          WrongWrong ww = (WrongWrong) it.next ();
          Variable target = ww.findVariable (unrolled);
          ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
          Assignment wrong = Assignment.makeFromSingleIndex (clique, ww.assnIdx);
//          System.out.println ("Computing for WW: "+clique+" idx "+ww.assnIdx+" target "+target);
          value += computeValueGradientForAssn (wrong, clique, target);
        }
      }
      return value;
    }

    private double computeValueGradientForAssn (Assignment observations, ACRF.UnrolledVarSet clique, Variable target)
    {
      numCvgaCalls++;
      Timing timing = new Timing ();

      ACRF.Template tmpl = clique.getTemplate ();
      int tidx = tmpl.index;
      Assignment cliqueAssn = Assignment.restriction (observations, clique);
      int M = target.getNumOutcomes ();
      double[] vals = new double [M];
      int[] singles = new int [M];
      for (int assnIdx = 0; assnIdx < M; assnIdx++) {
        cliqueAssn.setValue (target, assnIdx);
        vals[assnIdx] = computeLogFactorValue (cliqueAssn, tmpl, clique.getFv ());
        singles[assnIdx] = cliqueAssn.singleIndex ();
      }
      double logZ = Maths.sumLogProb (vals);

      for (int assnIdx = 0; assnIdx < M; assnIdx++) {
        double marginal = Math.exp (vals[assnIdx] - logZ);
        int expIdx = singles[assnIdx];
        expectations[tidx][expIdx].plusEqualsSparse (clique.getFv (), marginal);
        if (defaultExpectations[tidx].location (expIdx) != -1) {
          defaultExpectations[tidx].incrementValue (expIdx, marginal);
        }
      }

      int observedVal = observations.get (target);

      timePerCvgaCall += timing.elapsedTime ();

      return vals[observedVal] - logZ;
    }

    private double computeLogFactorValue (Assignment cliqueAssn, ACRF.Template tmpl, FeatureVector fv)
    {
      SparseVector[] weights = tmpl.getWeights ();
      int idx = cliqueAssn.singleIndex ();
      SparseVector w = weights[idx];
      double dp = w.dotProduct (fv);
      dp += tmpl.getDefaultWeight (idx);
      return dp;
    }


    public void collectConstraints (InstanceList ilist)
    {
      for (int inum = 0; inum < ilist.size (); inum++) {
        PwplACRFTrainer.logger.finest ("*** Collecting constraints for instance " + inum);
        Instance inst = ilist.get (inum);
        ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph (inst, templates, null, false);
        for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
          ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
          int tidx = clique.getTemplate ().index;
          if (tidx == -1) continue;

          int assn = clique.lookupAssignmentNumber ();
          constraints[tidx][assn].plusEqualsSparse (clique.getFv (), clique.size ());
          if (defaultConstraints[tidx].location (assn) != -1) {
            defaultConstraints[tidx].incrementValue (assn, clique.size ());
          }
        }

        // constraints for wrong-wrongs for instance
        if (allWrongWrongs != null) {
          List wrongs = allWrongWrongs[inum];
          for (Iterator wwIt = wrongs.iterator (); wwIt.hasNext ();) {
            WrongWrong ww = (WrongWrong) wwIt.next ();
            ACRF.UnrolledVarSet clique = ww.findVarSet (unrolled);
            int tidx = clique.getTemplate ().index;
            int wrong2rightId = ww.assnIdx;
            constraints[tidx][wrong2rightId].plusEqualsSparse (clique.getFv (), 1.0);
            if (defaultConstraints[tidx].location (wrong2rightId) != -1) {
              defaultConstraints[tidx].incrementValue (wrong2rightId, 1.0);
            }
          }
        }
      }
    }

    void dumpGradientToFile (String fileName)
    {
      try {
        double[] grad = new double [getNumParameters ()];
        getValueGradient (grad);

        PrintStream w = new PrintStream (new FileOutputStream (fileName));
        for (int i = 0; i < numParameters; i++) {
          w.println (grad[i]);
        }
        w.close ();
      } catch (IOException e) {
        System.err.println ("Could not open output file.");
        e.printStackTrace ();
      }
    }

    void dumpDefaults ()
    {
      System.out.println ("Default constraints");
      for (int i = 0; i < defaultConstraints.length; i++) {
        System.out.println ("Template " + i);
        defaultConstraints[i].print ();
      }
      System.out.println ("Default expectations");
      for (int i = 0; i < defaultExpectations.length; i++) {
        System.out.println ("Template " + i);
        defaultExpectations[i].print ();
      }
    }

    void printDebugInfo (ACRF.UnrolledGraph unrolled)
    {
      acrf.print (System.err);
      Assignment assn = unrolled.getAssignment ();
      for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
        ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet) it.next ();
        System.out.println ("Clique " + clique);
        dumpAssnForClique (assn, clique);
        Factor ptl = unrolled.factorOf (clique);
        System.out.println ("Value = " + ptl.value (assn));
        System.out.println (ptl);
      }
    }

    void dumpAssnForClique (Assignment assn, ACRF.UnrolledVarSet clique)
    {
      for (Iterator it = clique.iterator (); it.hasNext ();) {
        Variable var = (Variable) it.next ();
        System.out.println (var + " ==> " + assn.getObject (var)
                + "  (" + assn.get (var) + ")");
      }
    }


    private boolean weightValid (double w, int cnum, int j)
    {
      if (Double.isInfinite (w)) {
        PwplACRFTrainer.logger.warning ("Weight is infinite for clique " + cnum + "assignment " + j);
        return false;
      } else if (Double.isNaN (w)) {
        PwplACRFTrainer.logger.warning ("Weight is Nan for clique " + cnum + "assignment " + j);
        return false;
      } else {
        return true;
      }
    }

    //  WRONG WRONG HANDLING

    private class WrongWrong {
      int varIdx;
      int vsIdx;
      int assnIdx;

      public WrongWrong (ACRF.UnrolledGraph graph, VarSet vs, Variable var, int assnIdx)
      {
        varIdx = graph.getIndex (var);
        vsIdx = graph.getIndex (vs);
        this.assnIdx = assnIdx;
      }

      public ACRF.UnrolledVarSet findVarSet (ACRF.UnrolledGraph unrolled)
      {
        return unrolled.getUnrolledVarSet (vsIdx);
      }

      public Variable findVariable (ACRF.UnrolledGraph unrolled)
      {
        return unrolled.get (varIdx);
      }
    }

    private List allWrongWrongs[];

    private void addWrongWrong (InstanceList training)
    {
      allWrongWrongs = new List [training.size ()];
      int totalAdded = 0;

//      if (!acrf.isCacheUnrolledGraphs ()) {
//        throw new IllegalStateException ("Wrong-wrong won't work without caching unrolled graphs.");
//      }

      for (int i = 0; i < training.size (); i++) {
        allWrongWrongs[i] = new ArrayList ();
        int numAdded = 0;

        Instance instance = training.get (i);
        ACRF.UnrolledGraph unrolled = acrf.unroll (instance);
        if (unrolled.factors ().size () == 0) {
          System.err.println ("WARNING: FactorGraph for instance " + instance.getName () + " : no factors.");
          continue;
        }

        Inferencer inf = acrf.getInferencer ();
        inf.computeMarginals (unrolled);

        Assignment target = unrolled.getAssignment ();
        for (Iterator it = unrolled.unrolledVarSetIterator (); it.hasNext ();) {
          ACRF.UnrolledVarSet vs = (ACRF.UnrolledVarSet) it.next ();
          Factor marg = inf.lookupMarginal (vs);
          for (AssignmentIterator assnIt = vs.assignmentIterator (); assnIt.hasNext (); assnIt.advance ()) {
            if (marg.value (assnIt) > wrongWrongThreshold) {
              Assignment assn = assnIt.assignment ();
              for (int vi = 0; vi < vs.size (); vi++) {
                Variable var = vs.get (vi);
                if (isWrong2RightAssn (target, assn, var)) {
                  int assnIdx = assn.singleIndex ();
//                  System.out.println ("Computing for WW: "+vs+" idx "+assnIdx+" target "+var);
                  allWrongWrongs[i].add (new WrongWrong (unrolled, vs, var, assnIdx));
                  numAdded++;
                }
              }
            }
          }

        }

        logger.info ("WrongWrongs: Instance " + i + " : " + instance.getName () + " Num added = " + numAdded);
        totalAdded += numAdded;
      }

      resetConstraints ();
      collectConstraints (training);
      forceStale ();

      logger.info ("Total timesteps = " + totalTimesteps (training));
      logger.info ("Total WrongWrongs = " + totalAdded);
    }

    private int totalTimesteps (InstanceList ilist)
    {
      int total = 0;
      for (int i = 0; i < ilist.size (); i++) {
        Instance inst = ilist.get (i);
        Sequence seq = (Sequence) inst.getData ();
        total += seq.size ();
      }
      return total;
    }

    private boolean isWrong2RightAssn (Assignment target, Assignment assn, Variable toExclude)
    {
      Variable[] vars = assn.getVars ();
      for (int i = 0; i < vars.length; i++) {
        Variable variable = vars[i];
        if ((variable != toExclude) && (assn.get (variable) != target.get (variable))) {
//          return true;
          return assn.get (toExclude) == target.get (toExclude);
        }
      }
      return false;
    }

  } // OptimizableACRF

}
TOP

Related Classes of cc.mallet.grmm.learning.PwplACRFTrainer$Maxable$WrongWrong

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.