Package edu.stanford.nlp.classify

Source Code of edu.stanford.nlp.classify.NaiveBayesClassifierFactory

// Stanford Classifier - a multiclass maxent classifier
// NaiveBayesClassifierFactory
// Copyright (c) 2003-2007 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: java-nlp-user@lists.stanford.edu
//    Licensing: java-nlp-support@lists.stanford.eduu
//    http://www-nlp.stanford.edu/software/classifier.shtml

package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.HashIndex;

import java.util.*;


/** Creates a NaiveBayesClassifier given an RVFDataset.
*
* @author Kristina Toutanova (kristina@cs.stanford.edu)
*/
public class NaiveBayesClassifierFactory<L, F> implements ClassifierFactory<L, F, NaiveBayesClassifier<L, F>> {

  private static final long serialVersionUID = -8164165428834534041L;
  public static final int JL = 0;
  public static final int CL = 1;
  public static final int UCL = 2;
  int kind = JL;
  double alphaClass;
  double alphaFeature;
  double sigma;
  int prior = LogPrior.LogPriorType.NULL.ordinal();
  Index<L> labelIndex;
  Index<F> featureIndex;

  public NaiveBayesClassifierFactory() {
  }

  public NaiveBayesClassifierFactory(double alphaC, double alphaF, double sigma, int prior, int kind) {
    alphaClass = alphaC;
    alphaFeature = alphaF;
    this.sigma = sigma;
    this.prior = prior;
    this.kind = kind;
  }

