Package org.gd.spark.opendl.downpourSGD.Softmax

Source Code of org.gd.spark.opendl.downpourSGD.Softmax.LR$PredictSpark

/*
* Copyright 2013 GuoDing
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.gd.spark.opendl.downpourSGD.Softmax;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Writer;
import java.util.List;

import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.gd.spark.opendl.downpourSGD.SGDTrainConfig;
import org.gd.spark.opendl.downpourSGD.SampleVector;
import org.gd.spark.opendl.downpourSGD.train.SGDBase;
import org.gd.spark.opendl.downpourSGD.train.SGDParam;
import org.gd.spark.opendl.util.MathUtil;
import org.gd.spark.opendl.util.MyConjugateGradient;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

import cc.mallet.optimize.Optimizable;


/**
* Logistic Regression(Softmax) node <p/>
*
* @author GuoDing
* @since 2013-08-01
*/
public final class LR extends SGDBase {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(LR.class);
    private int x_num;
    private int y_num;
    private LRParam lrparam;

    /**
     * Constructor with random initial W&&B param
     *
     * @param x_feature_num Feature num
     * @param y_class_num Class num
     */
    public LR(int x_feature_num, int y_class_num) {
        this(x_feature_num, y_class_num, null, null);
    }

    /**
     * Constructor with initial W&&B param
     *
     * @param x_feature_num Feature num
     * @param y_class_num Class num
     * @param _w Specify W param matrix
     * @param _b Specify B param vector
     */
    public LR(int x_feature_num, int y_class_num, double[][] _w, double[] _b) {
        x_num = x_feature_num;
        y_num = y_class_num;
        lrparam = new LRParam(x_feature_num, y_class_num, _w, _b);
        param = lrparam;
    }

    /**
     * Do predict work with multiple sample(standalone)
     * @param x Input samples matrix
     * @return Predict result matrix (row=x's row, column=class num)
     */
    public final DoubleMatrix predict(DoubleMatrix x) {
        DoubleMatrix y = x.mmul(lrparam.w.transpose()).addiRowVector(lrparam.b);
        softmax(y);
        return y;
    }

    /**
     * Do predict work with one sample
     * @param x Input sample
     * @param y Output predict result
     */
    public final void predict(double[] x, double[] y) {
        for (int i = 0; i < y_num; i++) {
            y[i] = 0;
            for (int j = 0; j < x_num; j++) {
                y[i] += lrparam.w.get(i, j) * x[j];
            }
            y[i] += lrparam.b.get(i, 0);
        }
        softmax(y);
    }
   
    /**
     * Do predict work on spark
     * @param samples Input data RDD
     * @copyX Whether copy x data from original input to output SampleVector
     * @return Predict result data RDD
     */
    public final JavaRDD<SampleVector> predict(JavaRDD<SampleVector> samples, boolean copyX) {
      return samples.map(new PredictSpark(copyX));
    }

    private void softmax(DoubleMatrix y) {
        DoubleMatrix max = y.rowMaxs();
        MatrixFunctions.expi(y.subiColumnVector(max));
        DoubleMatrix sum = y.rowSums();
        y.diviColumnVector(sum);
    }

    private void softmax(double[] y) {
        double max = 0.0;
        double sum = 0.0;
        for (int i = 0; i < y_num; i++) {
            if (max < y[i]) {
                max = y[i];
            }
        }
        for (int i = 0; i < y_num; i++) {
            y[i] = Math.exp(y[i] - max);
            sum += y[i];
        }
        for (int i = 0; i < y_num; i++) {
            y[i] /= sum;
        }
    }

    /**
     *
     * @return W param matrix
     */
    public DoubleMatrix getW() {
        return lrparam.w;
    }

    /**
     *
     * @return B param vector
     */
    public DoubleMatrix getB() {
        return lrparam.b;
    }

    /**
     *
     * @return Feature num
     */
    public int getX() {
        return x_num;
    }

    /**
     *
     * @return Class num
     */
    public int getY() {
        return y_num;
    }

    @Override
    public final void read(DataInput in) throws IOException {
        x_num = in.readInt();
        y_num = in.readInt();
        for (int i = 0; i < y_num; i++) {
            for (int j = 0; j < x_num; j++) {
              lrparam.w.put(i, j, in.readDouble());
            }
        }
        for (int i = 0; i < y_num; i++) {
          lrparam.b.put(i, 0, in.readDouble());
        }
    }

