Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.CRFTrainerByGE

/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised;

import java.util.ArrayList;
import java.util.logging.Logger;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;

/**
* Trains a CRF using Generalized Expectation constraints that
* consider either a single label or a pair of labels of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* @author Gregory Druck
*/
public class CRFTrainerByGE extends TransducerTrainer {

  private static Logger logger = MalletLogger.getLogger(CRFTrainerByGE.class.getName());

  private static final int DEFAULT_NUM_RESETS = 1;
  private static final int DEFAULT_GPV = 10;

  private boolean converged;
  private int iteration;
  private int numThreads;
  private int numResets;
  private double gaussianPriorVariance;
  private ArrayList<GEConstraint> constraints;
  private CRF crf;
  private StateLabelMap stateLabelMap;

  private CRFOptimizableByGE optimizable;
  private Optimizer optimizer;

  public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints) {
    this(crf,constraints,1);
  }

  public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints, int numThreads) {
    this.converged = false;
    this.iteration = 0;
    this.constraints = constraints;
    this.crf = crf;
    this.numThreads = numThreads;
    this.numResets = DEFAULT_NUM_RESETS;
    this.gaussianPriorVariance = DEFAULT_GPV;
    // default one to one state label map
    // other maps can be set with setStateLabelMap
    this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(),true);
  }

  @Override
  public int getIteration() {
    return iteration;
  }

  @Override
  public Transducer getTransducer() {
    return crf;
  }

  @Override
  public boolean isFinishedTraining() {
    return converged;
  }

  public void setGaussianPriorVariance(double gpv) {
    this.gaussianPriorVariance = gpv;
  }

  /**
   * Sets number of resets of L-BFGS during
   * optimization.  Resetting more times
   * can be useful since the GE objective
   * function is non-convex
   *
   * @param numResets Number of resets of L-BFGS
   */
  public void setNumResets(int numResets) {
    this.numResets = numResets;
  }

  // map between states in CRF FST and labels
  public void setStateLabelMap(StateLabelMap map) {
    this.stateLabelMap = map;
  }

  public void setOptimizable(Optimizer optimizer) {
    this.optimizer = optimizer;
  }

  public Optimizable.ByGradientValue getOptimizable(InstanceList unlabeled) {
    if (optimizable == null) {
      optimizable = new CRFOptimizableByGE(crf, constraints, unlabeled, stateLabelMap,numThreads);
      optimizable.setGaussianPriorVariance(gaussianPriorVariance);
    }
    return optimizable;
  }

  public Optimizer getOptimizer(Optimizable.ByGradientValue optimizable) {
    if (optimizer == null) {
      optimizer = new LimitedMemoryBFGS(optimizable);
    }
    return optimizer;
  }

  @Override
  public boolean train(InstanceList unlabeledSet, int numIterations) {

    assert(constraints.size() > 0);
    if (constraints.size() == 0) {
      throw new RuntimeException("No constraints specified!");
    }

    getOptimizable(unlabeledSet);
    getOptimizer(optimizable);

    if (optimizer instanceof LimitedMemoryBFGS) {
      ((LimitedMemoryBFGS)optimizer).reset();
    }

    converged = false;
    logger.info ("CRF about to train with "+numIterations+" iterations");
    // sometimes resetting the optimizer helps to find
    // a better parameter setting
    int iter = 0;
    for (int reset = 0; reset < numResets + 1; reset++) {
      for (; iter < numIterations; iter++) {
        try {
          converged = optimizer.optimize (1);
          iteration++;
          logger.info ("CRF finished one iteration of maximizer, i="+iter);
          runEvaluators();
        } catch (IllegalArgumentException e) {
          e.printStackTrace();
          logger.info ("Catching exception; saying converged.");
          converged = true;
        } catch (Exception e) {
          e.printStackTrace();
          logger.info("Catching exception; saying converged.");
          converged = true;
        }
        if (converged) {
          logger.info ("CRF training has converged, i="+iter);
          break;
        }
      }
      if (optimizer instanceof LimitedMemoryBFGS) {
        ((LimitedMemoryBFGS)optimizer).reset();
      }
    }

    // shutdown threads
    optimizable.shutdown();

    return converged;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.CRFTrainerByGE

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.