Package com.numb3r3.common.opt

Source Code of com.numb3r3.common.opt.Softmax

package com.numb3r3.common.opt;

import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class Softmax extends DifferentiableMatrixFunction {
    /**
     * @param M
     * @return e^{\eta_j} / \sum_i e^{\eta_i}
     */
    @Override
    public DoubleMatrix valueAt(DoubleMatrix M) {
        DoubleMatrix exp = MatrixFunctions.exp(M);
        DoubleMatrix sums = exp.columnSums();
        return exp.diviRowVector(sums);
    }

    /**
     * @param X input double matrix
     * @return derivative of softmax, has the same formula as sigmoid :)
     */
    @Override
    public DoubleMatrix derivativeAt(DoubleMatrix X) {
        DoubleMatrix M = valueAt(X);
        return M.mul((M.mul(-1)).addi(1));
    }
}
TOP

Related Classes of com.numb3r3.common.opt.Softmax

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.