Package cc.mallet.classify

Source Code of cc.mallet.classify.BalancedWinnowTrainer

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


package cc.mallet.classify;

import java.io.Serializable;
import java.util.Arrays;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;



/**
* An implementation of the training methods of a BalancedWinnow
* on-line classifier. Given a labeled instance (x, y) the algorithm
* computes dot(x, wi), for w1, ... , wc where wi is the weight
* vector for class i.  The instance is classified as class j
* if the value of dot(x, wj) is the largest among the c dot
* products.
*
* <p>The weight vectors are updated whenever the the classifier
* makes a mistake or just barely got the correct answer (highest
* dot product is within delta percent higher than the second highest).
* Suppose the classifier guessed j and answer was j'. For each
* feature i that is present, multiply w_ji by (1-epsilon) and
* multiply w_j'i by (1+epsilon)
*
* <p>The above procedure is done multiple times to the training
* examples (default is 5), and epsilon is cut by the cooling
* rate at each iteration (default is cutting epsilon by half).
*
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class BalancedWinnowTrainer extends ClassifierTrainer<BalancedWinnow> implements Boostable, Serializable
{
  private static final long serialVersionUID = 1L;

  /**
   * 0.5
   */
  public static final double DEFAULT_EPSILON = .5;
  /**
   * 0.1
   */
  public static final double DEFAULT_DELTA = .1;
  /**
   * 30
   */
  public static final int DEFAULT_MAX_ITERATIONS = 30;
  /**
   * 0.5
   */
  public static final double DEFAULT_COOLING_RATE = .5;

  double m_epsilon;
  double m_delta;
  int m_maxIterations;
  double m_coolingRate;
  /**
   * Array of weights, one for each class and feature, initialized to 1.
   * For each class, there is an additional default "feature" weight
   * that is set to 1 in every example (it remains constant; this is
   * used to prevent the instance from having 0 dot product with a class).
   */
  double[][] m_weights;
 
  BalancedWinnow classifier;
  public BalancedWinnow getClassifier () { return classifier; }


  /**
   * Default constructor. Sets all features to defaults.
   */
  public BalancedWinnowTrainer()
  {
    this(DEFAULT_EPSILON, DEFAULT_DELTA, DEFAULT_MAX_ITERATIONS,  DEFAULT_COOLING_RATE);
  }

  /**
   * @param epsilon percentage by which to increase/decrease weight vectors
   * when an example is misclassified.
   * @param delta percentage by which the highest (and correct) dot product
   * should exceed the second highest dot product before we consider an example
   * to be correctly classified (margin width) when adjusting weights.
   * @param maxIterations maximum number of times to loop through training examples.
   * @param coolingRate percentage of epsilon to decrease after each iteration
   */
  public BalancedWinnowTrainer(double epsilon,
      double delta,
      int maxIterations,
      double coolingRate)
  {
    m_epsilon = epsilon;
    m_delta = delta;
    m_maxIterations = maxIterations;
    m_coolingRate = coolingRate;
  }

  /**
   * Trains the classifier on the instance list, updating
   * class weight vectors as appropriate
   * @param trainingList Instance list to be trained on
   * @return Classifier object containing learned weights
   */
  public BalancedWinnow train (InstanceList trainingList)
  {
    FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
    if (selectedFeatures != null)
      // xxx Attend to FeatureSelection!!!
      throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");

    double epsilon = m_epsilon;
    Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();
    int numLabels = trainingList.getTargetAlphabet().size();
    int numFeats = dict.size();
    m_weights = new double [numLabels][numFeats+1];

    // init weights to 1
    for(int i = 0; i < numLabels; i++)
      Arrays.fill(m_weights[i], 1.0);

    // Loop through training instances multiple times
    double[] results = new double[numLabels];
    for (int iter = 0; iter < m_maxIterations; iter++) {

      // loop through all instances
      for (int ii = 0; ii < trainingList.size(); ii++) {
        Instance inst = trainingList.get(ii);
        Labeling labeling = inst.getLabeling ();
        FeatureVector fv = (FeatureVector) inst.getData();
        int fvisize = fv.numLocations();
        int correctIndex = labeling.getBestIndex();
        Arrays.fill(results, 0);

        // compute dot(x, wi) for each class i
        for(int lpos = 0; lpos < numLabels; lpos++) {
          for(int fvi = 0; fvi < fvisize; fvi++) {
            int fi = fv.indexAtLocation(fvi);
            double vi = fv.valueAtLocation(fvi);
            results[lpos] += vi * m_weights[lpos][fi];
          }

          // This extra value comes from the extra
          // "feature" present in all examples
          results[lpos] += m_weights[lpos][numFeats];
        }

        // Get indices of the classes with the 2 highest dot products
        int predictedIndex = 0;
        int secondHighestIndex = 0;
        double max = Double.MIN_VALUE;
        double secondMax = Double.MIN_VALUE;
        for (int i = 0; i < numLabels; i++) {
          if (results[i] > max) {
            secondMax = max;
            max = results[i];
            secondHighestIndex = predictedIndex;
            predictedIndex = i;
          }
          else if (results[i] > secondMax) {
            secondMax = results[i];
            secondHighestIndex = i;
          }
        }

        // Adjust weights if this example is mispredicted
        // or just barely correct
        if (predictedIndex != correctIndex) {
          for (int fvi = 0; fvi < fvisize; fvi++) {
            int fi = fv.indexAtLocation(fvi);
            m_weights[predictedIndex][fi] *= (1 - epsilon);
            m_weights[correctIndex][fi] *= (1 + epsilon);
          }
          m_weights[predictedIndex][numFeats] *= (1 - epsilon);
          m_weights[correctIndex][numFeats] *= (1 + epsilon);
        }
        else if (max/secondMax - 1 < m_delta) {
          for (int fvi = 0; fvi < fvisize; fvi++) {
            int fi = fv.indexAtLocation(fvi);
            m_weights[secondHighestIndex][fi] *= (1 - epsilon);
            m_weights[correctIndex][fi] *= (1 + epsilon);
          }
          m_weights[secondHighestIndex][numFeats] *= (1 - epsilon);
          m_weights[correctIndex][numFeats] *= (1 + epsilon);
        }
      }
      // Cut epsilon by the cooling rate
      epsilon *= (1-m_coolingRate);
    }       
    this.classifier = new BalancedWinnow (trainingList.getPipe(), m_weights);
    return classifier;
  }

}
TOP

Related Classes of cc.mallet.classify.BalancedWinnowTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.