}
@Override
protected void gradientUpdateMiniBatch(SGDTrainConfig config, DoubleMatrix x_samples, DoubleMatrix y_samples, SGDParam curr_param) {
int nbr_samples = x_samples.rows;
DoubleMatrix curr_w = ((LRParam)curr_param).w;
DoubleMatrix curr_b = ((LRParam)curr_param).b;
DoubleMatrix curr_predict_y = x_samples.mmul(curr_w.transpose()).addiRowVector(curr_b);
softmax(curr_predict_y);
DoubleMatrix delta_b = y_samples.sub(curr_predict_y);
DoubleMatrix delta_w = delta_b.transpose().mmul(x_samples);
delta_b = delta_b.columnSums().divi(nbr_samples);
delta_w.divi(nbr_samples);
if (config.isUseRegularization()) {
if (0 != config.getLamada1()) {
delta_w.addi(MatrixFunctions.signum(curr_w).mmuli(config.getLamada1()));
delta_b.addi(MatrixFunctions.signum(curr_b).transpose().mmuli(config.getLamada1()));
}
if (0 != config.getLamada2()) {
delta_w.addi(curr_w.mmul(config.getLamada2()));
delta_b.addi(curr_b.transpose().mmul(config.getLamada2()));
}
}
curr_w.addi(delta_w.muli(config.getLearningRate()));
curr_b.addi(delta_b.transpose().muli(config.getLearningRate()));
}