Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.CRFOptimizableByEntropyRegularization

/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised;

import java.io.Serializable;
import java.util.logging.Logger;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.util.MalletLogger;

/**
* A CRF objective function that is the entropy of the CRF's
* predictions on unlabeled data.
*
* References:
* Feng Jiao, Shaojun Wang, Chi-Hoon Lee, Russell Greiner, Dale Schuurmans
* "Semi-supervised conditional random fields for improved sequence segmentation and labeling"
* ACL 2006
*
* Gideon Mann, Andrew McCallum
* "Efficient Computation of Entropy Gradient for Semi-Supervised Conditional Random Fields"
* HLT/NAACL 2007
*
* @author Gaurav Chandalia
* @author Gregory Druck
*/
public class CRFOptimizableByEntropyRegularization implements Optimizable.ByGradientValue,
                                                   Serializable {
  private static Logger logger = MalletLogger.getLogger(CRFOptimizableByEntropyRegularization.class.getName());

  private int cachedValueWeightsStamp = -1;
  private int cachedGradientWeightsStamp = -1;
 
  // model's expectations according to entropy reg.
  protected CRF.Factors expectations;
  // used to update gradient
  protected Transducer.Incrementor incrementor;

  // contains labeled and unlabeled data
  protected InstanceList data;
  // the model
  protected CRF crf;

  // scale entropy values,
  // typically, (entropyRegGamma * numLabeled / numUnlabeled)
  protected double scalingFactor;

  // log probability of all the sequences, this is also the entropy due to all
  // the instances (updated in computeExpectations())
  protected double cachedValue;
  // gradient due to this optimizable (updated in getValueGradient())
  protected double[] cachedGradient;

  /**
   * Initializes the structures.
   */
  public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList ilist,
                                    double scalingFactor) {
    data = ilist;
    this.crf = crf;
    this.scalingFactor = scalingFactor;

    // initialize the expectations using the model
    expectations = new CRF.Factors(crf);
    incrementor = expectations.new Incrementor();

    cachedValue = 0.0;
    cachedGradient = new double[crf.getParameters().getNumFactors()];
  }

  /**
   * Initializes the structures (sets the scaling factor to 1.0).
   */
  public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList ilist) {
    this(crf, ilist, 1.0);
  }

  public double getScalingFactor() {
    return scalingFactor;
  }

  public void setScalingFactor(double scalingFactor) {
    this.scalingFactor = scalingFactor;
  }

  /**
   * Resets, computes and fills expectations from all instances, also updating
   * the entropy value. <p>
   *
   * Analogous to <tt>CRFOptimizableByLabelLikelihood.getExpectationValue<tt>.
   */
  public void computeExpectations() {
    expectations.zero();

    // now, update the expectations due to each instance for entropy reg.
    for (int ii = 0; ii < data.size(); ii++) {
      FeatureVectorSequence input = (FeatureVectorSequence) data.get(ii).getData();
      SumLattice lattice = new SumLatticeDefault(crf,input, true);

      // udpate the expectations
      EntropyLattice entropyLattice = new EntropyLattice(
          input, lattice.getGammas(), lattice.getXis(), crf,
          incrementor, scalingFactor);
      cachedValue += entropyLattice.getEntropy();
    }
  }

  public double getValue() {
    if (crf.getWeightsValueChangeStamp() != cachedValueWeightsStamp) {
      // The cached value is not up to date; it was calculated for a different set of CRF weights.
      cachedValueWeightsStamp = crf.getWeightsValueChangeStamp();
     
      cachedValue = 0;
      computeExpectations();
      cachedValue = scalingFactor * cachedValue;
      assert(!Double.isNaN(cachedValue) && !Double.isInfinite(cachedValue))
          : "Likelihood due to Entropy Regularization is NaN/Infinite";
 
      logger.info("getValue() (entropy regularization) = " + cachedValue);
    }
    return cachedValue;
  }

  public void getValueGradient(double[] buffer) {
    if (cachedGradientWeightsStamp != crf.getWeightsValueChangeStamp()) {
      cachedGradientWeightsStamp = crf.getWeightsValueChangeStamp(); // cachedGradient will soon no longer be stale
   
      getValue();
   
      // if this fails then look in computeExpectations
      expectations.assertNotNaNOrInfinite();
      // fill up gradient
      expectations.getParameters(cachedGradient);
    }
    System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
  }

  // some get/set methods that have to be implemented
  public int getNumParameters() {
    return crf.getParameters().getNumFactors();
  }

  public void getParameters(double[] buffer) {
    crf.getParameters().getParameters(buffer);
  }

  public void setParameters(double[] buffer) {
    crf.getParameters().setParameters(buffer);
    crf.weightsValueChanged();
  }

  public double getParameter(int index) {
    return crf.getParameters().getParameter(index);
  }

  public void setParameter(int index, double value) {
    crf.getParameters().setParameter(index, value);
    crf.weightsValueChanged();
  }

  // serialization stuff
  private static final long serialVersionUID = 1;
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.CRFOptimizableByEntropyRegularization

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.