package edu.stanford.nlp.classify;
import java.util.List;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.ReflectionLoading;
/**
* Builds a classifier for binary logistic regression problems.
* This uses the standard statistics textbook formulation of binary
* logistic regression, which is more efficient than using the
* LinearClassifier class.
*
* @author Ramesh Nallapati nmramesh@cs.stanford.edu
*
*/
public class LogisticClassifierFactory<L,F> implements ClassifierFactory<L, F, LogisticClassifier<L,F>> {
private static final long serialVersionUID = 1L;
private double[] weights;
private Index<F> featureIndex;
private L[] classes = ErasureUtils.<L>mkTArray(Object.class,2);
public LogisticClassifier<L,F> trainWeightedData(GeneralDataset<L,F> data, float[] dataWeights){
if(data instanceof RVFDataset)
((RVFDataset<L,F>)data).ensureRealValues();
if (data.labelIndex.size() != 2) {
throw new RuntimeException("LogisticClassifier is only for binary classification!");
}
Minimizer<DiffFunction> minim;
LogisticObjectiveFunction lof = null;
if(data instanceof Dataset<?,?>)
lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC),dataWeights);
else if(data instanceof RVFDataset<?,?>)
lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC),dataWeights);
minim = new QNMinimizer(lof);
weights = minim.minimize(lof, 1e-4, new double[data.numFeatureTypes()]);
featureIndex = data.featureIndex;
classes[0] = data.labelIndex.get(0);
classes[1] = data.labelIndex.get(1);
return new LogisticClassifier<L,F>(weights,featureIndex,classes);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data) {
return trainClassifier(data, 0.0);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, LogPrior prior, boolean biased) {
return trainClassifier(data, 0.0, 1e-4, prior, biased);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, double l1reg) {
return trainClassifier(data, l1reg, 1e-4);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol) {
return trainClassifier(data, l1reg, tol, new LogPrior(LogPrior.LogPriorType.QUADRATIC), false);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, LogPrior prior) {
return trainClassifier(data, l1reg, tol, prior, false);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, boolean biased) {
return trainClassifier(data, l1reg, tol, new LogPrior(LogPrior.LogPriorType.QUADRATIC), biased);
}
public LogisticClassifier<L,F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, LogPrior prior, boolean biased) {
if(data instanceof RVFDataset)
((RVFDataset<L,F>)data).ensureRealValues();
if (data.labelIndex.size() != 2) {
throw new RuntimeException("LogisticClassifier is only for binary classification!");
}
Minimizer<DiffFunction> minim;
if (!biased) {
LogisticObjectiveFunction lof = null;
if(data instanceof Dataset<?,?>)
lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
else if(data instanceof RVFDataset<?,?>)
lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), prior);
if (l1reg > 0.0) {
minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
} else {
minim = new QNMinimizer(lof);
}
weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
} else {
BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
if (l1reg > 0.0) {
minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
} else {
minim = new QNMinimizer(lof);
}
weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
}
featureIndex = data.featureIndex;
classes[0] = data.labelIndex.get(0);
classes[1] = data.labelIndex.get(1);
return new LogisticClassifier<L,F>(weights,featureIndex,classes);
}
@Deprecated //this method no longer required by the ClassifierFactory Interface.
public LogisticClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
// TODO Auto-generated method stub
return null;
}
}