    @Override
    public final void write(DataOutput out) throws IOException {
        out.writeInt(x_num);
        out.writeInt(y_num);
        for (int i = 0; i < y_num; i++) {
            for (int j = 0; j < x_num; j++) {
                out.writeDouble(lrparam.w.get(i, i));
            }
        }
        for (int i = 0; i < y_num; i++) {
            out.writeDouble(lrparam.b.get(i, 0));
        }
    }

    @Override
    public final void print(Writer wr) throws IOException {
        String newLine = System.getProperty("line.separator");
        wr.write(String.valueOf(x_num));
        wr.write(",");
        wr.write(String.valueOf(y_num));
        wr.write(newLine);
        for (int i = 0; i < y_num; i++) {
            for (int j = 0; j < x_num; j++) {
                wr.write(String.valueOf(lrparam.w.get(i, i)));
                wr.write(",");
            }
            wr.write(newLine);
        }
        for (int i = 0; i < y_num; i++) {
            wr.write(String.valueOf(lrparam.b.get(i, 0)));
            wr.write(",");
        }
        wr.write(newLine);
    }

    private class LROptimizer implements Optimizable.ByGradientValue {
        private DoubleMatrix my_w;
        private DoubleMatrix my_b;
        private DoubleMatrix my_x_samples;
        private DoubleMatrix my_y_samples;
        private DoubleMatrix curr_predict_y;
        private int nbr_samples;
        private SGDTrainConfig my_config;

        public LROptimizer(SGDTrainConfig config, DoubleMatrix x_samples, DoubleMatrix y_samples, DoubleMatrix curr_w,
                DoubleMatrix curr_b) {
            my_x_samples = x_samples;
            my_y_samples = y_samples;
            my_w = curr_w;
            my_b = curr_b;
            nbr_samples = x_samples.getRows();
            my_config = config;
        }

        @Override
        public int getNumParameters() {
            return y_num * x_num + y_num;
        }

        @Override
        public double getParameter(int arg) {
            if (arg < y_num * x_num) {
                int i = arg / x_num;
                int j = arg % x_num;
                return my_w.get(i, j);
            }
            return my_b.get(arg - y_num * x_num, 0);
        }

        @Override
        public void getParameters(double[] arg) {
            int idx = 0;
            for (int i = 0; i < y_num; i++) {
                for (int j = 0; j < x_num; j++) {
                    arg[idx++] = my_w.get(i, j);
                }
            }
            for (int i = 0; i < y_num; i++) {
                arg[idx++] = my_b.get(i, 0);
            }
        }

        @Override
        public void setParameter(int arg0, double arg1) {
            if (arg0 < y_num * x_num) {
                int i = arg0 / x_num;
                int j = arg0 % x_num;
                my_w.put(i, j, arg1);
            } else {
                my_b.put(arg0 - y_num * x_num, 0, arg1);
            }
        }

        @Override
        public void setParameters(double[] arg) {
            int idx = 0;
            for (int i = 0; i < y_num; i++) {
                for (int j = 0; j < x_num; j++) {
                    my_w.put(i, j, arg[idx++]);
                }
            }
            for (int i = 0; i < y_num; i++) {
                my_b.put(i, 0, arg[idx++]);
            }
        }

        @Override
        public double getValue() {
            curr_predict_y = my_x_samples.mmul(my_w.transpose()).addiRowVector(my_b);
            softmax(curr_predict_y);
            double loss = MatrixFunctions.powi(curr_predict_y.sub(my_y_samples), 2).sum() / nbr_samples;
            if (my_config.isUseRegularization()) {
                if (0 != my_config.getLamada1()) {
                    loss += my_config.getLamada1()
                            * (MatrixFunctions.abs(my_w).sum() + MatrixFunctions.abs(my_b).sum()); // L1
                }
                if (0 != my_config.getLamada2()) {
                    loss += 0.5 * my_config.getLamada2()
                            * (MatrixFunctions.pow(my_w, 2).sum() + MatrixFunctions.pow(my_b, 2).sum()); // L2
                }
            }
            return -loss;
        }

