Package classification

Source Code of classification.MaxEntropy$Objective

package classification;

import java.util.ArrayList;

import algo.ConjugateGradient;
import algo.GradientAscent;
import types.Alphabet;
import types.ClassificationInstance;
import types.DifferentiableObjective;
import types.FeatureFunction;
import types.LinearClassifier;
import types.StaticUtils;

public class MaxEntropy {

  double gaussianPriorVariance;
  double numObservations;
  Alphabet xAlphabet;
  Alphabet yAlphabet;
  FeatureFunction fxy;

  public MaxEntropy(double gaussianPriorVariance, Alphabet xAlphabet,
      Alphabet yAlphabet, FeatureFunction fxy) {
    this.gaussianPriorVariance = gaussianPriorVariance;
    this.xAlphabet = xAlphabet;
    this.yAlphabet = yAlphabet;
    this.fxy = fxy;
  }

  public LinearClassifier batchTrain(
      ArrayList<ClassificationInstance> trainingData) {
    Objective obj = new Objective(trainingData);
    // perform gradient descent
    @SuppressWarnings("unused")
    GradientAscent gaoptimizer;
    ConjugateGradient optimizer = new ConjugateGradient(obj
        .getNumParameters());
    @SuppressWarnings("unused")
    boolean success = optimizer.maximize(obj);
    System.out.println("valCalls = " + obj.numValueCalls
        + "   gradientCalls=" + obj.numGradientCalls);
    return obj.classifier;
  }

  /**
   * An objective for our max-ent model. That is: max_\lambda sum_i log
   * Pr(y_i|x_i) - 1/var * ||\lambda||^2 where var is the Gaussian prior
   * variance, and p(y|x) = exp(f(x,y)*lambda)/Z(x).
   *
   * @author kuzman
   *
   */
  class Objective implements DifferentiableObjective {
    double[] empiricalExpectations;
    LinearClassifier classifier;
    ArrayList<ClassificationInstance> trainingData;
    int numValueCalls = 0;
    int numGradientCalls = 0;

    Objective(ArrayList<ClassificationInstance> trainingData) {
      this.trainingData = trainingData;
      // compute empirical expectations...
      empiricalExpectations = new double[fxy.wSize()];
      for (ClassificationInstance inst : trainingData) {
        StaticUtils.plusEquals(empiricalExpectations, fxy.apply(inst.x,
            inst.y));
      }
      classifier = new LinearClassifier(xAlphabet, yAlphabet, fxy);
    }

    public double getValue() {
      numValueCalls++;
      // value = log(prob(data)) - 1/gaussianPriorVariance * ||lambda||^2
      double val = 0;
      for (ClassificationInstance inst : trainingData) {
        double[] scores = classifier.scores(inst.x);
        double[] probs = StaticUtils.exp(scores);
        double Z = StaticUtils.sum(probs);
        val += scores[inst.y] - Math.log(Z);
      }
      val -= 1 / (2 * gaussianPriorVariance)
          * StaticUtils.twoNormSquared(classifier.w);
      return val;
    }

    public void getGradient(double[] gradient) {
      numGradientCalls++;
      // gradient = empiricalExpectations - modelExpectations
      // -2/gaussianPriorVariance * params
      double[] modelExpectations = new double[gradient.length];
      for (int i = 0; i < gradient.length; i++) {
        gradient[i] = empiricalExpectations[i];
        modelExpectations[i] = 0;
      }
      for (ClassificationInstance inst : trainingData) {
        double[] scores = classifier.scores(inst.x);
        double[] probs = StaticUtils.exp(scores);
        double Z = StaticUtils.sum(probs);
        for (int y = 0; y < yAlphabet.size(); y++) {
          StaticUtils.plusEquals(modelExpectations, fxy.apply(inst.x,
              y), probs[y] / Z);
        }
      }
      for (int i = 0; i < gradient.length; i++) {
        gradient[i] -= modelExpectations[i];
        gradient[i] -= 1 / gaussianPriorVariance * classifier.w[i];
      }

    }

    public void setParameters(double[] newParameters) {
      System.arraycopy(newParameters, 0, classifier.w, 0,
          newParameters.length);
    }

    public void getParameters(double[] params) {
      System.arraycopy(classifier.w, 0, params, 0, params.length);
    }

    public int getNumParameters() {
      return classifier.w.length;
    }

  }

}
TOP

Related Classes of classification.MaxEntropy$Objective

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.