Package tv.floe.metronome.deeplearning.rbm

Source Code of tv.floe.metronome.deeplearning.rbm.RestrictedBoltzmannMachine

package tv.floe.metronome.deeplearning.rbm;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;



import tv.floe.metronome.deeplearning.neuralnetwork.core.BaseNeuralNetworkVectorized;
import tv.floe.metronome.deeplearning.neuralnetwork.gradient.NeuralNetworkGradient;
import tv.floe.metronome.deeplearning.neuralnetwork.layer.HiddenLayer;
import tv.floe.metronome.deeplearning.neuralnetwork.optimize.NeuralNetworkOptimizer;
import tv.floe.metronome.math.MathUtils;
import tv.floe.metronome.math.MatrixUtils;
import tv.floe.metronome.types.Pair;



/**
* Basic Implementation of a Restricted Boltzmann Machine
*
* A restricted Boltzmann machine (RBM) is a generative stochastic neural
*  network that can learn a probability distribution over its set of inputs.
*
* Applications
*
* - dimensionality reduction, classification, collaborative filtering, feature learning and topic modelling
*
* Based on work by Hinton, et al 2006
*
* And inspired by the RBM Implementation of Adam Gibson:
*
* https://github.com/agibsonccc/java-deeplearning
*
*
*
* @author josh
*
*/
public class RestrictedBoltzmannMachine extends BaseNeuralNetworkVectorized {
 
  //private double learningRate = 0.1d;
  public transient NeuralNetworkOptimizer optimizer;
  public double[] debugWeightAddsBuffer = null; // only set when we want to check things
 
  public RestrictedBoltzmannMachine() { }

  public RestrictedBoltzmannMachine(int nVisible, int nHidden, Matrix weights, Matrix hBias, Matrix vBias, RandomGenerator rng) {
    super(nVisible, nHidden, weights, hBias, vBias, rng);
  }
 
  public RestrictedBoltzmannMachine(Matrix input, int nVisible, int nHidden, Matrix weights, Matrix hBias, Matrix vBias, RandomGenerator rng) {
    super(input, nVisible, nHidden, weights, hBias, vBias, rng);
  }
 
  /**
   * CTOR
   *
   * So at some point we make this a more elaborate setup to build RBMs?
   *
   * @param numVisibleNeurons
   * @param numHiddenNeurons
   */
  public RestrictedBoltzmannMachine(int numVisibleNeurons, int numHiddenNeurons, RandomGenerator rnd) {
   
    this.numberVisibleNeurons = numVisibleNeurons;
    this.numberHiddenNeurons = numHiddenNeurons;
   
    if (rnd == null) {
     
      this.randNumGenerator = new MersenneTwister(1234);

    } else {
     
      this.randNumGenerator = rnd;
     
    }
   
    double a = 1.0 / (double) this.numberVisibleNeurons;
   
   
   
    NormalDistribution u = new NormalDistribution(this.randNumGenerator,0,.01,NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);

    this.connectionWeights = new DenseMatrix( this.numberVisibleNeurons, this.numberHiddenNeurons );
    this.connectionWeights.assign(0.0);
   
    for (int r = 0; r < this.connectionWeights.numRows(); r++) {
     
      for(int c = 0; c < this.connectionWeights.numCols(); c++) {
     
        double init_weight = u.sample( );
       
      //  System.out.println("w: " + init_weight);
       
        this.connectionWeights.setQuick( r, c, init_weight );
     
      }

    }
   

    this.hiddenBiasNeurons = new DenseMatrix( 1, this.numberHiddenNeurons );
    // switch to column vector ?
    //this.hiddenBiasNeurons = new DenseMatrix( this.numberHiddenNeurons, 1 );
    this.hiddenBiasNeurons.assign(0.0);

    this.visibleBiasNeurons = new DenseMatrix( 1, this.numberVisibleNeurons );
    // switch to column vector ?
    //this.visibleBiasNeurons = new DenseMatrix( this.numberVisibleNeurons, 1 );
    this.visibleBiasNeurons.assign(0.0);

    //System.out.println("Creating RBM: visible: " + this.numberVisibleNeurons + ", hidden: " + this.numberHiddenNeurons );
   
   
  }
 
 
  public void setupCDkDebugBuffer(double[] buffer) {
    this.debugWeightAddsBuffer = buffer;
  }
 
