Package fr.lip6.jkernelmachines.classifier.multiclass

Source Code of fr.lip6.jkernelmachines.classifier.multiclass.OneAgainstAll

/**
    This file is part of JkernelMachines.

    JkernelMachines is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    JkernelMachines is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with JkernelMachines.  If not, see <http://www.gnu.org/licenses/>.

    Copyright David Picard - 2012

*/
package fr.lip6.jkernelmachines.classifier.multiclass;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;

import fr.lip6.jkernelmachines.classifier.Classifier;
import fr.lip6.jkernelmachines.threading.ThreadPoolServer;
import fr.lip6.jkernelmachines.type.TrainingSample;
import fr.lip6.jkernelmachines.util.DebugPrinter;

/**
* <p>
* Multiclass classifier with N times One against All scheme.
* </p>
* <p>
* The classification algorithm for each case is not set in this classifier, and
* should be provided.
* </p>
*
* @author picard
*
*/
public class OneAgainstAll<T> implements MulticlassClassifier<T> {

  Classifier<T> baseClassifier;

  List<Integer> classIndices;
  List<Classifier<T>> listOfClassifiers;
  List<TrainingSample<T>> tlist;
  int nbclasses = 0;

  DebugPrinter debug = new DebugPrinter();

  /**
   * <p>
   * Default constructor with underlying classifier algorithm.
   * </p>
   * <p>
   * The classifier given as argument is cloned N times at each training, in
   * order to provide a binary classification for each category.
   * </p>
   *
   * @param c
   *            the classifier class to use
   */
  public OneAgainstAll(Classifier<T> c) {
    baseClassifier = c;
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.classifier.Classifier#train(fr.lip6.type.TrainingSample)
   */
  @Override
  public void train(TrainingSample<T> t) {
    if (tlist == null)
      tlist = new ArrayList<TrainingSample<T>>();

    tlist.add(t);
    train();
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.classifier.Classifier#train(java.util.List)
   */
  @Override
  public void train(List<TrainingSample<T>> l) {
    tlist = new ArrayList<TrainingSample<T>>();
    tlist.addAll(l);

    train();
  }

  private void train() {
    // init
    classIndices = new ArrayList<Integer>();
    listOfClassifiers = new ArrayList<Classifier<T>>();

    // count classes
    nbclasses = 0;
    for (TrainingSample<T> t : tlist) {
      if (!classIndices.contains(t.label)) {
        classIndices.add(t.label);
        nbclasses++;
        // init classifiers
        listOfClassifiers.add(null);
      }
    }
    debug.println(1, "Number of Classes: " + nbclasses);

    ThreadPoolExecutor ex = ThreadPoolServer.getThreadPoolExecutor();
    List<Future<Object>> futures = new ArrayList<>();

    // learning N one against all classifiers
    for (int id = 0; id < nbclasses; id++) {
      final int i = id;
      futures.add(ex.submit(new Callable<Object>() {

        @Override
        public Object call() throws Exception {

          Classifier<T> cls = null;
          int c = 0;

          synchronized (listOfClassifiers) {
            c = classIndices.get(i);

            // building classifier
            try {
              cls = (Classifier<T>) baseClassifier.copy();
            } catch (Exception e) {
              debug.println(1, "ERROR: Classifier not Cloneable!");
              throw new UnsupportedOperationException(
                  baseClassifier.getClass().getSimpleName()
                      + " is not clonable.");
            }
          }

          debug.println(2, i + ": learning!");
          // building ad hoc trai list
          List<TrainingSample<T>> train = new ArrayList<TrainingSample<T>>();
          for (TrainingSample<T> t : tlist) {
            int y = -1;
            if (t.label == c)
              y = 1;
            train.add(new TrainingSample<T>(t.sample, y));
          }

          // training
          cls.train(train);

          // storing
          synchronized (listOfClassifiers) {
            listOfClassifiers.set(i, cls);
          }

          debug.println(1, i + ": done!");
          return null;
        }

      }));
    }

    for (Future<Object> f : futures) {
      try {
        f.get();
      } catch (InterruptedException e) {
        debug.println(1, "Error in learning on classifier");
        e.printStackTrace();
        throw new RuntimeException("Failed threading training");
      } catch (ExecutionException e) {
        debug.println(1, "Error in learning on classifier");
        e.printStackTrace();
        throw new RuntimeException("Failed threading training");
      }
    }

    ThreadPoolServer.shutdownNow(ex);
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.classifier.Classifier#valueOf(java.lang.Object)
   */
  @Override
  public double valueOf(T e) {
    final T t = e;
    if (listOfClassifiers == null || listOfClassifiers.isEmpty())
      return 0;

    final double[] values = new double[listOfClassifiers.size()];

    if (nbclasses > 2 * Runtime.getRuntime().availableProcessors()) {
      ThreadPoolExecutor ex = ThreadPoolServer.getThreadPoolExecutor();
      List<Future<Object>> futures = new ArrayList<>(
          listOfClassifiers.size());
      for (int i = 0; i < listOfClassifiers.size(); i++) {
        final int id = i;
        futures.add(ex.submit(new Callable<Object>() {

          @Override
          public Object call() throws Exception {
            values[id] = listOfClassifiers.get(id).valueOf(t);
            return null;
          }

        }));
      }

      for (Future<Object> f : futures) {
        try {
          f.get();
        } catch (InterruptedException | ExecutionException e1) {
          debug.println(1, "unable to thread evaluation");
          e1.printStackTrace();
          return -1;
        }
      }

      ThreadPoolServer.shutdownNow(ex);
    }
    else {
      for(int i = 0 ; i < nbclasses ; i++) {
        values[i] = listOfClassifiers.get(i).valueOf(e);
      }
    }

    // find max output
    int imax = -1;
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < listOfClassifiers.size(); i++) {
      if (values[i] > max) {
        max = values[i];
        imax = i;
      }
    }
    // return class corresponding to this output
    return classIndices.get(imax);
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.jkernelmachines.classifier.multiclass.MulticlassClassifier#
   * getConfidence(java.lang.Object)
   */
  @Override
  public double getConfidence(T e) {
    if (listOfClassifiers == null || listOfClassifiers.isEmpty())
      return 0;

    // find max output
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < listOfClassifiers.size(); i++) {
      double v = listOfClassifiers.get(i).valueOf(e);
      if (v > max) {
        max = v;
      }
    }

    return max;
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.jkernelmachines.classifier.multiclass.MulticlassClassifier#
   * getConfidences(java.lang.Object)
   */
  @Override
  public Map<Integer, Double> getConfidences(T e) {
    if (listOfClassifiers == null || listOfClassifiers.isEmpty())
      return null;

    HashMap<Integer, Double> map = new HashMap<>();
    for (int i = 0; i < listOfClassifiers.size(); i++) {
      map.put(classIndices.get(i), listOfClassifiers.get(i).valueOf(e));
    }
    return map;
  }

  /**
   * Returns the list of one against all classifiers used
   *
   * @return the list of classifiers
   */
  public List<Classifier<T>> getListOfClassifiers() {
    return listOfClassifiers;
  }

  /**
   * Returns a map with class labels as keys and corresponding one against all
   * classifiers as values
   *
   * @return the map of labels, classifiers
   */
  public Map<Integer, Classifier<T>> getMapOfClassifiers() {
    Map<Integer, Classifier<T>> map = new HashMap<Integer, Classifier<T>>();
    for (int i = 0; i < classIndices.size(); i++) {
      map.put(classIndices.get(i), listOfClassifiers.get(i));
    }
    return map;
  }

  /*
   * (non-Javadoc)
   *
   * @see fr.lip6.classifier.Classifier#copy()
   */
  @SuppressWarnings("unchecked")
  @Override
  public OneAgainstAll<T> copy() throws CloneNotSupportedException {
    return (OneAgainstAll<T>) super.clone();
  }

}
TOP

Related Classes of fr.lip6.jkernelmachines.classifier.multiclass.OneAgainstAll

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.