Package edu.stanford.nlp.classify

Source Code of edu.stanford.nlp.classify.LogisticClassifier

// Stanford Classifier - a multiclass maxent classifier
// LogisticClassifier
// Copyright (c) 2003-2007 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: java-nlp-user@lists.stanford.edu
//    Licensing: java-nlp-support@lists.stanford.edu
//    http://www-nlp.stanford.edu/software/classifier.shtml

package edu.stanford.nlp.classify;

import java.io.File;
import java.io.Serializable;
import java.util.*;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;

/**
* A classifier for binary logistic regression problems.
* This uses the standard statistics textbook formulation of binary
* logistic regression, which is more efficient than using the
* LinearClassifier class.
*
* @author Galen Andrew
* @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization)
* @author Ramesh Nallapati nmramesh@cs.stanford.edu {@link #justificationOf(Collection)}
*
* @param <L> The type of the labels in the Dataset
* @param <F> The type of the features in the Dataset
*/
public class LogisticClassifier<L, F> implements Classifier<L, F>, Serializable, RVFClassifier<L, F> {

  //TODO make it implement ProbabilisticClassifier as well. --Ramesh 12/03/2009.
  /**
   *
   */
  private static final long serialVersionUID = 6672245467246897192L;
  private double[] weights;
  private Index<F> featureIndex;
  private L[] classes = ErasureUtils.<L>mkTArray(Object.class,2);
  @Deprecated
  private LogPrior prior;
  @Deprecated
  private boolean biased = false;

  @Override
  public String toString() {
    if (featureIndex == null) {
      return "";
    }

    StringBuilder sb = new StringBuilder();
    for (F f : featureIndex) {
      sb.append(classes[1]).append(" / ").append(f).append(" = ").append(weights[featureIndex.indexOf(f)]);
    }

    return sb.toString();
  }

  public L getLabelForInternalPositiveClass(){
    return classes[1];
  }

  public L getLabelForInternalNegativeClass(){
    return classes[0];
  }

  // todo [cdm]: This method should be removed, and weightsAsGenericCounter renamed as weightsAsCounter!
  public Counter<String> weightsAsCounter() {
    Counter<String> c = new ClassicCounter<String>();
    for (F f : featureIndex) {
      c.incrementCount(classes[1]+" / "+f, weights[featureIndex.indexOf(f)]);
    }

    return c;
  }

  public Counter<F> weightsAsGenericCounter() {
    Counter<F> c = new ClassicCounter<F>();
    for (F f : featureIndex) {
      double w =  weights[featureIndex.indexOf(f)];
      if(w != 0.0)
        c.setCount(f, w);
    }
    return c;
  }

  public Index<F> getFeatureIndex() {
    return featureIndex;
  }

  public double[] getWeights() {
    return weights;
  }


  public LogisticClassifier(double[] weights, Index<F> featureIndex, L[] classes){
    this.weights = weights;
    this.featureIndex = featureIndex;
    this.classes = classes;
  }


  @Deprecated //use  LogisticClassifierFactory instead
  public LogisticClassifier(boolean biased) {
    this(new LogPrior(LogPrior.LogPriorType.QUADRATIC), biased);
  }

  @Deprecated //use  in LogisticClassifierFactory instead.
  public LogisticClassifier(LogPrior prior) {
    this.prior = prior;
  }


  @Deprecated //use  in LogisticClassifierFactory instead
  public LogisticClassifier(LogPrior prior, boolean biased) {
    this.prior = prior;
    this.biased = biased;
  }

  public Collection<L> labels() {
    Collection<L> l = new LinkedList<L>();
    l.add(classes[0]);
    l.add(classes[1]);
    return l;
  }

  public L classOf(Datum<L, F> datum) {
    if(datum instanceof RVFDatum<?,?>){
      return classOfRVFDatum((RVFDatum<L,F>) datum);
    }
    return classOf(datum.asFeatures());
  }

  @Deprecated //use classOf(Datum) instead.
  public L classOf(RVFDatum<L, F> example) {
    return classOf(example.asFeaturesCounter());
  }

  private L classOfRVFDatum(RVFDatum<L, F> example) {
    return classOf(example.asFeaturesCounter());
  }

  public L classOf(Counter<F> features) {
    if (scoreOf(features) > 0) {
      return classes[1];
    }
    return classes[0];
  }

  public L classOf(Collection<F> features) {
    if (scoreOf(features) > 0) {
      return classes[1];
    }
    return classes[0];
  }


  public double scoreOf(Collection<F> features) {
    double sum = 0;
    for (F feature : features) {
      int f = featureIndex.indexOf(feature);
      if (f >= 0) {
        sum += weights[f];
      }
    }
    return sum;
  }

  public double scoreOf(Counter<F> features) {
    double sum = 0;
    for (F feature : features.keySet()) {
      int f = featureIndex.indexOf(feature);
      if (f >= 0) {
        sum += weights[f]*features.getCount(feature);
      }
    }
    return sum;
  }
  /*
   * returns the weights to each feature assigned by the classifier
   * nmramesh@cs.stanford.edu
   */
  public Counter<F> justificationOf(Counter<F> features){
    Counter<F> fWts = new ClassicCounter<F>();
    for (F feature : features.keySet()) {
      int f = featureIndex.indexOf(feature);
      if (f >= 0) {
        fWts.incrementCount(feature,weights[f]*features.getCount(feature));
      }
    }
    return fWts;
  }
  /**
   * returns the weights assigned by the classifier to each feature
   */
  public Counter<F> justificationOf(Collection<F> features){
    Counter<F> fWts = new ClassicCounter<F>();
    for (F feature : features) {
      int f = featureIndex.indexOf(feature);
      if (f >= 0) {
        fWts.incrementCount(feature,weights[f]);
      }
    }
    return fWts;
  }

