Package de.jungblut.math.minimize

Source Code of de.jungblut.math.minimize.GradientDescent$GradientDescentBuilder

package de.jungblut.math.minimize;

import java.util.Arrays;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.google.common.base.Preconditions;

import de.jungblut.math.DoubleVector;

/**
* Gradient descent implementation with some neat features like momentum,
* divergence detection, delta breaks and bold driver and scheduled annealing
* adaptive learning rates. For more sophisticated configuration use the
* {@link GradientDescentBuilder}.
*
* @author thomas.jungblut
*
*/
public final class GradientDescent extends AbstractMinimizer {

  private static final Log LOG = LogFactory.getLog(GradientDescent.class);

  private static final int COST_HISTORY = 3;

  public static class GradientDescentBuilder {

    private final double alpha;
    private double breakDifference;
    private double momentum;
    private boolean breakOnDivergence;
    private boolean boldDriver;
    private double boldIncreasePercentage;
    private double boldDecreasePercentage;
    private int annealingIteration = -1;

    private GradientDescentBuilder(double alpha) {
      this.alpha = alpha;
    }

    public GradientDescent build() {
      return new GradientDescent(this);
    }

    /**
     * Add momentum to this gradient descent minimizer.
     *
     * @param momentum the momentum to use. Between 0 and 1.
     * @return the builder again.
     */
    public GradientDescentBuilder momentum(double momentum) {
      Preconditions.checkArgument(momentum >= 0d && momentum <= 1d,
          "Momentum must be between 0 and 1.");
      this.momentum = momentum;
      return this;
    }

    /**
     * BoldDriver will change the learning rate over time by observing the cost
     * of the costfunction. If the cost decreases, it will increase the learning
     * rate by 5%. If the cost increases it will cut the learning rate in half.
     *
     * @return the builder again.
     */
    public GradientDescentBuilder boldDriver() {
      return boldDriver(0.5, 0.05);
    }

    /**
     * BoldDriver will change the learning rate over time by observing the cost
     * of the costfunction. If the cost decreases, it will increase the learning
     * rate (typically by 5%). If the cost increases it will (typically) cut the
     * learning rate in half.
     *
     * @param increasedCostPercentage the percentage of the learning rate that
     *          will be used when cost increases.
     * @param decreasedCostPercentage the percentage of the learning rate that
     *          will be used when cost decreases.
     * @return the builder again.
     */
    public GradientDescentBuilder boldDriver(double increasedCostPercentage,
        double decreasedCostPercentage) {
      Preconditions.checkArgument(increasedCostPercentage >= 0d
          && increasedCostPercentage <= 1d,
          "increasedCostPercentage must be between 0 and 1.");
      Preconditions.checkArgument(decreasedCostPercentage >= 0d
          && decreasedCostPercentage <= 1d,
          "decreasedCostPercentage must be between 0 and 1.");
      this.boldDriver = true;
      this.boldIncreasePercentage = increasedCostPercentage;
      this.boldDecreasePercentage = decreasedCostPercentage;
      return this;
    }

    /**
     * If called, this breaks when the gradient descent minimizer starts to
     * diverge (costs are growing).
     *
     * @return the builder again.
     */
    public GradientDescentBuilder breakOnDivergence() {
      this.breakOnDivergence = true;
      return this;
    }

    /**
     * Breaks minimization process when the given delta in costs have been
     * archieved. Usually a quite low value of 1e-4 to 1e-8.
     *
     * @param delta the delta to break in difference between two costs.
     * @return the builder again.
     */
    public GradientDescentBuilder breakOnDifference(double delta) {
      this.breakDifference = delta;
      return this;
    }

    /**
     * Sets a simple annealing (alpha / (1+current_iteration / phi)) where phi
     * is the given parameter here. This will gradually lower the global
     * learning rate after the given amount of iterations.
     *
     * @param iteration the iteration to start annealing.
     * @return the builder again.
     */
    public GradientDescentBuilder annealingAfter(int iteration) {
      Preconditions.checkArgument(iteration > 0,
          "Annealing can only kick in after the first iteration! Given: "
              + iteration);
      this.annealingIteration = iteration;
      return this;
    }

    /**
     * Creates a new builder.
     *
     * @param alpha the learning rate to set.
     * @return a new builder.
     */
    public static GradientDescentBuilder create(double alpha) {
      return new GradientDescentBuilder(alpha);
    }

  }

  private final boolean breakOnDivergence;
  private final double breakDifference;
  private final double momentum;
  private final double alpha;
  private final boolean boldDriver;
  private final double boldIncreasePercentage;
  private final double boldDecreasePercentage;
  private final int annealingIteration;

