Package de.jungblut.classification.nn

Source Code of de.jungblut.classification.nn.RBMCostFunction

package de.jungblut.classification.nn;

import java.util.Random;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.cuda.JCUDAMatrixUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.AbstractMiniBatchCostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.Fmincg;
import de.jungblut.math.minimize.GradientDescent;

/**
* Restricted Boltzmann machine implementation using Contrastive Divergence 1
* (CD1). This algorithm is based on what has been teached by Prof. Hinton in
* the Coursera course "Neural Networks for Machine Learning" '12. This is an
* unsupervised learning algorithm to train high level feature detectors. NOTE:
* Sutskever and Tieleman have shown that it is not following the gradient of
* any function (Sutskever and Tieleman, 2010). So this isn't minimizable using
* line searching optimizers like {@link Fmincg}, {@link GradientDescent} is
* doing a great job though as it doesn't care if you're moving into the right
* direction down hill.
*
* @author thomas.jungblut
*
*/
public final class RBMCostFunction extends AbstractMiniBatchCostFunction {

  private final ActivationFunction activationFunction;
  private final int[][] unfoldParameters;

  private final TrainingType type;
  private final double lambda;
  private final Random random;

  public RBMCostFunction(DoubleVector[] currentTrainingSet, int batchSize,
      int numThreads, int numHiddenUnits,
      ActivationFunction activationFunction, TrainingType type, double lambda,
      long seed, boolean stochastic) {
    super(currentTrainingSet, null, batchSize, numThreads, stochastic);
    this.activationFunction = activationFunction;
    this.type = type;
    this.lambda = lambda;
    this.random = new Random(seed);
    this.unfoldParameters = MultilayerPerceptronCostFunction
        .computeUnfoldParameters(new int[] {
            currentTrainingSet[0].getDimension(), numHiddenUnits + 1 });
  }

  @Override
  protected CostGradientTuple evaluateBatch(DoubleVector input,
      DoubleMatrix data, DoubleMatrix outcomeBatch) {
    // input contains the weights between the visible and the hidden units
    DoubleMatrix theta = DenseMatrixFolder.unfoldMatrices(input,
        unfoldParameters)[0].transpose();

    /*
     * POSITIVE PHASE
     */
    DoubleMatrix positiveHiddenProbs = activationFunction.apply(multiply(data,
        theta, false, false));
    // set out hidden bias back to 1
    positiveHiddenProbs.setColumnVector(0,
        DenseDoubleVector.ones(positiveHiddenProbs.getRowCount()));
    DoubleMatrix positiveAssociations = multiply(data, positiveHiddenProbs,
        true, false);
    /*
     * END OF POSITIVE PHASE
     */
    binarize(random, positiveHiddenProbs);
    /*
     * START NEGATIVE PHASE
     */
    DoubleMatrix negativeData = activationFunction.apply(multiply(
        positiveHiddenProbs, theta, false, true));
    negativeData.setColumnVector(0,
        DenseDoubleVector.ones(negativeData.getRowCount()));
    DoubleMatrix negativeHiddenProbs = activationFunction.apply(multiply(
        negativeData, theta, false, false));
    negativeHiddenProbs.setColumnVector(0,
        DenseDoubleVector.ones(negativeHiddenProbs.getRowCount()));
    DoubleMatrix negativeAssociations = multiply(negativeData,
        negativeHiddenProbs, true, false);
    /*
     * END OF NEGATIVE PHASE
     */

    // measure a very simple reconstruction error
    double j = data.subtract(negativeData).pow(2).sum();

    // calculate the approx. gradient
    DoubleMatrix thetaGradient = positiveAssociations.subtract(
        negativeAssociations).divide(data.getRowCount());

    // calculate the weight decay and apply it
    if (lambda != 0d) {
      DoubleVector bias = thetaGradient.getColumnVector(0);
      thetaGradient = thetaGradient.subtract(thetaGradient.multiply(lambda
          / data.getRowCount()));
      thetaGradient.setColumnVector(0, bias);
    }

    // transpose the gradient and negate it, because we transposed theta at the
    // top and our gradient descent subtracts instead of addition.
    return new CostGradientTuple(j,
        DenseMatrixFolder.foldMatrices((DenseDoubleMatrix) thetaGradient
            .multiply(-1).transpose()));
  }

  private DoubleMatrix multiply(DoubleMatrix a1, DoubleMatrix a2,
      boolean a1Transpose, boolean a2Transpose) {
    switch (type) {
      case CPU:
        return multiplyCPU(a1, a2, a1Transpose, a2Transpose);
      case GPU:
        return multiplyGPU(a1, a2, a1Transpose, a2Transpose);
    }
    throw new IllegalArgumentException(
        "Trainingtype couldn't be anticipated by switch.");
  }

  private static DoubleMatrix multiplyCPU(DoubleMatrix a1, DoubleMatrix a2,
      boolean a1Transpose, boolean a2Transpose) {
    a2 = a2Transpose ? a2.transpose() : a2;
    a1 = a1Transpose ? a1.transpose() : a1;
    return a1.multiply(a2);
  }

  private static DoubleMatrix multiplyGPU(DoubleMatrix a1, DoubleMatrix a2,
      boolean a1Transpose, boolean a2Transpose) {
    return JCUDAMatrixUtils.multiply((DenseDoubleMatrix) a1,
        (DenseDoubleMatrix) a2, a1Transpose, a2Transpose);
  }

  int[][] getUnfoldParameters() {
    return this.unfoldParameters;
  }

  static DoubleVector[] binarize(Random r, DoubleVector[] hiddenActivations) {
    for (int i = 0; i < hiddenActivations.length; i++) {
      binarize(r, hiddenActivations[i]);
    }
    return hiddenActivations;
  }

  static DoubleMatrix binarize(Random r, DoubleMatrix hiddenActivations) {
    for (int i = 0; i < hiddenActivations.getRowCount(); i++) {
      for (int j = 0; j < hiddenActivations.getColumnCount(); j++) {
        hiddenActivations.set(i, j,
            hiddenActivations.get(i, j) > r.nextDouble() ? 1d : 0d);
      }
    }
    return hiddenActivations;
  }

  static DoubleVector binarize(Random r, DoubleVector v) {
    for (int j = 0; j < v.getDimension(); j++) {
      v.set(j, v.get(j) > r.nextDouble() ? 1d : 0d);
    }
    return v;
  }

}
TOP

Related Classes of de.jungblut.classification.nn.RBMCostFunction

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.