Package classification

Source Code of classification.AdaBoost

package classification;

import java.util.ArrayList;
import java.util.Random;

import types.Alphabet;
import types.ClassificationInstance;
import types.FeatureFunction;
import types.LinearClassifier;
import types.SparseVector;
import types.StaticUtils;

public class AdaBoost {

  int numIterations;
  Alphabet xAlphabet;
  Alphabet yAlphabet;
  FeatureFunction fxy;
  double smooth = 0.01;

  public AdaBoost(int numIterations, Alphabet xAlphabet, Alphabet yAlphabet,
      FeatureFunction fxy) {
    this.numIterations = numIterations;
    this.xAlphabet = xAlphabet;
    this.yAlphabet = yAlphabet;
    this.fxy = fxy;
  }

  public void printArray(double[] a) {
    for (int i = 0; i < a.length; i++) {
      System.out.print(a[i] + " ");
    }
    System.out.println();
  }

  public LinearClassifier batchTrain(
      ArrayList<ClassificationInstance> trainingData) {
    LinearClassifier result = new LinearClassifier(xAlphabet, yAlphabet,
        fxy);
    double[] w = new double[trainingData.size()];
    for (int i = 0; i < w.length; i++)
      w[i] = 1.0 / trainingData.size();
    // choose $t$ weights.
    double[] correct = new double[fxy.wSize()];
    double[] wrong = new double[fxy.wSize()];

    @SuppressWarnings("unused")
    int oldBest = 0;
    for (int iter = 0; iter < numIterations; iter++) {
      computeAccuracies(correct, wrong, trainingData, w);
      // System.out.println();
      // System.out.print(" correct = ");
      // printArray(correct);
      // System.out.print(" wrong = ");
      // printArray(wrong);
      int bestFeature = chooseBest(correct, wrong);
      double alpha = Math
          .log((correct[bestFeature]) / wrong[bestFeature]) / 2;
      result.w[bestFeature] += alpha;
      updateW(w, bestFeature, alpha, trainingData);
      // System.out.print(" w = ");
      // printArray(w);
    }
    return result;
  }

  private int chooseBest(double[] correct, double[] wrong) {
    int res = 0;
    double bestval = Double.MIN_VALUE;
    for (int i = 0; i < correct.length; i++) {
      double val = correct[i] - wrong[i];
      if (val > bestval) {
        res = i;
        bestval = val;
      }
    }
    return res;
  }

  private void updateW(double[] w, int bestFeature, double alpha,
      ArrayList<ClassificationInstance> trainingData) {
    double wrongUpdate = Math.exp(alpha);
    double correctUpdate = Math.exp(-alpha);
    for (int instInd = 0; instInd < trainingData.size(); instInd++) {
      ClassificationInstance inst = trainingData.get(instInd);
      for (int y = 0; y < yAlphabet.size(); y++) {
        SparseVector fv = fxy.apply(inst.x, y);
        for (int i = 0; i < fv.numEntries(); i++) {
          if (fv.getIndexAt(i) == bestFeature) {
            if (y == inst.y)
              w[instInd] *= correctUpdate;
            else
              w[instInd] *= wrongUpdate;
          }
        }
      }
    }
    double sum = StaticUtils.sum(w);
    for (int i = 0; i < w.length; i++) {
      w[i] /= sum;
    }

  }

  private void computeAccuracies(double[] correct, double[] wrongs,
      ArrayList<ClassificationInstance> trainingData, double[] w) {
    double total = 2 * smooth;
    for (int i = 0; i < correct.length; i++) {
      correct[i] = smooth;
      wrongs[i] = smooth;
    }
    for (int instInd = 0; instInd < trainingData.size(); instInd++) {
      ClassificationInstance inst = trainingData.get(instInd);
      total += w[instInd];
      for (int y = 0; y < yAlphabet.size(); y++) {
        SparseVector fv = fxy.apply(inst.x, y);
        if (y == inst.y) {
          for (int i = 0; i < fv.numEntries(); i++) {
            correct[fv.getIndexAt(i)] += w[instInd];
          }
        } else {
          for (int i = 0; i < fv.numEntries(); i++) {
            wrongs[fv.getIndexAt(i)] += w[instInd];
          }
        }
      }
    }
    for (int i = 0; i < correct.length; i++) {
      correct[i] /= total;
      wrongs[i] /= total;
    }
  }

  public static void main(String[] args) {
    ArrayList<ClassificationInstance> train = new ArrayList<ClassificationInstance>();
    Alphabet xAlphabet = new Alphabet();
    Alphabet yAlphabet = new Alphabet();
    String[] classes = new String[] { "a", "b" };
    Random r = new Random(10);
    int numFeats = 5;
    double randomFrac = 0.5;
    double missingFrac = 0.5;
    for (int instInd = 0; instInd < 10; instInd++) {
      String label = classes[r.nextInt(classes.length)];
      SparseVector sv = new SparseVector();
      for (int fInd = 0; fInd < numFeats; fInd++) {
        if (r.nextDouble() < missingFrac)
          continue;
        String tmpLab = label;
        if (r.nextDouble() < randomFrac)
          tmpLab = classes[r.nextInt(classes.length)];
        sv.add(xAlphabet.lookupObject(tmpLab + fInd), 1);
      }
      train.add(new ClassificationInstance(xAlphabet, yAlphabet, sv,
          label));
    }
    AdaBoost boost = new AdaBoost(10, xAlphabet, yAlphabet,
        new CompleteFeatureFunction(xAlphabet, yAlphabet));
    LinearClassifier h = boost.batchTrain(train);
    System.out.println(StaticUtils.computeAccuracy(h, train));
  }

}
TOP

Related Classes of classification.AdaBoost

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.