/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.Serializable;
import java.util.Arrays;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
/**
* An implementation of the training methods of a BalancedWinnow
* on-line classifier. Given a labeled instance (x, y) the algorithm
* computes dot(x, wi), for w1, ... , wc where wi is the weight
* vector for class i. The instance is classified as class j
* if the value of dot(x, wj) is the largest among the c dot
* products.
*
* <p>The weight vectors are updated whenever the the classifier
* makes a mistake or just barely got the correct answer (highest
* dot product is within delta percent higher than the second highest).
* Suppose the classifier guessed j and answer was j'. For each
* feature i that is present, multiply w_ji by (1-epsilon) and
* multiply w_j'i by (1+epsilon)
*
* <p>The above procedure is done multiple times to the training
* examples (default is 5), and epsilon is cut by the cooling
* rate at each iteration (default is cutting epsilon by half).
*
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class BalancedWinnowTrainer extends ClassifierTrainer<BalancedWinnow> implements Boostable, Serializable
{
private static final long serialVersionUID = 1L;
/**
* 0.5
*/
public static final double DEFAULT_EPSILON = .5;
/**
* 0.1
*/
public static final double DEFAULT_DELTA = .1;
/**
* 30
*/
public static final int DEFAULT_MAX_ITERATIONS = 30;
/**
* 0.5
*/
public static final double DEFAULT_COOLING_RATE = .5;
double m_epsilon;
double m_delta;
int m_maxIterations;
double m_coolingRate;
/**
* Array of weights, one for each class and feature, initialized to 1.
* For each class, there is an additional default "feature" weight
* that is set to 1 in every example (it remains constant; this is
* used to prevent the instance from having 0 dot product with a class).
*/
double[][] m_weights;
BalancedWinnow classifier;
public BalancedWinnow getClassifier () { return classifier; }
/**
* Default constructor. Sets all features to defaults.
*/
public BalancedWinnowTrainer()
{
this(DEFAULT_EPSILON, DEFAULT_DELTA, DEFAULT_MAX_ITERATIONS, DEFAULT_COOLING_RATE);
}
/**
* @param epsilon percentage by which to increase/decrease weight vectors
* when an example is misclassified.
* @param delta percentage by which the highest (and correct) dot product
* should exceed the second highest dot product before we consider an example
* to be correctly classified (margin width) when adjusting weights.
* @param maxIterations maximum number of times to loop through training examples.
* @param coolingRate percentage of epsilon to decrease after each iteration
*/
public BalancedWinnowTrainer(double epsilon,
double delta,
int maxIterations,
double coolingRate)
{
m_epsilon = epsilon;
m_delta = delta;
m_maxIterations = maxIterations;
m_coolingRate = coolingRate;
}
/**
* Trains the classifier on the instance list, updating
* class weight vectors as appropriate
* @param trainingList Instance list to be trained on
* @return Classifier object containing learned weights
*/
public BalancedWinnow train (InstanceList trainingList)
{
FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
if (selectedFeatures != null)
// xxx Attend to FeatureSelection!!!
throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");
double epsilon = m_epsilon;
Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();
int numLabels = trainingList.getTargetAlphabet().size();
int numFeats = dict.size();
m_weights = new double [numLabels][numFeats+1];
// init weights to 1
for(int i = 0; i < numLabels; i++)
Arrays.fill(m_weights[i], 1.0);
// Loop through training instances multiple times
double[] results = new double[numLabels];
for (int iter = 0; iter < m_maxIterations; iter++) {
// loop through all instances
for (int ii = 0; ii < trainingList.size(); ii++) {
Instance inst = trainingList.get(ii);
Labeling labeling = inst.getLabeling ();
FeatureVector fv = (FeatureVector) inst.getData();
int fvisize = fv.numLocations();
int correctIndex = labeling.getBestIndex();
Arrays.fill(results, 0);
// compute dot(x, wi) for each class i
for(int lpos = 0; lpos < numLabels; lpos++) {
for(int fvi = 0; fvi < fvisize; fvi++) {
int fi = fv.indexAtLocation(fvi);
double vi = fv.valueAtLocation(fvi);
results[lpos] += vi * m_weights[lpos][fi];
}
// This extra value comes from the extra
// "feature" present in all examples
results[lpos] += m_weights[lpos][numFeats];
}
// Get indices of the classes with the 2 highest dot products
int predictedIndex = 0;
int secondHighestIndex = 0;
double max = Double.MIN_VALUE;
double secondMax = Double.MIN_VALUE;
for (int i = 0; i < numLabels; i++) {
if (results[i] > max) {
secondMax = max;
max = results[i];
secondHighestIndex = predictedIndex;
predictedIndex = i;
}
else if (results[i] > secondMax) {
secondMax = results[i];
secondHighestIndex = i;
}
}
// Adjust weights if this example is mispredicted
// or just barely correct
if (predictedIndex != correctIndex) {
for (int fvi = 0; fvi < fvisize; fvi++) {
int fi = fv.indexAtLocation(fvi);
m_weights[predictedIndex][fi] *= (1 - epsilon);
m_weights[correctIndex][fi] *= (1 + epsilon);
}
m_weights[predictedIndex][numFeats] *= (1 - epsilon);
m_weights[correctIndex][numFeats] *= (1 + epsilon);
}
else if (max/secondMax - 1 < m_delta) {
for (int fvi = 0; fvi < fvisize; fvi++) {
int fi = fv.indexAtLocation(fvi);
m_weights[secondHighestIndex][fi] *= (1 - epsilon);
m_weights[correctIndex][fi] *= (1 + epsilon);
}
m_weights[secondHighestIndex][numFeats] *= (1 - epsilon);
m_weights[correctIndex][numFeats] *= (1 + epsilon);
}
}
// Cut epsilon by the cooling rate
epsilon *= (1-m_coolingRate);
}
this.classifier = new BalancedWinnow (trainingList.getPipe(), m_weights);
return classifier;
}
}