  /**
   * Based on Hinton's explanation on CD from (Hinton, 2002)
   *
   * 1. start by setting the states of the visible units to a training vector
   *
   * 2. the binary states of the hidden units are all computed in parallel using Equation (7)
   *
   * 3. Once binary states have been chosen for the hidden units, a "reconstruction" is produced by setting each vi to 1
   * with a probability given by equation (8)
   *
   * 4. The change in weights is given by Equation (9)
   *
   *
   *
   * @param k
   */
  public void contrastiveDivergence(double learningRate, int k, Matrix input) {

    if (input != null) {
     
      this.trainingDataset = input;
     
    }
   
    // actually performs CDk
    NeuralNetworkGradient gradient = getGradient(new Object[]{k,learningRate});
   
    //W.addi(gradient.getwGradient());
    this.connectionWeights = this.connectionWeights.plus(gradient.getwGradient());
   
    //hBias.addi(gradient.gethBiasGradient());
    this.hiddenBiasNeurons = this.hiddenBiasNeurons.plus(gradient.gethBiasGradient());
   
    //vBias.addi(gradient.getvBiasGradient());
    this.visibleBiasNeurons = this.visibleBiasNeurons.plus( gradient.getvBiasGradient() );
   

  }
 

 
  /**
   * reviewed, seems decent
   *
   */
  @Override
  public NeuralNetworkGradient getGradient(Object[] params) {
    int k = (Integer) params[0];
    double learningRate = (Double) params[1];
   
    // init CDk
   
    // do gibbs sampling given V to get the Hidden states based on the training input
    // compute positive phase
    Pair<Matrix, Matrix> hiddenProbsAndSamplesStart = this.sampleHiddenGivenVisible( this.trainingDataset );
       
    Pair<Pair<Matrix, Matrix>,Pair<Matrix, Matrix>> gibbsSamplingMatrices = null;
   
    // now run k full steps of alternating Gibbs sampling
   
    //negative visble "means" or "expected values"
    Matrix negativeVisibleExpectedValues = null;
    //negative value samples
    Matrix negativeVisibleSamples = null;
    //negative hidden means or expected values
    Matrix negativeHiddenExpectedValues = null;
    //negative hidden samples
    Matrix negativeHiddenSamples = null;
   
    for ( int x = 0; x < k; x++ ) {
     
      if (0 == x) {
       
        gibbsSamplingMatrices = this.gibbsSamplingStepFromHidden( hiddenProbsAndSamplesStart.getSecond() );
               
      } else {
       
        gibbsSamplingMatrices = this.gibbsSamplingStepFromHidden( negativeHiddenSamples );
       
      }
     
      // "free energy of the negative phase"
      // now create some easier to use aliases
      negativeVisibleExpectedValues = gibbsSamplingMatrices.getFirst().getFirst();
      negativeVisibleSamples = gibbsSamplingMatrices.getFirst().getSecond();
      negativeHiddenExpectedValues = gibbsSamplingMatrices.getSecond().getFirst();
      negativeHiddenSamples = gibbsSamplingMatrices.getSecond().getSecond();
     
     
    }
       
    // ----- now calculate equation (9) to get the weight changes ------
   
    // now compute the <vi hj>data   
    Matrix trainingDataTimesInitialHiddenStates = this.trainingDataset.transpose().times( hiddenProbsAndSamplesStart.getSecond() );

    // now compute the <vi hj>model (this may be vi * phj --- double check)
    Matrix negativeVisibleSamplesTransposeTimesNegHiddenExpValues = negativeVisibleSamples.transpose().times( negativeHiddenExpectedValues );
       
    // calc the delta between: data - model
    Matrix dataModelDelta = trainingDataTimesInitialHiddenStates.minus( negativeVisibleSamplesTransposeTimesNegHiddenExpValues );
   
    // learningRate * delta(data - model)
    Matrix wGradient = dataModelDelta.times( learningRate );
   
    if (useRegularization) {
   
      //wGradient.subi(W.muli(l2));
     
      //System.out.println("regularization! " + l2);
     
      // TODO: figure out if this should stick around like this
      this.connectionWeights = this.connectionWeights.times(l2);
      wGradient = wGradient.minus(this.connectionWeights);
   
    }

   
    if (momentum != 0) {
   
      //wGradient.muli( 1 - momentum);
     
      wGradient = wGradient.times( 1 - momentum );
     
    }

   
    // This is added normalization for mini-batching
   
    //wGradient.divi(input.rows);
//    wGradient = wGradient.divide( this.trainingDataset.numRows() );
   
   
    // ---- end of equation (9) section -----------------
   
    // update the connection weights and bias terms for visible/hidden units
    //this.connectionWeights = this.connectionWeights.plus( connectionWeightChanges );

    Matrix vBiasGradient = MatrixUtils.mean( this.trainingDataset.minus( negativeVisibleSamples ) , 0).times( learningRate );
    //this.visibleBiasNeurons = this.visibleBiasNeurons.plus( vBiasAdd );

    Matrix hBiasGradient = null;
   
    if(this.sparsity != 0) {
      //all hidden units must stay around this number
//      hBiasGradient = mean(probHidden.getSecond().add( -sparsity),0).mul(learningRate);
      hBiasGradient = MatrixUtils.mean( hiddenProbsAndSamplesStart.getSecond().plus( -sparsity ) , 0).times( learningRate ); //.times(this.learningRate);
    }
    else {
      //update rule: the expected values of the hidden input - the negative hidden  means adjusted by the learning rate
//      hBiasGradient = mean(probHidden.getSecond().sub(nhMeans), 0).mul(learningRate);
      hBiasGradient = MatrixUtils.mean( hiddenProbsAndSamplesStart.getSecond().minus( negativeHiddenExpectedValues ) , 0).times( learningRate ); //.times(this.learningRate);
    }
   
   
   
    //this.hiddenBiasNeurons = this.hiddenBiasNeurons.plus( hBiasAdd );   

    return new NeuralNetworkGradient(wGradient, vBiasGradient, hBiasGradient);
   
   
  }

