/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.logging.Logger;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByLabelDistribution;
import cc.mallet.classify.constraints.pr.MaxEntL2FLPRConstraints;
import cc.mallet.classify.constraints.pr.MaxEntPRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.NullLabel;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
/**
* Penalty (soft) version of Posterior Regularization (PR) for training MaxEnt.
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/
public class MaxEntPRTrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt> {
private static Logger logger = MalletLogger.getLogger(MaxEntPRTrainer.class.getName());
// for using this from the command line
private boolean normalize = true;
private boolean useValues = false;
private int minIterations = 10;
private int maxIterations = 500;
private double qGPV;
private String constraintsFile;
private boolean converged = false;
private int numIterations = 0;
private double tolerance = 0.001;
private double pGPV;
private ArrayList<MaxEntPRConstraint> constraints;
private MaxEnt p;
private PRAuxClassifier q;
public MaxEntPRTrainer() {}
public MaxEntPRTrainer(ArrayList<MaxEntPRConstraint> constraints) {
this.constraints = constraints;
}
public void setPGaussianPriorVariance(double pGPV) {
this.pGPV = pGPV;
}
public void setQGaussianPriorVariance(double qGPV) {
this.qGPV = qGPV;
}
public void setConstraintsFile(String filename) {
this.constraintsFile = filename;
}
public void setUseValues(boolean flag) {
this.useValues = flag;
}
public void setMinIterations(int minIterations) {
this.minIterations = minIterations;
}
public void setMaxIterations(int minIterations) {
this.maxIterations = minIterations;
}
public void setNormalize(boolean normalize) {
this.normalize = normalize;
}
public Optimizer getOptimizer() {
throw new RuntimeException("Not yet implemented!");
}
public int getIteration() {
return numIterations;
}
@Override
public boolean isFinishedTraining() {
return converged;
}
@Override
public MaxEnt getClassifier() {
return p;
}
@Override
public MaxEnt train(InstanceList trainingSet) {
return train(trainingSet,maxIterations);
}
public MaxEnt train(InstanceList trainingSet, int maxIterations) {
return train(trainingSet,Math.min(maxIterations,minIterations),maxIterations);
}
public MaxEnt train(InstanceList data, int minIterations, int maxIterations) {
if (constraints == null && constraintsFile != null) {
HashMap<Integer,double[]> constraintsMap =
FeatureConstraintUtil.readConstraintsFromFile(constraintsFile, data);
logger.info("number of constraints: " + constraintsMap.size());
constraints = new ArrayList<MaxEntPRConstraint>();
MaxEntL2FLPRConstraints prConstraints = new MaxEntL2FLPRConstraints(data.getDataAlphabet().size(),
data.getTargetAlphabet().size(),useValues,normalize);
for (int fi : constraintsMap.keySet()) {
prConstraints.addConstraint(fi, constraintsMap.get(fi), qGPV);
}
constraints.add(prConstraints);
}
BitSet instancesWithConstraints = new BitSet(data.size());
for (MaxEntPRConstraint constraint : constraints) {
BitSet bitset = constraint.preProcess(data);
instancesWithConstraints.or(bitset);
}
InstanceList unlabeled = data.cloneEmpty();
for (int ii = 0; ii < data.size(); ii++) {
if (instancesWithConstraints.get(ii)) {
boolean noLabel = data.get(ii).getTarget() == null;
if (noLabel) {
data.get(ii).unLock();
data.get(ii).setTarget(new NullLabel((LabelAlphabet)data.getTargetAlphabet()));
}
unlabeled.add(data.get(ii));
}
}
int numFeatures = unlabeled.getDataAlphabet().size();
// setup model
int numParameters = (numFeatures + 1) * unlabeled.getTargetAlphabet().size();
if (p == null) {
p = new MaxEnt(unlabeled.getPipe(),new double[numParameters]);
}
// setup aux model
q = new PRAuxClassifier(unlabeled.getPipe(),constraints);
double oldValue = -Double.MAX_VALUE;
for (numIterations = 0; numIterations < maxIterations; numIterations++) {
double[][] base = optimizeQ(unlabeled,p,numIterations==0);
double value = optimizePAndComputeValue(unlabeled,q,base,pGPV);
logger.info("iteration " + numIterations + " total value " + value);
if (numIterations >= (minIterations-1) && 2.0*Math.abs(value-oldValue) <= tolerance *
(Math.abs(value)+Math.abs(oldValue) + 1e-5)){
logger.info("PR value difference below tolerance (oldValue: " + oldValue + " newValue: " + value + ")");
converged = true;
break;
}
oldValue = value;
}
return p;
}
private double optimizePAndComputeValue(InstanceList data, PRAuxClassifier q, double[][] base, double pGPV) {
InstanceList dataLabeled = data.cloneEmpty();
double entropy = 0;
int numLabels = data.getTargetAlphabet().size();
for (int ii = 0; ii < data.size(); ii++) {
double[] scores = new double[numLabels];
q.getClassificationScores(data.get(ii), scores);
for (int li = 0; li < numLabels; li++) {
if (base != null && base[ii][li] == 0) {
scores[li] = Double.NEGATIVE_INFINITY;
}
else if (base != null) {
double logP = Math.log(base[ii][li]);
scores[li] += logP;
}
}
MatrixOps.expNormalize(scores);
entropy += Maths.getEntropy(scores);
LabelVector lv = new LabelVector((LabelAlphabet)data.getTargetAlphabet(), scores);
Instance instance = new Instance(data.get(ii).getData(),lv,null,null);
dataLabeled.add(instance);
}
// train supervised
MaxEntOptimizableByLabelDistribution opt = new MaxEntOptimizableByLabelDistribution(dataLabeled,p);
opt.setGaussianPriorVariance(pGPV);
LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(opt);
try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
bfgs.reset();
try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
double value = 0;
for (MaxEntPRConstraint constraint : q.getConstraintFeatures()) {
// plus sign because this returns negative values
value += constraint.getCompleteValueContribution();
}
value += entropy + opt.getValue();
return value;
}
private double[][] optimizeQ(InstanceList data, Classifier p, boolean firstIter) {
int numLabels = data.getTargetAlphabet().size();
double[][] base;
if (firstIter) {
base = null;
}
else {
base = new double[data.size()][numLabels];
for (int ii = 0; ii < data.size(); ii++) {
p.classify(data.get(ii)).getLabelVector().addTo(base[ii]);
}
}
PRAuxClassifierOptimizable optimizable = new PRAuxClassifierOptimizable(data,base,q);
LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(optimizable);
try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
bfgs.reset();
try { bfgs.optimize(); } catch (Exception e) { e.printStackTrace(); }
return base;
}
}