Package tv.floe.metronome.deeplearning.neuralnetwork.optimize

Source Code of tv.floe.metronome.deeplearning.neuralnetwork.optimize.LogisticRegressionOptimizer

package tv.floe.metronome.deeplearning.neuralnetwork.optimize;

import java.io.Serializable;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;


import tv.floe.metronome.deeplearning.neuralnetwork.core.LogisticRegression;
import tv.floe.metronome.deeplearning.neuralnetwork.gradient.LogisticRegressionGradient;
import tv.floe.metronome.math.ArrayUtils;
import tv.floe.metronome.math.MatrixUtils;

import cc.mallet.optimize.Optimizable;

public class LogisticRegressionOptimizer implements Optimizable.ByGradientValue,OptimizableByGradientValueMatrix {

  private LogisticRegression logReg;
  private double lr;
 
 
 
  public LogisticRegressionOptimizer(LogisticRegression logReg, double lr) {
    super();
    this.logReg = logReg;
    this.lr = lr;
  }

  @Override
  public int getNumParameters() {
    return MatrixUtils.length( logReg.connectionWeights ) + MatrixUtils.length( logReg.biasTerms );
  }

  @Override
  public void getParameters(double[] buffer) {

    for(int i = 0; i < buffer.length; i++) {
      buffer[i] = getParameter(i);
    }

  }

  @Override
  public double getParameter(int index) {
   
    if ( index >= MatrixUtils.length(logReg.connectionWeights)) {
   
      return MatrixUtils.getElement( logReg.biasTerms, index - MatrixUtils.length(logReg.connectionWeights) );
     
    }
   
    return MatrixUtils.getElement( logReg.connectionWeights , index);
   
  }

  @Override
  public void setParameters(double[] params) {
   
    for(int i = 0; i < params.length; i++) {
     
      setParameter(i,params[i]);
     
    }
   
  }

  @Override
  public void setParameter(int index, double value) {
   
    if (index >= MatrixUtils.length( logReg.connectionWeights )) {
     
      MatrixUtils.setElement( logReg.biasTerms, index - MatrixUtils.length(logReg.connectionWeights), value);
     
    } else {
     
      MatrixUtils.setElement( logReg.connectionWeights, index, value);
     
    }
   
  }

  @Override
  public void getValueGradient(double[] buffer) {
   
    LogisticRegressionGradient grad = logReg.getGradient( lr );
   
    for (int i = 0; i < buffer.length; i++) {
   
      if ( i < MatrixUtils.length( logReg.connectionWeights )) {
       
        buffer[ i ] = MatrixUtils.getElement( grad.getwGradient(), i );
       
      } else {
       
        buffer[ i ] = MatrixUtils.getElement( grad.getbGradient(), i - MatrixUtils.length( logReg.connectionWeights ) );
       
      }
     
    }
   
  }

  @Override
  public double getValue() {
 
    return -logReg.negativeLogLikelihood();

  }

  @Override
  public Matrix getParameters() {
   
    Matrix params = new DenseMatrix(1, getNumParameters() );
   
    for (int i = 0; i < MatrixUtils.length( params ); i++) {
     
    //  params.put(i,getParameter(i));
      MatrixUtils.setElement( params, i, this.getParameter( i ) );
     
    }
   
    return params;
   
  }

  @Override
  public void setParameters(Matrix params) {
   
    //this.setParameters(params.toArray());
    this.setParameters( ArrayUtils.flatten( MatrixUtils.fromMatrix( params ) ) );
   
  }

  @Override
  public Matrix getValueGradient() {
   
    LogisticRegressionGradient grad = logReg.getGradient( lr );
    Matrix ret = new DenseMatrix( 1, getNumParameters() );
   
    for (int i = 0; i < MatrixUtils.length( ret ); i++) {
     
      if ( i < MatrixUtils.length( logReg.connectionWeights  ) ) {
       
        MatrixUtils.setElement( ret, i, MatrixUtils.getElement( grad.getwGradient(), i ) );
       
      } else {
       
        MatrixUtils.setElement( ret, i, MatrixUtils.getElement( grad.getbGradient(), i - MatrixUtils.length( logReg.connectionWeights ) ) );
       
      }
     
    }
    return ret;
  }
 
 
}
TOP

Related Classes of tv.floe.metronome.deeplearning.neuralnetwork.optimize.LogisticRegressionOptimizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.