  /**
   * Reconstruction entropy.
   *
   * Used to calculate how well trained the RBM currently is
   *
   * - the input training set matrix is arranged such that there is a training example per row
   *
   * This compares the similarity of two probability
   * distributions, in this case that would be the input
   * and the reconstructed input with gaussian noise.
   * This will account for either regularization or none
   * depending on the configuration.
   *
   * @return reconstruction error
   */
  public double getReConstructionCrossEntropy() {

    // 1. get sigmoid of the inputMatrix x weights
   
    // probably could just call the propUp call w the training dataset as param
    Matrix preSigmoidHidden = this.trainingDataset.times( this.connectionWeights );
    preSigmoidHidden = MatrixUtils.addRowVector(preSigmoidHidden, this.hiddenBiasNeurons.viewRow(0));
    Matrix sigHidden = MatrixUtils.sigmoid(preSigmoidHidden);
   
   
   
    // 2. get sigmoid of the ( sigH from #1 x transpose(weights) )
   
    // could use propDown here
    Matrix preSigmoidVis = sigHidden.times( this.connectionWeights.transpose() );
    preSigmoidVis = MatrixUtils.addRowVector(preSigmoidVis, this.visibleBiasNeurons.viewRow(0));
    Matrix sigVis = MatrixUtils.sigmoid( preSigmoidVis );
   
   
    // 3. put together the partials to build the cross entropy
   
    Matrix logSigmoidVis = MatrixUtils.log(sigVis);
    Matrix oneMinusSigmoidVis = MatrixUtils.oneMinus(sigVis); //MatrixUtils.ones(sigVis.numRows(), sigVis.numCols()).minus(sigVis);


   
    Matrix oneMinusInput = MatrixUtils.oneMinus(this.trainingDataset);
   
   
    Matrix logOneMinusSigVisible = MatrixUtils.log(oneMinusSigmoidVis);
   
    // D
    // self.input * T.log(T.nnet.sigmoid(pre_sigmoid_nv))
    Matrix inputTimesLogSigVisible = MatrixUtils.elementWiseMultiplication(this.trainingDataset,logSigmoidVis);
   
    // E
    // (1 - self.input) * T.log(1 - T.nnet.sigmoid(pre_sigmoid_nv))
    Matrix oneMinusInputTimesLogOneMinusSigV = MatrixUtils.elementWiseMultiplication(oneMinusInput,logOneMinusSigVisible);
   
    // D + E
    Matrix Inner = inputTimesLogSigVisible.plus(oneMinusInputTimesLogOneMinusSigV);

    // get the sum for each row (example)
    Matrix crossEntRowSums = MatrixUtils.rowSums(Inner);
   
    //MatrixUtils.debug_print(crossEntRowSums);
   
    // now compute the average across the cross-entropies for each training example
    double crossEntFinal = MatrixUtils.mean(crossEntRowSums);
       
    return -crossEntFinal;
  }
 

 
  /**
   * "Compute Activations"
   *
   * Generate probabilities for each hidden unit being set to 1
   * Equation (7) in Hinton
   *
   * This function propagates the visible units activation upwards to
   * the hidden units,
   *
   * aka "propagate up"
   *
   * @param visible
   * @return
   */
  public Matrix propUp(Matrix visible) {
   
    Matrix preSigmoid = visible.times( this.connectionWeights );
    preSigmoid = MatrixUtils.addRowVector(preSigmoid, this.hiddenBiasNeurons.viewRow(0));

    return MatrixUtils.sigmoid(preSigmoid);
  }

