Package de.jungblut.math.minimize

Source Code of de.jungblut.math.minimize.AbstractMiniBatchCostFunction$CallableMiniBatch

package de.jungblut.math.minimize;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;

import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.partition.Boundaries.Range;

/**
* Mini Batch cost function. It features parallel calculation of mini batches
* and averaging of the error and gradient. In addition you can calculate
* mini-batches in a stochastic fashion, that means that every called iteration
* a new mini-batch will be evaluated, its gradient will immediately be
* returned. <br/>
* <br/>
* The resulting costfunction can be minimized with every normal
* {@link Minimizer}. The extending cost functions should be as stateless as
* possible.
*
* @author thomas.jungblut
*
*/
public abstract class AbstractMiniBatchCostFunction implements CostFunction {

  private final Executor pool;
  private final List<Tuple<DoubleMatrix, DoubleMatrix>> batches;
  private final boolean stochastic;
  // offset for minibatches in stochastic mode
  private int batchOffset = 0;

  class CallableMiniBatch implements Callable<CostGradientTuple> {

    private final DoubleVector parameters;
    private final Tuple<DoubleMatrix, DoubleMatrix> featureOutcome;

    public CallableMiniBatch(Tuple<DoubleMatrix, DoubleMatrix> featureOutcome,
        DoubleVector parameters) {
      super();
      this.featureOutcome = featureOutcome;
      this.parameters = parameters;
    }

    @Override
    public CostGradientTuple call() throws Exception {
      return evaluateBatch(parameters, featureOutcome.getFirst(),
          featureOutcome.getSecond());
    }
  }

  /**
   * An abstract minibatch costfunction. This trains on all batches and averages
   * their results (gradient&cost).
   *
   * @param inputMatrix the data to crunch, bias will be added while calculating
   *          the batches.
   * @param outcomeMatrix the data that denotes the outcome of the features-
   *          while supervised learning. Can be null for unsupervised methods.
   * @param batchSize the batch size to use, 0 denotes full batch learning by
   *          default.
   * @param numThreads the number of threads to use to calculate the batches.
   */
  public AbstractMiniBatchCostFunction(DoubleVector[] inputMatrix,
      DoubleVector[] outcomeMatrix, int batchSize, int numThreads) {
    this(inputMatrix, outcomeMatrix, batchSize, numThreads, false);
  }

  /**
   * An abstract minibatch costfunction.
   *
   * @param inputMatrix the data to crunch, bias will be added while calculating
   *          the batches.
   * @param outcomeMatrix the data that denotes the outcome of the features-
   *          while supervised learning. Can be null for unsupervised methods.
   * @param batchSize the batch size to use, 0 denotes full batch learning by
   *          default.
   * @param numThreads the number of threads to use to calculate the batches.
   * @param stochastic if true the batches will be evaluated stochastically,
   *          that means that each iteration the next batch is chosen to
   *          evaluate and calculate gradient and cost.
   */
  public AbstractMiniBatchCostFunction(DoubleVector[] inputMatrix,
      DoubleVector[] outcomeMatrix, int batchSize, int numThreads,
      boolean stochastic) {
    Preconditions.checkArgument(batchSize >= 0
        && batchSize <= inputMatrix.length, "Batchsize wasn't in range of 0-"
        + inputMatrix.length);
    Preconditions.checkArgument(numThreads >= 1,
        "#Threads need to be at least > 0");
    this.stochastic = stochastic;

    ThreadFactory factory = new ThreadFactoryBuilder().setDaemon(true)
        .setNameFormat("MiniBatch Worker %d").build();
    // if we configured batchsize 0, we know that this is the default for full
    // batching, thus only use a single thread.
    Set<Range> partitions = new HashSet<>();
    if (batchSize == 0) {
      numThreads = 1;
      batchSize = inputMatrix.length;
      partitions.add(new Range(0, batchSize - 1));
    } else {
      int offset = 0;
      while (offset < inputMatrix.length) {
        partitions.add(new Range(offset, Math.min(inputMatrix.length - 1,
            offset + (batchSize - 1))));
        offset += batchSize;
      }
    }
    pool = Executors.newFixedThreadPool(stochastic ? 1 : numThreads, factory);

    batches = new ArrayList<>();
    for (Range r : partitions) {
      int start = r.getStart();
      int end = r.getEnd(); // inclusive
      DoubleVector[] featureSubArray = ArrayUtils.subArray(inputMatrix, start,
          end);
      boolean sparse = featureSubArray[0].isSparse();
      DoubleMatrix outcomeMat = null;
      if (outcomeMatrix != null) {
        DoubleVector[] outcomeSubArray = ArrayUtils.subArray(outcomeMatrix,
            start, end);
        outcomeMat = new DenseDoubleMatrix(outcomeSubArray);
      }
      DenseDoubleVector bias = DenseDoubleVector.ones(featureSubArray.length);
      DoubleMatrix featureMatrix = sparse ? new SparseDoubleRowMatrix(
          featureSubArray) : new DenseDoubleMatrix(featureSubArray);
      DoubleMatrix featuresWithBias = sparse ? new SparseDoubleRowMatrix(bias,
          featureMatrix) : new DenseDoubleMatrix(bias, featureMatrix);
      batches.add(new Tuple<>(featuresWithBias, outcomeMat));
    }
  }

  @Override
  public final CostGradientTuple evaluateCost(DoubleVector input) {
    // single batch case, take the low overhead route
    if (batches.size() == 1) {
      CallableMiniBatch batch = new CallableMiniBatch(batches.get(0), input);
      try {
        return batch.call();
      } catch (Exception e) {
        e.printStackTrace();
        // return null so minimizers should fast-fail
        return null;
      }
    }
    // multiple batches
    ExecutorCompletionService<CostGradientTuple> completionService = new ExecutorCompletionService<>(
        pool);

    int submittedBatches = 0;
    // submit all batches to the service, pool will take care of the parallelism
    for (int i = batchOffset; i < batches.size(); i++) {
      completionService.submit(new CallableMiniBatch(batches.get(i), input));
      submittedBatches++;
      if (stochastic) {
        // if stochastic, we will increase the batch
        batchOffset++;
        // check for wrap-ups
        if (batchOffset >= batches.size()) {
          batchOffset = 0;
        }
        break;
      }
    }
    double costSum = 0d;
    DoubleVector gradientSum = new DenseDoubleVector(input.getLength());
    try {
      // now collect the results
      for (int i = 0; i < submittedBatches; i++) {
        CostGradientTuple result = completionService.take().get();
        costSum += result.getCost();
        gradientSum = gradientSum.add(result.getGradient());
      }
    } catch (InterruptedException | ExecutionException e) {
      e.printStackTrace();
      // return null so minimizers should fast-fail
      return null;
    }
    // just return an average over the batches
    return new CostGradientTuple(costSum / submittedBatches,
        gradientSum.divide(submittedBatches));
  }

  /**
   * Evaluate the batch.
   *
   * @param theta the parameters to use.
   * @param featureBatch the batch matrix as input (already contains a bias!).
   * @param outcomeBatch the batch matrix denoting the output.
   * @return the cost/gradient tuple usually used when using
   *         {@link #evaluateCost(DoubleVector)}.
   */
  protected abstract CostGradientTuple evaluateBatch(DoubleVector theta,
      DoubleMatrix featureBatch, DoubleMatrix outcomeBatch);

}
TOP

Related Classes of de.jungblut.math.minimize.AbstractMiniBatchCostFunction$CallableMiniBatch

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.