Package cc.mallet.classify

Source Code of cc.mallet.classify.MaxEntPRTrainer

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByLabelDistribution;
import cc.mallet.classify.constraints.pr.MaxEntL2FLPRConstraints;
import cc.mallet.classify.constraints.pr.MaxEntPRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.NullLabel;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;


/**
* Penalty (soft) version of Posterior Regularization (PR) for training MaxEnt.
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/

public class MaxEntPRTrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt> {

  private static Logger logger = MalletLogger.getLogger(MaxEntPRTrainer.class.getName());
 
  // for using this from the command line
  private boolean normalize = true;
  private boolean useValues = false;
  private int minIterations = 10;
  private int maxIterations = 500;
  private double qGPV;
  private String constraintsFile;

  private boolean converged = false;
  private int numIterations = 0;
  private double tolerance = 0.001;
  private double pGPV;
  private ArrayList<MaxEntPRConstraint> constraints;
  private MaxEnt p;
  private PRAuxClassifier q;
 
  public MaxEntPRTrainer() {}
 
  public MaxEntPRTrainer(ArrayList<MaxEntPRConstraint> constraints) {
    this.constraints = constraints;
  }
 
  public void setPGaussianPriorVariance(double pGPV) {
    this.pGPV = pGPV;
  }
 
  public void setQGaussianPriorVariance(double qGPV) {
    this.qGPV = qGPV;
  }
 
  public void setConstraintsFile(String filename) {
    this.constraintsFile = filename;
  }
 
  public void setUseValues(boolean flag) {
    this.useValues = flag;
  }
 
  public void setMinIterations(int minIterations) {
    this.minIterations = minIterations;
  }
 
  public void setMaxIterations(int minIterations) {
    this.maxIterations = minIterations;
  }
 
  public void setNormalize(boolean normalize) {
    this.normalize = normalize;
  }

  public Optimizer getOptimizer() {
    throw new RuntimeException("Not yet implemented!");
  }

  public int getIteration() {
    return numIterations;
  }
 
  @Override
  public boolean isFinishedTraining() {
    return converged;
  }

  @Override
  public MaxEnt getClassifier() {
    return p;
  }

  @Override
  public MaxEnt train(InstanceList trainingSet) {
    return train(trainingSet,maxIterations);
  }
 
  public MaxEnt train(InstanceList trainingSet, int maxIterations) {
    return train(trainingSet,Math.min(maxIterations,minIterations),maxIterations);
  }
 
  public MaxEnt train(InstanceList data, int minIterations, int maxIterations) {
    if (constraints == null && constraintsFile != null) {
      HashMap<Integer,double[]> constraintsMap =
        FeatureConstraintUtil.readConstraintsFromFile(constraintsFile, data);
      logger.info("number of constraints: " + constraintsMap.size());
      constraints = new ArrayList<MaxEntPRConstraint>();
      MaxEntL2FLPRConstraints prConstraints = new MaxEntL2FLPRConstraints(data.getDataAlphabet().size(),
          data.getTargetAlphabet().size(),useValues,normalize);
      for (int fi : constraintsMap.keySet()) {
        prConstraints.addConstraint(fi, constraintsMap.get(fi), qGPV);
      }
      constraints.add(prConstraints);
    }

    BitSet instancesWithConstraints = new BitSet(data.size());
    for (MaxEntPRConstraint constraint : constraints) {
      BitSet bitset = constraint.preProcess(data);
      instancesWithConstraints.or(bitset);
    }
   
    InstanceList unlabeled = data.cloneEmpty();
    for (int ii = 0; ii < data.size(); ii++) {
      if (instancesWithConstraints.get(ii)) {
        boolean noLabel = data.get(ii).getTarget() == null;
        if (noLabel) {
          data.get(ii).unLock();
          data.get(ii).setTarget(new NullLabel((LabelAlphabet)data.getTargetAlphabet()));
        }
        unlabeled.add(data.get(ii));
      }
    }

    int numFeatures = unlabeled.getDataAlphabet().size();
   
    // setup model
    int numParameters = (numFeatures + 1) * unlabeled.getTargetAlphabet().size();
    if (p == null) {
      p = new MaxEnt(unlabeled.getPipe(),new double[numParameters]);
    }

    // setup aux model
    q = new PRAuxClassifier(unlabeled.getPipe(),constraints);
   
    double oldValue = -Double.MAX_VALUE;
    for (numIterations = 0; numIterations < maxIterations; numIterations++) {

      double[][] base = optimizeQ(unlabeled,p,numIterations==0);
     
      double value = optimizePAndComputeValue(unlabeled,q,base,pGPV);
      logger.info("iteration " + numIterations + " total value " + value);

      if (numIterations >= (minIterations-1) && 2.0*Math.abs(value-oldValue) <= tolerance *
          (Math.abs(value)+Math.abs(oldValue) + 1e-5)){
        logger.info("PR value difference below tolerance (oldValue: " + oldValue + " newValue: " + value + ")");
        converged = true;
        break;
      }
      oldValue = value;
    }
    return p;
  }
 
  private double optimizePAndComputeValue(InstanceList data, PRAuxClassifier q, double[][] base, double pGPV) {
   
    InstanceList dataLabeled = data.cloneEmpty();
   
    double entropy = 0;
   
    int numLabels = data.getTargetAlphabet().size();
    for (int ii = 0; ii < data.size(); ii++) {
      double[] scores = new double[numLabels];
      q.getClassificationScores(data.get(ii), scores);
      for (int li = 0; li < numLabels; li++) {
        if (base != null && base[ii][li] == 0) {
          scores[li] = Double.NEGATIVE_INFINITY;
        }
        else if (base != null) {
          double logP = Math.log(base[ii][li]);
          scores[li] += logP; 
        }
      }
      MatrixOps.expNormalize(scores);
  
      entropy += Maths.getEntropy(scores);

      LabelVector lv = new LabelVector((LabelAlphabet)data.getTargetAlphabet(), scores);
      Instance instance = new Instance(data.get(ii).getData(),lv,null,null);
      dataLabeled.add(instance);
    }
   
    // train supervised
    MaxEntOptimizableByLabelDistribution opt = new  MaxEntOptimizableByLabelDistribution(dataLabeled,p);
    opt.setGaussianPriorVariance(pGPV);

    LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(opt);
    try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
    bfgs.reset();
    try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
   
    double value = 0;
    for (MaxEntPRConstraint constraint : q.getConstraintFeatures()) {
      // plus sign because this returns negative values
      value += constraint.getCompleteValueContribution();
    }
    value += entropy + opt.getValue();
    return value;
  }
 
  private double[][] optimizeQ(InstanceList data, Classifier p, boolean firstIter) {
    int numLabels = data.getTargetAlphabet().size();

    double[][] base;
    if (firstIter) {
      base = null;
    }
    else {
      base = new double[data.size()][numLabels];
      for (int ii = 0; ii < data.size(); ii++) {
        p.classify(data.get(ii)).getLabelVector().addTo(base[ii]);
      }
    }
   
    PRAuxClassifierOptimizable optimizable = new PRAuxClassifierOptimizable(data,base,q);
   
    LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(optimizable);
    try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
    bfgs.reset();
    try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
   
    return base;
  }
}
TOP

Related Classes of cc.mallet.classify.MaxEntPRTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.