Package com.github.neuralnetworks.training.rbm

Source Code of com.github.neuralnetworks.training.rbm.CDTrainerBase

package com.github.neuralnetworks.training.rbm;

import com.github.neuralnetworks.architecture.types.RBM;
import com.github.neuralnetworks.calculation.RBMLayerCalculator;
import com.github.neuralnetworks.training.OneStepTrainer;
import com.github.neuralnetworks.training.TrainingInputData;
import com.github.neuralnetworks.util.Constants;
import com.github.neuralnetworks.util.Matrix;
import com.github.neuralnetworks.util.Properties;

/**
* Base class for Contrastive Divergence
* requires RBMLayerCalculator as the layer calculator. This allows for different implementations of the layer calculator, like GPU/CPU for example
*/
public abstract class CDTrainerBase extends OneStepTrainer<RBM> {

    private static final long serialVersionUID = 1L;

    /**
     * positive phase visible layer results
     */
    private Matrix posPhaseVisible;

    /**
     * negative phase visible layer results
     */
    private Matrix negPhaseVisible;

    /**
     * positive phase hidden layer results
     */
    private Matrix posPhaseHidden;

    /**
     * negative phase hidden layer results
     */
    private Matrix negPhaseHidden;

    public CDTrainerBase(Properties properties) {
  super(properties);
    }

    @Override
    protected void learnInput(TrainingInputData data, int batch) {
  RBM nn = getNeuralNetwork();

  posPhaseVisible = data.getInput();
  if (negPhaseVisible == null || negPhaseVisible.getColumns() != posPhaseVisible.getColumns()) {
      negPhaseVisible = new Matrix(posPhaseVisible.getRows(), posPhaseVisible.getColumns());
      posPhaseHidden = new Matrix(nn.getMainConnections().getConnectionGraph().getRows(), posPhaseVisible.getColumns());
      negPhaseHidden = new Matrix(nn.getMainConnections().getConnectionGraph().getRows(), posPhaseVisible.getColumns());
  }

  getLayerCalculator().gibbsSampling(nn, posPhaseVisible, posPhaseHidden, negPhaseVisible, negPhaseHidden, getGibbsSamplingCount(), batch == 0 ? true : getResetRBM(), true);

  // update weights
  updateWeights(posPhaseVisible, posPhaseHidden, negPhaseVisible, negPhaseHidden);
    }

    public Matrix getPosPhaseVisible() {
        return posPhaseVisible;
    }

    public Matrix getNegPhaseVisible() {
        return negPhaseVisible;
    }

    public Matrix getPosPhaseHidden() {
        return posPhaseHidden;
    }

    public Matrix getNegPhaseHidden() {
        return negPhaseHidden;
    }

    public RBMLayerCalculator getLayerCalculator() {
  return properties.getParameter(Constants.LAYER_CALCULATOR);
    }

    public void setLayerCalculator(RBMLayerCalculator layerCalculator) {
  properties.setParameter(Constants.LAYER_CALCULATOR, layerCalculator);
    }
   
    public Boolean getResetRBM() {
  return properties.getParameter(Constants.RESET_RBM);
    }
   
    public void setResetRBM(boolean resetRBM) {
  properties.setParameter(Constants.RESET_RBM, resetRBM);
    }

    public int getGibbsSamplingCount() {
  return properties.containsKey(Constants.GIBBS_SAMPLING_COUNT) ? (int) properties.get(Constants.GIBBS_SAMPLING_COUNT) : 1;
    }

    protected abstract void updateWeights(Matrix posPhaseVisible, Matrix posPhaseHidden, Matrix negPhaseVisible, Matrix negPhaseHidden);
}
TOP

Related Classes of com.github.neuralnetworks.training.rbm.CDTrainerBase

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.