  private NaiveBayesClassifier<L, F> trainClassifier(int[][] data, int[] labels, int numFeatures,
      int numClasses, Index<L> labelIndex, Index<F> featureIndex) {
    Set<L> labelSet = Generics.newHashSet();
    NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses);
    Counter<L> priors = new ClassicCounter<L>();
    double[] pr = nbWeights.priors;
    for (int i = 0; i < pr.length; i++) {
      priors.incrementCount(labelIndex.get(i), pr[i]);
      labelSet.add(labelIndex.get(i));
    }
    Counter<Pair<Pair<L, F>, Number>> weightsCounter = new ClassicCounter<Pair<Pair<L, F>, Number>>();
    double[][][] wts = nbWeights.weights;
    for (int c = 0; c < numClasses; c++) {
      L label = labelIndex.get(c);
      for (int f = 0; f < numFeatures; f++) {
        F feature = featureIndex.get(f);
        Pair<L, F> p = new Pair<L, F>(label, feature);
        for (int val = 0; val < wts[c][f].length; val++) {
          Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val));
          weightsCounter.incrementCount(key, wts[c][f][val]);
        }
      }
    }
    return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet);

  }

  /**
   * The examples are assumed to be a list of RFVDatum.
   * The datums are assumed to contain the zeroes as well.
   */
  @Override
  @Deprecated
  public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
    RVFDatum<L, F> d0 = examples.get(0);
    int numFeatures = d0.asFeatures().size();
    int[][] data = new int[examples.size()][numFeatures];
    int[] labels = new int[examples.size()];
    labelIndex = new HashIndex<L>();
    featureIndex = new HashIndex<F>();
    for (int d = 0; d < examples.size(); d++) {
      RVFDatum<L, F> datum = examples.get(d);
      Counter<F> c = datum.asFeaturesCounter();
      for (F feature: c.keySet()) {
        if(featureIndex.add(feature)) {
          int fNo = featureIndex.indexOf(feature);
          int value = (int) c.getCount(feature);
          data[d][fNo] = value;
        }
      }
      labelIndex.add(datum.label());
      labels[d] = labelIndex.indexOf(datum.label());

    }
    int numClasses = labelIndex.size();
    return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
  }


  /**
   * The examples are assumed to be a list of RFVDatum.
   * The datums are assumed to not contain the zeroes and then they are added to each instance.
   */
  public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples, Set<F> featureSet) {
    int numFeatures = featureSet.size();
    int[][] data = new int[examples.size()][numFeatures];
    int[] labels = new int[examples.size()];
    labelIndex = new HashIndex<L>();
    featureIndex = new HashIndex<F>();
    for (F feat : featureSet) {
      featureIndex.add(feat);
    }
    for (int d = 0; d < examples.size(); d++) {
      RVFDatum<L, F> datum = examples.get(d);
      Counter<F> c = datum.asFeaturesCounter();
      for (F feature : c.keySet()) {
        int fNo = featureIndex.indexOf(feature);
        int value = (int) c.getCount(feature);
        data[d][fNo] = value;
      }
      labelIndex.add(datum.label());
      labels[d] = labelIndex.indexOf(datum.label());

    }
    int numClasses = labelIndex.size();
    return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
  }


  /**
   * Here the data is assumed to be for every instance, array of length numFeatures
   * and the value of the feature is stored including zeroes.
   *
   * @return label,fno,value -> weight
   */
  private NBWeights trainWeights(int[][] data, int[] labels, int numFeatures, int numClasses) {
    if (kind == JL) {
      return trainWeightsJL(data, labels, numFeatures, numClasses);
    }
    if (kind == UCL) {
      return trainWeightsUCL(data, labels, numFeatures, numClasses);
    }
    if (kind == CL) {
      return trainWeightsCL(data, labels, numFeatures, numClasses);
    }
    return null;
  }

  private NBWeights trainWeightsJL(int[][] data, int[] labels, int numFeatures, int numClasses) {
    int[] numValues = numberValues(data, numFeatures);
    double[] priors = new double[numClasses];
    double[][][] weights = new double[numClasses][numFeatures][];
    //init weights array
    for (int cl = 0; cl < numClasses; cl++) {
      for (int fno = 0; fno < numFeatures; fno++) {
        weights[cl][fno] = new double[numValues[fno]];
      }
    }
    for (int i = 0; i < data.length; i++) {
      priors[labels[i]]++;
      for (int fno = 0; fno < numFeatures; fno++) {
        weights[labels[i]][fno][data[i][fno]]++;
      }
    }
    for (int cl = 0; cl < numClasses; cl++) {
      for (int fno = 0; fno < numFeatures; fno++) {
        for (int val = 0; val < numValues[fno]; val++) {
          weights[cl][fno][val] = Math.log((weights[cl][fno][val] + alphaFeature) / (priors[cl] + alphaFeature * numValues[fno]));
        }
      }
      priors[cl] = Math.log((priors[cl] + alphaClass) / (data.length + alphaClass * numClasses));
    }
    return new NBWeights(priors, weights);
  }

  private NBWeights trainWeightsUCL(int[][] data, int[] labels, int numFeatures, int numClasses) {
    int[] numValues = numberValues(data, numFeatures);
    int[] sumValues = new int[numFeatures]; //how many feature-values are before this feature
    for (int j = 1; j < numFeatures; j++) {
      sumValues[j] = sumValues[j - 1] + numValues[j - 1];
    }
    int[][] newdata = new int[data.length][numFeatures + 1];
    for (int i = 0; i < data.length; i++) {
      newdata[i][0] = 0;
      for (int j = 0; j < numFeatures; j++) {
        newdata[i][j + 1] = sumValues[j] + data[i][j] + 1;
      }
    }
    int totalFeatures = sumValues[numFeatures - 1] + numValues[numFeatures - 1] + 1;
    System.err.println("total feats " + totalFeatures);
    LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(totalFeatures, numClasses, newdata, labels, prior, sigma, 0.0);
    Minimizer<DiffFunction> min = new QNMinimizer();
    double[] argmin = min.minimize(objective, 1e-4, objective.initial());
    double[][] wts = objective.to2D(argmin);
    System.out.println("weights have dimension " + wts.length);
    return new NBWeights(wts, numValues);
  }


  private NBWeights trainWeightsCL(int[][] data, int[] labels, int numFeatures, int numClasses) {

    LogConditionalEqConstraintFunction objective = new LogConditionalEqConstraintFunction(numFeatures, numClasses, data, labels, prior, sigma, 0.0);
    Minimizer<DiffFunction> min = new QNMinimizer();
    double[] argmin = min.minimize(objective, 1e-4, objective.initial());
    double[][][] wts = objective.to3D(argmin);
    double[] priors = objective.priors(argmin);
    return new NBWeights(priors, wts);
  }

  static int[] numberValues(int[][] data, int numFeatures) {
    int[] numValues = new int[numFeatures];
    for (int i = 0; i < data.length; i++) {
      for (int j = 0; j < data[i].length; j++) {
        if (numValues[j] < data[i][j] + 1) {
          numValues[j] = data[i][j] + 1;
        }
      }
    }
    return numValues;
  }

  static class NBWeights {
    double[] priors;
    double[][][] weights;

    NBWeights(double[] priors, double[][][] weights) {
      this.priors = priors;
      this.weights = weights;
    }

    /**
     * create the parameters from a coded representation
     * where feature 0 is the prior etc.
     *
     */
    NBWeights(double[][] wts, int[] numValues) {
      int numClasses = wts[0].length;
      priors = new double[numClasses];
      System.arraycopy(wts[0], 0, priors, 0, numClasses);
      int[] sumValues = new int[numValues.length];
      for (int j = 1; j < numValues.length; j++) {
        sumValues[j] = sumValues[j - 1] + numValues[j - 1];
      }
      weights = new double[priors.length][sumValues.length][];
      for (int fno = 0; fno < numValues.length; fno++) {
        for (int c = 0; c < numClasses; c++) {
          weights[c][fno] = new double[numValues[fno]];
        }

        for (int val = 0; val < numValues[fno]; val++) {
          int code = sumValues[fno] + val + 1;
          for (int cls = 0; cls < numClasses; cls++) {
            weights[cls][fno][val] = wts[code][cls];
          }
        }
      }
    }
  }

  public static void main(String[] args) {

    /*
    List examples = new ArrayList();
    String leftLight = "leftLight";
    String rightLight = "rightLight";
    String broken = "BROKEN";
    String ok = "OK";
    Counter c1 = new Counter();
    c1.incrementCount(leftLight, 0);
    c1.incrementCount(rightLight, 0);
    RVFDatum d1 = new RVFDatum(c1, broken);
    examples.add(d1);
    Counter c2 = new Counter();
    c2.incrementCount(leftLight, 1);
    c2.incrementCount(rightLight, 1);
    RVFDatum d2 = new RVFDatum(c2, ok);
    examples.add(d2);
    Counter c3 = new Counter();
    c3.incrementCount(leftLight, 0);
    c3.incrementCount(rightLight, 1);
    RVFDatum d3 = new RVFDatum(c3, ok);
    examples.add(d3);
    Counter c4 = new Counter();
    c4.incrementCount(leftLight, 1);
    c4.incrementCount(rightLight, 0);
    RVFDatum d4 = new RVFDatum(c4, ok);
    examples.add(d4);
    NaiveBayesClassifier classifier = (NaiveBayesClassifier) new NaiveBayesClassifierFactory(200, 200, 1.0, LogPrior.QUADRATIC.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(examples);
    classifier.print();
    //now classifiy
    for (int i = 0; i < examples.size(); i++) {
        RVFDatum d = (RVFDatum) examples.get(i);
        Counter scores = classifier.scoresOf(d);
        System.out.println("for datum " + d + " scores are " + scores.toString());
        System.out.println(" class is " + scores.argmax());
    }

}
*/
    String trainFile = args[0];
    String testFile = args[1];
    NominalDataReader nR = new NominalDataReader();
    Map<Integer, Index<String>> indices = Generics.newHashMap();
    List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
    List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);
    System.out.println("Constrained conditional likelihood no prior :");
    for (int j = 0; j < 100; j++) {
      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train);
      classifier.print();
      //now classifiy

      float accTrain = classifier.accuracy(train.iterator());
      System.err.println("training accuracy " + accTrain);
      float accTest = classifier.accuracy(test.iterator());
      System.err.println("test accuracy " + accTest);

    }
    System.out.println("Unconstrained conditional likelihood no prior :");
    for (int j = 0; j < 100; j++) {
      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train);
      classifier.print();
      //now classify

      float accTrain = classifier.accuracy(train.iterator());
      System.err.println("training accuracy " + accTrain);
      float accTest = classifier.accuracy(test.iterator());
      System.err.println("test accuracy " + accTest);

    }
  }

  @Override
  public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
    if(dataset instanceof RVFDataset){
      throw new RuntimeException("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
    }
    return trainClassifier(dataset.getDataArray(), dataset.labels, dataset.numFeatures(),
        dataset.numClasses(), dataset.labelIndex, dataset.featureIndex);
  }

}
TOP

Related Classes of edu.stanford.nlp.classify.NaiveBayesClassifierFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.