Package cc.mallet.fst.semi_supervised.pr

Source Code of cc.mallet.fst.semi_supervised.pr.CRFTrainerByPR

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr;

import java.util.ArrayList;
import java.util.BitSet;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;

/**
* Posterior regularization trainer.
*
* @author Gregory Druck
*/

public class CRFTrainerByPR extends TransducerTrainer implements TransducerTrainer.ByOptimization {

  private boolean converged;
  private int iter;
  private int numThreads;
  private double pGpv;
  private double tolerance;
  private double value;
  private double qValue;
  private ArrayList<PRConstraint> constraints;
  private LimitedMemoryBFGS bfgs;
  private CRF crf;
  private StateLabelMap stateLabelMap;
 
  public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints) {
    this(crf,constraints,1);
  }
 
  public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints, int numThreads) {
    this.crf = crf;
    this.iter = 0;
    this.value = Double.NEGATIVE_INFINITY;
    this.constraints = constraints;
    this.pGpv = 10;
    this.tolerance = 0.001;
    this.numThreads = numThreads;
    this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(),true);
  }
 
  @Override
  public int getIteration() {
    return iter;
  }

  @Override
  public Transducer getTransducer() {
    return crf;
  }

  @Override
  public boolean isFinishedTraining() {
    return converged;
  }
 
  // map between states in CRF FST and labels
  public void setStateLabelMap(StateLabelMap map) {
    this.stateLabelMap = map;
  }
 
  public void setPGaussianPriorVariance(double pGpv) {
    this.pGpv = pGpv;
  }
 
  public void setTolerance(double tolerance) {
    this.tolerance = tolerance;
  }
 
  @Override
  public boolean train(InstanceList train, int numIterations) {
    return train(train,0,numIterations);
  }
 
  public boolean train(InstanceList train, int minIter, int maxIter) {
    return train(train,minIter,maxIter,Integer.MAX_VALUE);
  }
 
  public boolean train(InstanceList train, int minIter, int maxIter, int maxIterPerStep) {
    double oldValue = 0;
    int max = iter + maxIter;
   
    BitSet constrainedInstances = new BitSet();
    for (PRConstraint constraint : constraints) {
      constrainedInstances.or(constraint.preProcess(train));
      constraint.setStateLabelMap(stateLabelMap);
    }
   
    int removed = 0;
    InstanceList tempTrain = train.cloneEmpty();
    for (int ii = 0; ii < train.size(); ii++) {
      if (constrainedInstances.get(ii)) {
        tempTrain.add(train.get(ii));
      }
      else {
        removed++;
      }
    }
    train = tempTrain;
    System.err.println("Removed " + removed + " instances that do not contain constraints.");
   
    PRAuxiliaryModel model = new PRAuxiliaryModel(crf,constraints);

    for (; iter < max; iter++) {
      long startTime = System.currentTimeMillis();
     
      // train q
      ConstraintsOptimizableByPR opt = new ConstraintsOptimizableByPR(crf, train, model, numThreads);
      bfgs = new LimitedMemoryBFGS(opt);
      try {
        bfgs.optimize(maxIterPerStep);
      } catch (Exception e) {
        e.printStackTrace();
      }
      opt.shutdown();
     
      /*
      for (int j = 0; j < constraints.size(); j++) {
        constraints.get(j).print();
      }
      */
     
      qValue = opt.getCompleteValueContribution();
      assert(qValue > 0);
     
      // use to train p
      CRFOptimizableByKL optP = new CRFOptimizableByKL(crf, train, model, opt.getCachedDots(), numThreads, 1);
      optP.setGaussianPriorVariance(pGpv);
      LimitedMemoryBFGS bfgsP = new LimitedMemoryBFGS(optP);
     
      try {
        bfgsP.optimize(maxIterPerStep);
      } catch (Exception e) {
        e.printStackTrace();
      }
      optP.shutdown();
     
      value = optP.getValue() - qValue;
      assert(value < 0);
      System.err.println("Total value = " + value + " (pValue = " + optP.getValue() + ") (qValue = " + (-qValue) + ")");
     
      System.err.println("Time for iteration " + String.format("%.2f",((System.currentTimeMillis() - startTime) / 1000.)) + "s");
     
      // stopping criteria from BFGS
      //System.err.println("Convergence test: " + (2.0*Math.abs(value-oldValue)) + " <= " + (tolerance * (Math.abs(value)+Math.abs(oldValue) + 1e-5)));
      if((iter >= minIter) && 2.0*Math.abs(value-oldValue) <= tolerance *
          (Math.abs(value)+Math.abs(oldValue) + 1e-5)){
        System.err.println("AP value difference below tolerance (oldValue: "
          + oldValue + "newValue: " + value);
       
        break;
      }
     
     
      oldValue = value;
     
      runEvaluators();
    }
    converged = true;
    return converged;
  }
 
  public double getTotalValue() {
    return value;
  }
 
  public double getQValue() {
    return qValue;
  }

  public Optimizer getOptimizer() {
    return bfgs;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.pr.CRFTrainerByPR

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.