Package cc.mallet.classify

Source Code of cc.mallet.classify.MaxEntGETrainer

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.classify.constraints.ge.MaxEntKLFLGEConstraints;
import cc.mallet.classify.constraints.ge.MaxEntL2FLGEConstraints;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;

/**
* Training of MaxEnt models with labeled features using
* Generalized Expectation Criteria.
*
* Based on:
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum
* SIGIR 2008
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*
* Better explanations of parameters is given in MaxEntOptimizableByGE
*/

public class MaxEntGETrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt>, Boostable, Serializable {

  private static final long serialVersionUID = 1L;
  private static Logger logger = MalletLogger.getLogger(MaxEntGETrainer.class.getName());
  private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntGETrainer.class.getName()+"-pl");

  // these are for using this code from the command line
  private boolean l2 = false;
  private boolean normalize = true;
  private boolean useValues = false;
  private String constraintsFile;
 
  private int numIterations = 0;
  private int maxIterations = Integer.MAX_VALUE;
  private double temperature = 1;
  private double gaussianPriorVariance = 1;

  protected ArrayList<MaxEntGEConstraint> constraints;
  private InstanceList trainingList = null;
  private MaxEnt classifier = null;
  private MaxEntOptimizableByGE ge = null;
  private Optimizer opt = null;

  public MaxEntGETrainer() {}
 
  public MaxEntGETrainer(ArrayList<MaxEntGEConstraint> constraints) {
    this.constraints = constraints;
  }
 
  public MaxEntGETrainer(ArrayList<MaxEntGEConstraint> constraints, MaxEnt classifier) {
    this.constraints = constraints;
    this.classifier = classifier;
  }
 
  public void setConstraintsFile(String filename) {
    this.constraintsFile = filename;
  }
 
  public void setTemperature(double temp) {
    this.temperature = temp;
  }
 
  public void setGaussianPriorVariance(double variance) {
    this.gaussianPriorVariance = variance;
  }
 
  public MaxEnt getClassifier () {
    return classifier;
  }

  public void setUseValues(boolean flag) {
    this.useValues = flag;
  }
 
  public void setL2(boolean flag) {
    l2 = flag;
  }
 
  public void setNormalize(boolean normalize) {
    this.normalize = normalize;
  }
 
  public Optimizable.ByGradientValue getOptimizable (InstanceList trainingList) {
    if (ge == null) {
      ge = new MaxEntOptimizableByGE(trainingList,constraints,classifier);
      ge.setTemperature(temperature);
      ge.setGaussianPriorVariance(gaussianPriorVariance);
    }
    return ge;
  }

  public Optimizer getOptimizer () {
    getOptimizable(trainingList);
    if (opt == null) {
      opt = new LimitedMemoryBFGS(ge);
    }
    return opt;
  }
 
  public void setOptimizer(Optimizer opt) {
    this.opt = opt;
  }

  /**
   * Specifies the maximum number of iterations to run during a single call
   * to <code>train</code> or <code>trainWithFeatureInduction</code>.
   * @return This trainer
   */
  public void setMaxIterations (int iter) {
    maxIterations = iter;
  }
 
  public int getIteration () {
    return numIterations;
  }

  public MaxEnt train (InstanceList trainingList) {
    return train (trainingList, maxIterations);
  }

  public MaxEnt train (InstanceList train, int maxIterations) {
    trainingList = train;

    if (constraints == null && constraintsFile != null) {
      HashMap<Integer,double[]> constraintsMap =
        FeatureConstraintUtil.readConstraintsFromFile(constraintsFile, trainingList);

      logger.info("number of constraints: " + constraintsMap.size());
      constraints = new ArrayList<MaxEntGEConstraint>();
      if (l2) {
        MaxEntL2FLGEConstraints geConstraints = new MaxEntL2FLGEConstraints(train.getDataAlphabet().size(),
            train.getTargetAlphabet().size(),useValues,normalize);
        for (int fi : constraintsMap.keySet()) {
          geConstraints.addConstraint(fi, constraintsMap.get(fi), 1);
        }
        constraints.add(geConstraints);
      }
      else {
        MaxEntKLFLGEConstraints geConstraints = new MaxEntKLFLGEConstraints(train.getDataAlphabet().size(),
            train.getTargetAlphabet().size(),useValues);
        for (int fi : constraintsMap.keySet()) {
          geConstraints.addConstraint(fi, constraintsMap.get(fi), 1);
        }
        constraints = new ArrayList<MaxEntGEConstraint>();
        constraints.add(geConstraints);
      }
    }
   
    getOptimizable(trainingList);
    getOptimizer();
   
    if (opt instanceof LimitedMemoryBFGS) {
      ((LimitedMemoryBFGS)opt).reset();
    }   
   
    logger.fine ("trainingList.size() = "+trainingList.size());

    try {
      opt.optimize(maxIterations);
      numIterations += maxIterations;
    } catch (Exception e) {
      e.printStackTrace();
      logger.info ("Catching exception; saying converged.");
    }

    if (maxIterations == Integer.MAX_VALUE && opt instanceof LimitedMemoryBFGS) {
      // Run it again because in our and Sam Roweis' experience, BFGS can still
      // eke out more likelihood after first convergence by re-running without
      // being restricted by its gradient history.
      ((LimitedMemoryBFGS)opt).reset();
      try {
        opt.optimize(maxIterations);
        numIterations += maxIterations;
      } catch (Exception e) {
        e.printStackTrace();
        logger.info ("Catching exception; saying converged.");
      }
    }
    progressLogger.info("\n"); //  progress messages are on one line; move on.
   
    classifier = ge.getClassifier();
    return classifier;
  }
}
TOP

Related Classes of cc.mallet.classify.MaxEntGETrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.