  /**
   * Binomial sampling of the hidden values given visible
   *
   * This function infers state of hidden units given visible units
   *
   * Used in Gibbs Sampling
   *
   * @param visible
   */
  public Pair<Matrix, Matrix> sampleHiddenGivenVisible(Matrix visible) {
       
    Matrix hiddenProbs = this.propUp(visible);

    Matrix hiddenBinomialSamples = MatrixUtils.genBinomialDistribution(hiddenProbs, 1, this.randNumGenerator);
   
    //apply dropout
    this.applyDropOutIfNecessary( hiddenBinomialSamples );
   
   
    return new Pair<Matrix, Matrix>(hiddenProbs, hiddenBinomialSamples);
  }
 
 
  /**
   *
   *
   * @param visible
   * @return
   */
  public Matrix propDown(Matrix hidden) {
   
    Matrix preSigmoid = hidden.times( this.connectionWeights.transpose() );
    preSigmoid = MatrixUtils.addRowVector(preSigmoid, this.visibleBiasNeurons.viewRow(0));

    return MatrixUtils.sigmoid(preSigmoid);
  }
 
  /**
   * This function infers state of visible units given hidden units
   *
   * Used in Gibbs Sampling
   *
   * @param hidden
   */
  public Pair<Matrix, Matrix> sampleVisibleGivenHidden(Matrix hidden) {
   
    Matrix visibleProb = this.propDown(hidden);

    Matrix visibleBinomialSample = MatrixUtils.genBinomialDistribution(visibleProb, 1, this.randNumGenerator);

    return new Pair<Matrix, Matrix>(visibleProb, visibleBinomialSample);
   
  }
 
  /**
   * based on:
   * "gibbs_vhv" which performs a step of Gibbs sampling starting from the visible units.
   *
   * TODO: how do we return things?
   *
   */
  public Pair<Pair<Matrix, Matrix>, Pair<Matrix, Matrix>> gibbsSamplingStepFromVisible(Matrix visible) {
 
    Pair<Matrix, Matrix> hiddenProbsAndSamples = this.sampleHiddenGivenVisible(visible);
    Pair<Matrix, Matrix> visibleProbsAndSamples = this.sampleVisibleGivenHidden( hiddenProbsAndSamples.getSecond() );

    return new Pair<Pair<Matrix, Matrix>, Pair<Matrix, Matrix>>(hiddenProbsAndSamples, visibleProbsAndSamples);
  }
 
