/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.DenseVector;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
/**
* Maximum Entropy classifier.
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
public class MCMaxEnt extends Classifier implements Serializable
{
double [] parameters; // indexed by <labelIndex,featureIndex>
int defaultFeatureIndex;
FeatureSelection featureSelection;
FeatureSelection[] perClassFeatureSelection;
// The default feature is always the feature with highest index
public MCMaxEnt (Pipe dataPipe,
double[] parameters,
FeatureSelection featureSelection,
FeatureSelection[] perClassFeatureSelection)
{
super (dataPipe);
assert (featureSelection == null || perClassFeatureSelection == null);
this.parameters = parameters;
this.featureSelection = featureSelection;
this.perClassFeatureSelection = perClassFeatureSelection;
this.defaultFeatureIndex = dataPipe.getDataAlphabet().size();
// assert (parameters.getNumCols() == defaultFeatureIndex+1);
}
public MCMaxEnt (Pipe dataPipe,
double[] parameters,
FeatureSelection featureSelection)
{
this (dataPipe, parameters, featureSelection, null);
}
public MCMaxEnt (Pipe dataPipe,
double[] parameters,
FeatureSelection[] perClassFeatureSelection)
{
this (dataPipe, parameters, null, perClassFeatureSelection);
}
public MCMaxEnt (Pipe dataPipe,
double[] parameters)
{
this (dataPipe, parameters, null, null);
}
public double[] getParameters ()
{
return parameters;
}
public void setParameter (int classIndex, int featureIndex, double value)
{
parameters[classIndex*(getAlphabet().size()+1) + featureIndex] = value;
}
public void getUnnormalizedClassificationScores (Instance instance, double[] scores)
{
// arrayOutOfBounds if pipe has grown since training
// int numFeatures = getAlphabet().size() + 1;
int numFeatures = this.defaultFeatureIndex + 1;
int numLabels = getLabelAlphabet().size();
assert (scores.length == numLabels);
FeatureVector fv = (FeatureVector) instance.getData ();
// Make sure the feature vector's feature dictionary matches
// what we are expecting from our data pipe (and thus our notion
// of feature probabilities.
assert (fv.getAlphabet ()
== this.instancePipe.getDataAlphabet ());
// Include the feature weights according to each label
for (int li = 0; li < numLabels; li++) {
scores[li] = parameters[li*numFeatures + defaultFeatureIndex]
+ MatrixOps.rowDotProduct (parameters, numFeatures,
li, fv,
defaultFeatureIndex,
(perClassFeatureSelection == null
? featureSelection
: perClassFeatureSelection[li]));
}
}
public void getClassificationScores (Instance instance, double[] scores)
{
int numLabels = getLabelAlphabet().size();
assert (scores.length == numLabels);
FeatureVector fv = (FeatureVector) instance.getData ();
// Make sure the feature vector's feature dictionary matches
// what we are expecting from our data pipe (and thus our notion
// of feature probabilities.
assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
// arrayOutOfBounds if pipe has grown since training
// int numFeatures = getAlphabet().size() + 1;
int numFeatures = this.defaultFeatureIndex + 1;
// Include the feature weights according to each label
for (int li = 0; li < numLabels; li++) {
scores[li] = parameters[li*numFeatures + defaultFeatureIndex]
+ MatrixOps.rowDotProduct (parameters, numFeatures,
li, fv,
defaultFeatureIndex,
(perClassFeatureSelection == null
? featureSelection
: perClassFeatureSelection[li]));
// xxxNaN assert (!Double.isNaN(scores[li])) : "li="+li;
}
// Move scores to a range where exp() is accurate, and normalize
double max = MatrixOps.max (scores);
double sum = 0;
for (int li = 0; li < numLabels; li++)
sum += (scores[li] = Math.exp (scores[li] - max));
for (int li = 0; li < numLabels; li++) {
scores[li] /= sum;
// xxxNaN assert (!Double.isNaN(scores[li]));
}
}
public Classification classify (Instance instance)
{
int numClasses = getLabelAlphabet().size();
double[] scores = new double[numClasses];
getClassificationScores (instance, scores);
// Create and return a Classification object
return new Classification (instance, this,
new LabelVector (getLabelAlphabet(),
scores));
}
public void print ()
{
final Alphabet dict = getAlphabet();
final LabelAlphabet labelDict = getLabelAlphabet();
int numFeatures = dict.size() + 1;
int numLabels = labelDict.size();
// Include the feature weights according to each label
for (int li = 0; li < numLabels; li++) {
System.out.println ("FEATURES FOR CLASS "+labelDict.lookupObject (li));
System.out.println (" <default> "+parameters [li*numFeatures + defaultFeatureIndex]);
for (int i = 0; i < defaultFeatureIndex; i++) {
Object name = dict.lookupObject (i);
double weight = parameters [li*numFeatures + i];
System.out.println (" "+name+" "+weight);
}
}
}
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
static final int NULL_INTEGER = -1;
private void writeObject(ObjectOutputStream out) throws IOException
{
out.writeInt(CURRENT_SERIAL_VERSION);
out.writeObject(getInstancePipe());
int np = parameters.length;
out.writeInt(np);
for (int p = 0; p < np; p++)
out.writeDouble(parameters[p]);
out.writeInt(defaultFeatureIndex);
if (featureSelection == null)
out.writeInt(NULL_INTEGER);
else
{
out.writeInt(1);
out.writeObject(featureSelection);
}
if (perClassFeatureSelection == null)
out.writeInt(NULL_INTEGER);
else
{
out.writeInt(perClassFeatureSelection.length);
for (int i = 0; i < perClassFeatureSelection.length; i++)
if (perClassFeatureSelection[i] == null)
out.writeInt(NULL_INTEGER);
else
{
out.writeInt(1);
out.writeObject(perClassFeatureSelection[i]);
}
}
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
if (version != CURRENT_SERIAL_VERSION)
throw new ClassNotFoundException("Mismatched MCMaxEnt versions: wanted " +
CURRENT_SERIAL_VERSION + ", got " +
version);
instancePipe = (Pipe) in.readObject();
int np = in.readInt();
parameters = new double[np];
for (int p = 0; p < np; p++)
parameters[p] = in.readDouble();
defaultFeatureIndex = in.readInt();
int opt = in.readInt();
if (opt == 1)
featureSelection = (FeatureSelection)in.readObject();
int nfs = in.readInt();
if (nfs >= 0)
{
perClassFeatureSelection = new FeatureSelection[nfs];
for (int i = 0; i < nfs; i++)
{
opt = in.readInt();
if (opt == 1)
perClassFeatureSelection[i] = (FeatureSelection)in.readObject();
}
}
}
}