/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.Logger;
import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.classify.constraints.ge.MaxEntKLFLGEConstraints;
import cc.mallet.classify.constraints.ge.MaxEntL2FLGEConstraints;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
/**
* Training of MaxEnt models with labeled features using
* Generalized Expectation Criteria.
*
* Based on:
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum
* SIGIR 2008
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*
* Better explanations of parameters is given in MaxEntOptimizableByGE
*/
public class MaxEntGETrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt>, Boostable, Serializable {
private static final long serialVersionUID = 1L;
private static Logger logger = MalletLogger.getLogger(MaxEntGETrainer.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntGETrainer.class.getName()+"-pl");
// these are for using this code from the command line
private boolean l2 = false;
private boolean normalize = true;
private boolean useValues = false;
private String constraintsFile;
private int numIterations = 0;
private int maxIterations = Integer.MAX_VALUE;
private double temperature = 1;
private double gaussianPriorVariance = 1;
protected ArrayList<MaxEntGEConstraint> constraints;
private InstanceList trainingList = null;
private MaxEnt classifier = null;
private MaxEntOptimizableByGE ge = null;
private Optimizer opt = null;
public MaxEntGETrainer() {}
public MaxEntGETrainer(ArrayList<MaxEntGEConstraint> constraints) {
this.constraints = constraints;
}
public MaxEntGETrainer(ArrayList<MaxEntGEConstraint> constraints, MaxEnt classifier) {
this.constraints = constraints;
this.classifier = classifier;
}
public void setConstraintsFile(String filename) {
this.constraintsFile = filename;
}
public void setTemperature(double temp) {
this.temperature = temp;
}
public void setGaussianPriorVariance(double variance) {
this.gaussianPriorVariance = variance;
}
public MaxEnt getClassifier () {
return classifier;
}
public void setUseValues(boolean flag) {
this.useValues = flag;
}
public void setL2(boolean flag) {
l2 = flag;
}
public void setNormalize(boolean normalize) {
this.normalize = normalize;
}
public Optimizable.ByGradientValue getOptimizable (InstanceList trainingList) {
if (ge == null) {
ge = new MaxEntOptimizableByGE(trainingList,constraints,classifier);
ge.setTemperature(temperature);
ge.setGaussianPriorVariance(gaussianPriorVariance);
}
return ge;
}
public Optimizer getOptimizer () {
getOptimizable(trainingList);
if (opt == null) {
opt = new LimitedMemoryBFGS(ge);
}
return opt;
}
public void setOptimizer(Optimizer opt) {
this.opt = opt;
}
/**
* Specifies the maximum number of iterations to run during a single call
* to <code>train</code> or <code>trainWithFeatureInduction</code>.
* @return This trainer
*/
public void setMaxIterations (int iter) {
maxIterations = iter;
}
public int getIteration () {
return numIterations;
}
public MaxEnt train (InstanceList trainingList) {
return train (trainingList, maxIterations);
}
public MaxEnt train (InstanceList train, int maxIterations) {
trainingList = train;
if (constraints == null && constraintsFile != null) {
HashMap<Integer,double[]> constraintsMap =
FeatureConstraintUtil.readConstraintsFromFile(constraintsFile, trainingList);
logger.info("number of constraints: " + constraintsMap.size());
constraints = new ArrayList<MaxEntGEConstraint>();
if (l2) {
MaxEntL2FLGEConstraints geConstraints = new MaxEntL2FLGEConstraints(train.getDataAlphabet().size(),
train.getTargetAlphabet().size(),useValues,normalize);
for (int fi : constraintsMap.keySet()) {
geConstraints.addConstraint(fi, constraintsMap.get(fi), 1);
}
constraints.add(geConstraints);
}
else {
MaxEntKLFLGEConstraints geConstraints = new MaxEntKLFLGEConstraints(train.getDataAlphabet().size(),
train.getTargetAlphabet().size(),useValues);
for (int fi : constraintsMap.keySet()) {
geConstraints.addConstraint(fi, constraintsMap.get(fi), 1);
}
constraints = new ArrayList<MaxEntGEConstraint>();
constraints.add(geConstraints);
}
}
getOptimizable(trainingList);
getOptimizer();
if (opt instanceof LimitedMemoryBFGS) {
((LimitedMemoryBFGS)opt).reset();
}
logger.fine ("trainingList.size() = "+trainingList.size());
try {
opt.optimize(maxIterations);
numIterations += maxIterations;
} catch (Exception e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
}
if (maxIterations == Integer.MAX_VALUE && opt instanceof LimitedMemoryBFGS) {
// Run it again because in our and Sam Roweis' experience, BFGS can still
// eke out more likelihood after first convergence by re-running without
// being restricted by its gradient history.
((LimitedMemoryBFGS)opt).reset();
try {
opt.optimize(maxIterations);
numIterations += maxIterations;
} catch (Exception e) {
e.printStackTrace();
logger.info ("Catching exception; saying converged.");
}
}
progressLogger.info("\n"); // progress messages are on one line; move on.
classifier = ge.getClassifier();
return classifier;
}
}