  /**
   * based on: "gibbs_hvh"
   * - This function implements one step of Gibbs sampling,
   * starting from the visible state
   *
   * @param hidden
   */
  public Pair<Pair<Matrix, Matrix>, Pair<Matrix, Matrix>> gibbsSamplingStepFromHidden(Matrix hidden) {
   
   
    Pair<Matrix, Matrix> visibleProbsAndSamples = this.sampleVisibleGivenHidden(hidden);
    Pair<Matrix, Matrix> hiddenProbsAndSamples = this.sampleHiddenGivenVisible(visibleProbsAndSamples.getSecond());
   
    return new Pair<Pair<Matrix, Matrix>, Pair<Matrix, Matrix>>(visibleProbsAndSamples, hiddenProbsAndSamples);
  }

  /**
   * Free energy for an RBM
   * Lower energy models have higher probability
   * of activations
   * @param visibleSample the sample to test on
   * @return the free engery for this sample
   */
  public double freeEnergy(Matrix visibleSample) {

    Matrix wxB = MatrixUtils.addRowVector( visibleSample.times( this.connectionWeights ), this.hiddenBiasNeurons.viewRow( 0 ) );
   
    double vBiasTerm = MathUtils.dot(visibleSample, this.visibleBiasNeurons);
   
    double hBiasTerm = MatrixUtils.sum( MatrixUtils.log( MatrixUtils.exp( wxB ).plus( 1.0 ) ) );
   
    return -hBiasTerm - vBiasTerm;
   
  }
 
  /**
   *
   * Reconstructs the visible input.
   * A reconstruction is a propagation down of the reconstructed hidden input.
   *
   *
   */
  public Matrix reconstructVisibleInput(Matrix visible) {

    // propUp
    Matrix propUpHiddenResult = this.propUp(visible);
   
    return this.propDown(propUpHiddenResult);
  }

 
  /**
   * TODO: Think about the object[] array as parameter mechanism a bit more
   *
   */
  @Override
  public void trainTillConvergence(Matrix input, double learningRate, Object[] params) {
   
    if ( null != input) {
      this.trainingDataset = input;
    }
   
    // still java 6!
    //int k = (Integer) params[0];
    //trainTillConvergence(learningRate, k, input);
   
    optimizer = new RestrictedBoltzmannMachineOptimizer( this, learningRate, params );
    optimizer.train( input );
   
   
  }
 
  /**
   * Trains till global minimum is found.
   *
   * TODO: implement the newer style optimizer stuff here!
   *
   * @param learningRate
   * @param k
   * @param input
   */
  public void trainTillConvergence(double learningRate, int k, Matrix input) {
   
    if (input != null) {
      this.trainingDataset = input;
    }

    //System.out.println("using Mallet's optimization");
   
    optimizer = new RestrictedBoltzmannMachineOptimizer(this, learningRate, new Object[]{ k });
    optimizer.train(input);

   
  }
 

  @Override
  public double lossFunction(Object[] params) {
    return getReConstructionCrossEntropy();
  }

  @Override
  public void train(Matrix input, double learningRate, Object[] params) {
    int k = (Integer) params[0];
    //this.learningRate = learningRate;
    contrastiveDivergence(learningRate, k, input);
  }

  /**
   * Reconstructs the visible input.
   * A reconstruction is a propdown of the reconstructed hidden input.
   * @param v the visible input
   * @return the reconstruction of the visible input
   */
  @Override
  public Matrix reconstruct(Matrix v) {
 
    return propDown(propUp(v));
   
  }
 
  /*
   *

   
   *
   */

  /**
   * Copies params from the passed in network
   * to this one
   * @param n the network to copy
   */
  public void update(BaseNeuralNetworkVectorized n) {
    /*
    this.W = n.W;
    this.hBias = n.hBias;
    this.vBias = n.vBias;
    this.l2 = n.l2;
    this.useRegularization = n.useRegularization;
    this.momentum = n.momentum;
    this.nHidden = n.nHidden;
    this.nVisible = n.nVisible;
    this.rng = n.rng;
    this.sparsity = n.sparsity;
    */
  }