  private GradientDescent(GradientDescentBuilder builder) {
    this.alpha = builder.alpha;
    this.breakDifference = builder.breakDifference;
    this.momentum = builder.momentum;
    this.breakOnDivergence = builder.breakOnDivergence;
    this.boldDriver = builder.boldDriver;
    this.boldIncreasePercentage = builder.boldIncreasePercentage;
    this.boldDecreasePercentage = builder.boldDecreasePercentage;
    this.annealingIteration = builder.annealingIteration;
  }

  /**
   * @param alpha the learning rate.
   * @param limit the delta in cost to archieve to break the iterations.
   */
  public GradientDescent(double alpha, double limit) {
    this(GradientDescentBuilder.create(alpha).breakOnDifference(limit));
  }

  @Override
  public final DoubleVector minimize(CostFunction f, DoubleVector pInput,
      final int maxIterations, boolean verbose) {

    double[] lastCosts = new double[COST_HISTORY];
    Arrays.fill(lastCosts, Double.MAX_VALUE);
    final int lastIndex = lastCosts.length - 1;
    DoubleVector lastTheta = null;
    DoubleVector lastGradient = null;
    DoubleVector theta = pInput;
    double alpha = this.alpha;
    for (int iteration = 0; iteration < maxIterations; iteration++) {
      CostGradientTuple evaluateCost = f.evaluateCost(theta);
      if (verbose) {
        LOG.info("Iteration " + iteration + " | Cost: "
            + evaluateCost.getCost());
      }
      shiftLeft(lastCosts);
      lastCosts[lastIndex] = evaluateCost.getCost();
      // break if we converged below the limit
      if (converged(lastCosts, breakDifference)) {
        break;
      }
      // break if we are going in the wrong direction
      if (breakOnDivergence && ascending(lastCosts)) {
        break;
      }

      DoubleVector gradient = evaluateCost.getGradient();
      // check the bold driver
      if (boldDriver) {
        if (lastGradient != null) {
          double costDifference = getCostDifference(lastCosts);
          if (costDifference < 0) {
            // we can increase, because cost decreased
            alpha += (alpha * boldDecreasePercentage);
          } else {
            // we decrease, because cost increased
            // we undo the last theta change
            theta = lastTheta;
            gradient = lastGradient;
            alpha -= (alpha * boldIncreasePercentage);
          }
          if (verbose) {
            LOG.info("Iteration " + iteration + " | Alpha: " + alpha + "\n");
          }
        }
        lastGradient = gradient;
      }
      // check annealing
      if (annealingIteration > 0) {
        // always pick the initial learning rate
        alpha = this.alpha / (1d + iteration / annealingIteration);
      }
      // save our last parameter
      lastTheta = theta;
      // basically subtract the gradient multiplied with the learning rate
      theta = theta.subtract(gradient.multiply(alpha));
      if (lastTheta != null && momentum != 0d) {
        // we add momentum as the parameter "m" multiplied by the difference of
        // both theta vectors
        theta = theta.add((lastTheta.subtract(theta)).multiply(momentum));
      }
      onIterationFinished(iteration, evaluateCost.getCost(), theta);
    }

    return theta;

  }

  /**
   * Minimize a given cost function f with the initial parameters pInput (also
   * called theta) with a learning rate alpha and a fixed number of iterations.
   * The loop can break earlier if costs converge below the limit. If the same
   * cost was archieved three times in a row, it will also break the iterations.
   *
   * @param f the function to minimize.
   * @param pInput the starting parameters.
   * @param alpha the learning rate.
   * @param limit the cost to archieve to break the iterations.
   * @param length the number of iterations.
   * @param verbose if true prints progress to STDOUT.
   * @return the learned minimal parameters.
   */
  public static DoubleVector minimizeFunction(CostFunction f,
      DoubleVector pInput, double alpha, double limit, int length,
      final boolean verbose) {
    return new GradientDescent(alpha, limit).minimize(f, pInput, length,
        verbose);
  }

  static void shiftLeft(double[] lastCosts) {
    final int lastIndex = lastCosts.length - 1;
    for (int i = 0; i < lastIndex; i++) {
      lastCosts[i] = lastCosts[i + 1];
    }
    // shift MAX_VALUE into the last position
    lastCosts[lastIndex] = Double.MAX_VALUE;
  }

  static boolean converged(double[] lastCosts, double limit) {
    return Math.abs(getCostDifference(lastCosts)) < limit;
  }

  static boolean ascending(double[] lastCosts) {
    double last = lastCosts[0];
    boolean ascending = false;
    for (int i = 1; i < lastCosts.length; i++) {
      ascending = last < lastCosts[i];
      last = lastCosts[i];
    }
    return ascending;
  }

  private static double getCostDifference(double[] lastCosts) {
    return lastCosts[lastCosts.length - 1] - lastCosts[lastCosts.length - 2];
  }

}
TOP

Related Classes of de.jungblut.math.minimize.GradientDescent$GradientDescentBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.