  /**
   * returns the scores for both the classes
   */
  public Counter<L> scoresOf(Datum<L, F> datum) {
    if(datum instanceof RVFDatum<?,?>)return scoresOfRVFDatum((RVFDatum<L,F>)datum);
    Collection<F> features = datum.asFeatures();
    double sum = scoreOf(features);
    Counter<L> c = new ClassicCounter<L>();
    c.setCount(classes[0], -sum);
    c.setCount(classes[1], sum);
    return c;
  }


  @Deprecated //use scoresOfDatum(Datum) instead.
  public Counter<L> scoresOf(RVFDatum<L, F> example) {
    return scoresOfRVFDatum(example);
  }


  private Counter<L> scoresOfRVFDatum(RVFDatum<L, F> example) {
    Counter<F> features = example.asFeaturesCounter();
    double sum = scoreOf(features);
    Counter<L> c = new ClassicCounter<L>();
    c.setCount(classes[0], -sum);
    c.setCount(classes[1], sum);
    return c;
  }

  public double probabilityOf(Datum<L, F> example) {
    if (example instanceof RVFDatum<?,?>) {
      return probabilityOfRVFDatum((RVFDatum<L,F>)example);
    }
    return probabilityOf(example.asFeatures(), example.label());
  }

  public double probabilityOf(Collection<F> features, L label) {
    short sign = (short)(label.equals(classes[0]) ? 1 : -1);
    return 1.0 / (1.0 + Math.exp(sign * scoreOf(features)));
  }

  public double probabilityOf(RVFDatum<L, F> example) {
    return probabilityOfRVFDatum(example);
  }

  private double probabilityOfRVFDatum(RVFDatum<L, F> example) {
    return probabilityOf(example.asFeaturesCounter(), example.label());
  }

  public double probabilityOf(Counter<F> features, L label) {
    short sign = (short)(label.equals(classes[0]) ? 1 : -1);
    return 1.0 / (1.0 + Math.exp(sign * scoreOf(features)));
  }

  /**
   * Trains on weighted dataset.
   * @param dataWeights weights of the data.
   */
  @Deprecated //Use LogisticClassifierFactory to train instead.
  public void trainWeightedData(GeneralDataset<L,F> data, float[] dataWeights){
    if (data.labelIndex.size() != 2) {
      throw new RuntimeException("LogisticClassifier is only for binary classification!");
    }

    Minimizer<DiffFunction> minim;
      LogisticObjectiveFunction lof = null;
      if(data instanceof Dataset<?,?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior,dataWeights);
      else if(data instanceof RVFDataset<?,?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), prior,dataWeights);
      minim = new QNMinimizer(lof);
      weights = minim.minimize(lof, 1e-4, new double[data.numFeatureTypes()]);

    featureIndex = data.featureIndex;
    classes[0] = data.labelIndex.get(0);
    classes[1] = data.labelIndex.get(1);
  }

  @Deprecated //Use LogisticClassifierFactory to train instead.
  public void train(GeneralDataset<L, F> data) {
    train(data, 0.0, 1e-4);
  }

  @Deprecated //Use LogisticClassifierFactory to train instead.
  public void train(GeneralDataset<L, F> data, double l1reg, double tol) {
    if (data.labelIndex.size() != 2) {
      throw new RuntimeException("LogisticClassifier is only for binary classification!");
    }

    Minimizer<DiffFunction> minim;
    if (!biased) {
      LogisticObjectiveFunction lof = null;
      if(data instanceof Dataset<?,?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
      else if(data instanceof RVFDataset<?,?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), prior);
      if (l1reg > 0.0) {
        minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
      } else {
        minim = new QNMinimizer(lof);
      }
      weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
    } else {
      BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
      if (l1reg > 0.0) {
        minim = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
      } else {
        minim = new QNMinimizer(lof);
      }
      weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
    }

    featureIndex = data.featureIndex;
    classes[0] = data.labelIndex.get(0);
    classes[1] = data.labelIndex.get(1);
  }


  public static void main(String[] args) throws Exception {
    Properties prop = StringUtils.argsToProperties(args);

    double l1reg = Double.parseDouble(prop.getProperty("l1reg","0.0"));

    Dataset<String, String> ds = new Dataset<String, String>();
    for (String line : ObjectBank.getLineIterator(new File(prop.getProperty("trainFile")))) {
      String[] bits = line.split("\\s+");
      Collection<String> f = new LinkedList<String>(Arrays.asList(bits).subList(1, bits.length));
      String l = bits[0];
      ds.add(f, l);
    }

    ds.summaryStatistics();

    boolean biased = prop.getProperty("biased", "false").equals("true");
    LogisticClassifierFactory<String, String> factory = new LogisticClassifierFactory<String, String>();
    LogisticClassifier<String, String> lc = factory.trainClassifier(ds, l1reg, 1e-4, biased);


    for (String line : ObjectBank.getLineIterator(new File(prop.getProperty("testFile")))) {
      String[] bits = line.split("\\s+");
      Collection<String> f = new LinkedList<String>(Arrays.asList(bits).subList(1, bits.length));
      //String l = bits[0];
      String g = lc.classOf(f);
      System.out.println(g + '\t' + line);
    }
  }


}
TOP

Related Classes of edu.stanford.nlp.classify.LogisticClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.