  /**
   * Serializes this to the output stream.
   * @param os the output stream to write to
   */
  public void write(OutputStream os) {
    try {

        DataOutput d = new DataOutputStream(os);
        ObjectOutputStream oos = new ObjectOutputStream(os);
       
        d.writeInt( this.numberVisibleNeurons );
        d.writeInt( this.numberHiddenNeurons );
       
        MatrixWritable.writeMatrix(d, this.hiddenBiasNeurons );
        MatrixWritable.writeMatrix(d, this.visibleBiasNeurons );
        MatrixWritable.writeMatrix(d, this.connectionWeights );
//        MatrixWritable.writeMatrix(d, this.trainingDataset ); 
               
        oos.writeObject( this.randNumGenerator );

      d.writeDouble( this.sparsity );
      d.writeDouble( this.momentum );
      d.writeDouble( this.l2 );
      d.writeInt( this.renderWeightsEveryNumEpochs );
      d.writeDouble( this.fanIn );
      d.writeBoolean( this.useRegularization );

       

    } catch (IOException e) {
      throw new RuntimeException(e);
    }

  } 
 
 
 
 
  /**
   * Load (using {@link ObjectInputStream}
   * @param is the input stream to load from (usually a file)
   */
  public void load(InputStream is) {
    try {

      DataInput di = new DataInputStream(is);
     
      //this.nIn = di.readInt();
//      this.input = MatrixWritable.readMatrix( di );


        //DataOutput d = new DataOutputStream(os);
        ObjectInputStream ois = new ObjectInputStream(is);
       
        this.numberVisibleNeurons = di.readInt();
        this.numberHiddenNeurons = di.readInt();
       
        this.hiddenBiasNeurons = MatrixWritable.readMatrix( di );
        this.visibleBiasNeurons = MatrixWritable.readMatrix( di );
        this.connectionWeights = MatrixWritable.readMatrix( di );
//        this.trainingDataset = MatrixWritable.readMatrix( di ); 
               
        this.randNumGenerator = (RandomGenerator) ois.readObject();

      this.sparsity = di.readDouble();
      this.momentum = di.readDouble();
      this.l2 = di.readDouble();
      this.renderWeightsEveryNumEpochs = di.readInt();
      this.fanIn = di.readDouble();
      this.useRegularization = di.readBoolean();
     
    } catch (Exception e) {
      throw new RuntimeException(e);
    }

  } 
 
  /**
   * Serializes this to the output stream.
   *
   * Used in parameter averaging
   *
   * @param os the output stream to write to
   */
  public void serializeParameters(OutputStream os) {
    try {

        DataOutput d = new DataOutputStream(os);
       
        d.writeInt( this.numberVisibleNeurons );
        d.writeInt( this.numberHiddenNeurons );
       
        // ??
        MatrixWritable.writeMatrix(d, this.hiddenBiasNeurons );
        // ??
        MatrixWritable.writeMatrix(d, this.visibleBiasNeurons );
        // yes
        MatrixWritable.writeMatrix(d, this.connectionWeights );
               
    } catch (IOException e) {
      throw new RuntimeException(e);
    }

  }   
 
  /**
   * Load parameter values from the byte stream
   *
   */
  public void loadParameterValues(InputStream is) {
    try {

      DataInput di = new DataInputStream(is);
     
        this.numberVisibleNeurons = di.readInt();
        this.numberHiddenNeurons = di.readInt();
       
        this.hiddenBiasNeurons = MatrixWritable.readMatrix( di );
        this.visibleBiasNeurons = MatrixWritable.readMatrix( di );
        this.connectionWeights = MatrixWritable.readMatrix( di );
               
    } catch (Exception e) {
      throw new RuntimeException(e);
    }

  }   
 
 
}
TOP

Related Classes of tv.floe.metronome.deeplearning.rbm.RestrictedBoltzmannMachine

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.