        @Override
        public void getValueGradient(double[] arg) {
            DoubleMatrix delta_b = my_y_samples.sub(curr_predict_y);
            DoubleMatrix delta_w = delta_b.transpose().mmul(my_x_samples);
            delta_b = delta_b.columnSums().divi(nbr_samples);
            delta_w.divi(nbr_samples);

            if (my_config.isUseRegularization()) {
                if (0 != my_config.getLamada1()) {
                    delta_w.addi(MatrixFunctions.signum(my_w).mmuli(my_config.getLamada1()));
                    delta_b.addi(MatrixFunctions.signum(my_b).transpose().mmuli(my_config.getLamada1()));
                }
                if (0 != my_config.getLamada2()) {
                    delta_w.addi(my_w.mmul(my_config.getLamada2()));
                    delta_b.addi(my_b.transpose().mmul(my_config.getLamada2()));
                }
            }

            int idx = 0;
            for (int i = 0; i < y_num; i++) {
                for (int j = 0; j < x_num; j++) {
                    arg[idx++] = delta_w.get(i, j);
                }
            }
            for (int i = 0; i < y_num; i++) {
                arg[idx++] = delta_b.get(0, i);
            }
        }
    }
   
    private class PredictSpark extends Function<SampleVector, SampleVector> {
    private static final long serialVersionUID = 1L;
    private boolean copyX = false;
    public PredictSpark(boolean _copyX) {
      copyX = _copyX;
    }
    @Override
    public SampleVector call(SampleVector arg) throws Exception {
      SampleVector ret = new SampleVector(x_num, y_num);
      if(copyX) {
        for(int i = 0; i < x_num; i++) {
          ret.getX()[i] = arg.getX()[i];
        }
      }
      predict(arg.getX(), ret.getY());
      return ret;
    }
    }

  @Override
  protected void gradientUpdateMiniBatch(SGDTrainConfig config, DoubleMatrix x_samples, DoubleMatrix y_samples, SGDParam curr_param) {
    int nbr_samples = x_samples.rows;
    DoubleMatrix curr_w = ((LRParam)curr_param).w;
    DoubleMatrix curr_b = ((LRParam)curr_param).b;
   
      DoubleMatrix curr_predict_y = x_samples.mmul(curr_w.transpose()).addiRowVector(curr_b);
        softmax(curr_predict_y);
        DoubleMatrix delta_b = y_samples.sub(curr_predict_y);
        DoubleMatrix delta_w = delta_b.transpose().mmul(x_samples);
        delta_b = delta_b.columnSums().divi(nbr_samples);
        delta_w.divi(nbr_samples);
       
        if (config.isUseRegularization()) {
            if (0 != config.getLamada1()) {
                delta_w.addi(MatrixFunctions.signum(curr_w).mmuli(config.getLamada1()));
                delta_b.addi(MatrixFunctions.signum(curr_b).transpose().mmuli(config.getLamada1()));
            }
            if (0 != config.getLamada2()) {
                delta_w.addi(curr_w.mmul(config.getLamada2()));
                delta_b.addi(curr_b.transpose().mmul(config.getLamada2()));
            }
        }
       
        curr_w.addi(delta_w.muli(config.getLearningRate()));
        curr_b.addi(delta_b.transpose().muli(config.getLearningRate()));
  }

  @Override
  protected void gradientUpdateCG(SGDTrainConfig config, DoubleMatrix x_samples, DoubleMatrix y_samples, SGDParam curr_param) {
    DoubleMatrix curr_w = ((LRParam)curr_param).w;
    DoubleMatrix curr_b = ((LRParam)curr_param).b;
   
    LROptimizer lropt = new LROptimizer(config, x_samples, y_samples, curr_w, curr_b);
        MyConjugateGradient cg = new MyConjugateGradient(lropt, config.getCgInitStepSize());
        cg.setTolerance(config.getCgTolerance());
        try {
            cg.optimize(config.getCgMaxIterations());
        } catch (Throwable e) {
            logger.error("", e);
        }
  }

  @Override
  protected void mergeParam(SGDParam new_param, int nrModelReplica) {
    LRParam new_lrparam = (LRParam)new_param;
    lrparam.w.addi(new_lrparam.w.sub(lrparam.w).divi(nrModelReplica));
      lrparam.b.addi(new_lrparam.b.sub(lrparam.b).divi(nrModelReplica));
  }

  @Override
  protected double loss(List<SampleVector> samples) {
    DoubleMatrix x_samples = MathUtil.convertX2Matrix(samples);
        DoubleMatrix y_samples = MathUtil.convertY2Matrix(samples);
        DoubleMatrix predict_y = predict(x_samples);
        return MatrixFunctions.powi(predict_y.sub(y_samples), 2).sum();
  }

  @Override
  protected boolean isSupervise() {
    return true;
  }
}
TOP

Related Classes of org.gd.spark.opendl.downpourSGD.Softmax.LR$PredictSpark

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.