}
@Override
protected void gradientUpdateMiniBatch(SGDTrainConfig config, DoubleMatrix x_samples, DoubleMatrix y_samples, SGDParam curr_param) {
int nbr_sample = x_samples.getRows();
DoubleMatrix curr_w = ((HiddenLayerParam)curr_param).w;
DoubleMatrix curr_hbias = ((HiddenLayerParam)curr_param).hbias;
DoubleMatrix curr_vbias = ((HiddenLayerParam)curr_param).vbias;
DoubleMatrix v1_sample = x_samples;
DoubleMatrix h1_probability = new DoubleMatrix(nbr_sample, n_hidden);
DoubleMatrix h1_sample = new DoubleMatrix(nbr_sample, n_hidden);
DoubleMatrix v2_probability = new DoubleMatrix(nbr_sample, n_visible);
DoubleMatrix v2_sample = new DoubleMatrix(nbr_sample, n_visible);
DoubleMatrix h2_probability = new DoubleMatrix(nbr_sample, n_hidden);
//DoubleMatrix nh_samples = new DoubleMatrix(nbr_sample, n_hidden);
sample_h_given_v(v1_sample, h1_probability, h1_sample, curr_w, curr_hbias);
if(config.isUseHintonCD1()) {
sample_v_given_h(h1_sample, v2_probability, null, curr_w, curr_vbias);
sample_h_given_v(v2_probability, h2_probability, null, curr_w, curr_hbias);
}
else {
sample_v_given_h(h1_sample, v2_probability, v2_sample, curr_w, curr_vbias);
sample_h_given_v(v2_sample, h2_probability, null, curr_w, curr_hbias);
}
DoubleMatrix delta_w = null;
DoubleMatrix delta_hbias = null;
DoubleMatrix delta_vbias = null;
if(config.isUseHintonCD1()) {
delta_w = h1_probability.transpose().mmul(v1_sample).subi(h2_probability.transpose().mmul(v2_probability));
delta_hbias = h1_probability.sub(h2_probability).columnSums().divi(nbr_sample);
delta_vbias = v1_sample.sub(v2_probability).columnSums().divi(nbr_sample);
}
else {
delta_w = h1_sample.transpose().mmul(v1_sample).subi(h2_probability.transpose().mmul(v2_sample));
delta_hbias = h1_sample.sub(h2_probability).columnSums().divi(nbr_sample);
delta_vbias = v1_sample.sub(v2_sample).columnSums().divi(nbr_sample);
}
if (config.isUseRegularization()) {
//only L2 for RBM
if (0 != config.getLamada2()) {
delta_w.subi(curr_w.mul(config.getLamada2()));
}
}
delta_w.divi(nbr_sample);
curr_w.addi(delta_w.muli(config.getLearningRate()));
curr_hbias.addi(delta_hbias.transpose().muli(config.getLearningRate()));
curr_vbias.addi(delta_vbias.transpose().muli(config.getLearningRate()));
}