Package de.jungblut.classification.meta

Source Code of de.jungblut.classification.meta.Voter

package de.jungblut.classification.meta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.partition.BlockPartitioner;
import de.jungblut.partition.Boundaries.Range;

/**
* Implementation of vote ensembling. This features multithreading, different
* combination and selection techniques.
*
* @author thomas.jungblut
*
*/
public final class Voter<A extends Classifier> extends AbstractClassifier {

  private static final Log LOG = LogFactory.getLog(Voter.class);

  public static enum CombiningType {
    MAJORITY, AVERAGE, PROBABILITY
  }

  public static enum SelectionType {
    NONE, SHUFFLE, BAGGING
  }

  private final Classifier[] classifier;

  private CombiningType type;
  private SelectionType selection = SelectionType.NONE;
  private int threads = 1;
  private boolean verbose;

  private Voter(CombiningType type, int numClassifiers,
      ClassifierFactory<A> classifierFactory) {
    this.type = type;
    this.classifier = new Classifier[numClassifiers];
    for (int i = 0; i < numClassifiers; i++) {
      this.classifier[i] = classifierFactory.newInstance();
    }
  }

  @Override
  public void train(DoubleVector[] features, DoubleVector[] outcome) {
    ExecutorService pool = Executors.newFixedThreadPool(threads);
    try {
      ExecutorCompletionService<Boolean> completionService = new ExecutorCompletionService<>(
          pool);

      // do the selection of the data
      List<TrainingSplit> splits = null;
      switch (selection) {
        case BAGGING:
          splits = bag(features, outcome);
          break;
        case SHUFFLE:
          splits = partition(features, outcome, true);
          break;
        default:
          splits = partition(features, outcome, false);
          break;
      }

      // submit to the threadpool
      for (int i = 0; i < classifier.length; i++) {
        completionService.submit(new TrainingWorker(classifier[i], splits
            .get(i)));
      }

      // let the training happen meanwhile wait for the result
      for (int i = 0; i < classifier.length; i++) {
        completionService.take();
        if (verbose) {
          LOG.info("Finished with training classifier " + (i + 1) + " of "
              + classifier.length);
        }

      }
    } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      pool.shutdownNow();
    }
    if (verbose) {
      LOG.info("Successfully finished training!");
    }

  }

  @Override
  public DoubleVector predict(DoubleVector features) {
    // predict
    DoubleVector[] result = new DoubleVector[classifier.length];
    for (int i = 0; i < classifier.length; i++) {
      result[i] = classifier[i].predict(features);
    }
    int numPossibleOutcomes = result[0].getDimension() == 1 ? 2 : result[0]
        .getDimension();
    DoubleVector toReturn = new DenseDoubleVector(
        result[0].getDimension() == 1 ? 1 : numPossibleOutcomes);
    // now combine the results based on the rule
    switch (type) {
      case MAJORITY:
        double[] histogram = createPredictionHistogram(result,
            numPossibleOutcomes);
        if (numPossibleOutcomes == 2) {
          toReturn.set(0, ArrayUtils.maxIndex(histogram));
        } else {
          toReturn.set(ArrayUtils.maxIndex(histogram), 1d);
        }
        break;
      case PROBABILITY:
        histogram = createPredictionHistogram(result, numPossibleOutcomes);
        double histSum = 0;
        for (double d : histogram) {
          histSum += d;
        }
        if (numPossibleOutcomes == 2) {
          toReturn.set(0, histogram[1] / histSum);
        } else {
          for (int i = 0; i < histogram.length; i++) {
            toReturn.set(i, histogram[i] / histSum);
          }
        }
        break;
      case AVERAGE:
        for (int i = 0; i < result.length; i++) {
          toReturn = toReturn.add(result[i]);
        }
        toReturn = toReturn.divide(classifier.length);
        break;
      default:
        throw new UnsupportedOperationException("Type " + type
            + " isn't supported yet!");
    }
    return toReturn;
  }

  public Classifier[] getClassifier() {
    return classifier;
  }

  /**
   * @return sets this instance to verbose and returns it.
   */
  public Voter<A> verbose() {
    return verbose(true);
  }

  /**
   * @return sets this instance to verbose and returns it.
   */
  public Voter<A> verbose(boolean verb) {
    this.verbose = verb;
    return this;
  }

  /**
   * @return sets the selection type and returns this instance.
   */
  public Voter<A> selectionType(SelectionType type) {
    this.selection = type;
    return this;
  }

  /**
   * @return sets the number of threads and returns this instance.
   */
  public Voter<A> numThreads(int threads) {
    this.threads = threads;
    return this;
  }

  /**
   * @return sets the used combination type in this instance and returns it.
   */
  public Voter<A> setCombiningType(CombiningType type) {
    this.type = type;
    return this;
  }

  private double[] createPredictionHistogram(DoubleVector[] result,
      int possibleOutcomes) {
    double[] histogram = new double[possibleOutcomes];
    for (int i = 0; i < classifier.length; i++) {
      int clz = classifier[i].extractPredictedClass(result[i]);
      histogram[clz]++;
    }
    return histogram;
  }

  private List<TrainingSplit> bag(DoubleVector[] features,
      DoubleVector[] outcome) {
    List<TrainingSplit> splits = new ArrayList<>(classifier.length);
    Random rand = new Random();
    for (int i = 0; i < classifier.length; i++) {

      DoubleVector[] featureBag = new DoubleVector[features.length];
      DoubleVector[] outcomeBag = new DoubleVector[features.length];
      // bagging is basically filling random items into these arrays
      for (int n = 0; n < features.length; n++) {
        int nextInt = rand.nextInt(features.length);
        featureBag[n] = features[nextInt];
        outcomeBag[n] = outcome[nextInt];
      }
      splits.add(new TrainingSplit(featureBag, outcomeBag));
    }

    return splits;
  }

  private List<TrainingSplit> partition(DoubleVector[] features,
      DoubleVector[] outcome, boolean shuffle) {

    List<TrainingSplit> splits = new ArrayList<>(classifier.length);
    if (shuffle) {
      ArrayUtils.multiShuffle(features, outcome);
    }

    List<Range> partitions = new ArrayList<>(new BlockPartitioner().partition(
        classifier.length, features.length).getBoundaries());

    final int[] splitRanges = new int[classifier.length + 1];
    for (int i = 1; i < classifier.length; i++) {
      splitRanges[i] = partitions.get(i).getStart();
    }
    splitRanges[classifier.length] = features.length - 1;

    if (verbose) {
      LOG.info("Computed split ranges for 0-" + features.length + ": "
          + Arrays.toString(splitRanges) + "\n");
    }

    for (int i = 0; i < classifier.length; i++) {
      DoubleVector[] featureSplit = ArrayUtils.subArray(features,
          splitRanges[i], splitRanges[i + 1]);
      DoubleVector[] outcomeSplit = ArrayUtils.subArray(outcome,
          splitRanges[i], splitRanges[i + 1]);
      splits.add(new TrainingSplit(featureSplit, outcomeSplit));
    }

    return splits;
  }

  /**
   * Creates a new voting classificator. The training is single threaded, no
   * shuffling or bagging takes place.
   *
   * @param numClassifiers the number of base classifiers to use.
   * @param type the combining type to use.
   * @param classifierFactory the classifier factory to create base classifiers.
   * @return a new Voter.
   */
  public static <K extends Classifier> Voter<K> create(int numClassifiers,
      CombiningType type, ClassifierFactory<K> classifierFactory) {
    return new Voter<>(type, numClassifiers, classifierFactory);
  }

  final class TrainingWorker implements Callable<Boolean> {

    private final Classifier cls;
    private final TrainingSplit split;

    TrainingWorker(Classifier classifier, TrainingSplit split) {
      this.cls = classifier;
      this.split = split;
    }

    @Override
    public Boolean call() throws Exception {

      cls.train(split.getTrainFeatures(), split.getTrainOutcome());

      return true;
    }
  }

}
TOP

Related Classes of de.jungblut.classification.meta.Voter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.