value = 0.0;
//initialize any variables
DoubleAD[] derivativeAD = new DoubleAD[x.length];
for (int i = 0; i < x.length;i++) {
derivativeAD[i] = new DoubleAD(0.0,0.0);
}
DoubleAD[] xAD = new DoubleAD[x.length];
for (int i = 0; i < x.length;i++){
xAD[i] = new DoubleAD(x[i],v[i]);
}
// Initialize the sums
DoubleAD[] sums = new DoubleAD[numClasses];
for (int c = 0; c<numClasses;c++){
sums[c] = new DoubleAD(0,0);
}
DoubleAD[] probs = new DoubleAD[numClasses];
for (int c = 0; c<numClasses;c++) {
probs[c] = new DoubleAD(0,0);
}
//long curTime = System.currentTimeMillis();
// Copy the Derivative numerator, and set up the vector V to be used for Hess*V
for (int i = 0; i < x.length;i++){
xAD[i].set(x[i],v[i]);
derivativeAD[i].set(0.0,0.0);
}
//System.err.print(System.currentTimeMillis() - curTime + " - ");
//curTime = System.currentTimeMillis();
for (int d = 0; d <batch.length ; d++) {
//Sets the index based on the current batch
int m = (curElement + d) % data.length;
int[] features = data[m];
for (int c = 0; c<numClasses;c++){
sums[c].set(0.0,0.0);
}
for (int c = 0; c < numClasses; c++) {
for (int feature : features) {
int i = indexOf(feature, c);
sums[c] = ADMath.plus(sums[c], xAD[i]);
}
}
DoubleAD total = ADMath.logSum(sums);
for (int c = 0; c < numClasses; c++) {
probs[c] = ADMath.exp( ADMath.minus(sums[c], total) );
if (dataWeights != null) {
probs[c] = ADMath.multConst(probs[c], dataWeights[d]);
}
for (int feature : features) {
int i = indexOf(feature, c);
if (c == labels[m]) {
derivativeAD[i].plusEqualsConst(-1.0);
}
derivativeAD[i].plusEquals(probs[c]);
}
}
double dV = sums[labels[m]].getval() - total.getval();
if (dataWeights != null) {
dV *= dataWeights[d];
}
value -= dV;
}