Map<Integer, double[]> dropoutPriorGradFirstHalf = sparseE(activeFeatures);
if (TIMED)
System.err.println("activeFeatures size: "+activeFeatures.length + ", dataLen: " + docData.length);
Timing timer = new Timing();
if (TIMED)
timer.start();
double priorValue = 0;
long elapsedMs = 0;
Pair<double[][][], double[][][]> condProbs = getCondProbs(cliqueTree, docData);
if (TIMED) {
elapsedMs = timer.stop();
System.err.println("\t cond prob took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
}
// first index position is curr index, second index curr-class, third index prev-class
// e.g. [1][2][3] means curr is at position 1 with class 2, prev is at position 0 with class 3
double[][][] prevGivenCurr = condProbs.first();
// first index position is curr index, second index curr-class, third index next-class
// e.g. [0][2][3] means curr is at position 0 with class 2, next is at position 1 with class 3
double[][][] nextGivenCurr = condProbs.second();
// first dim is doc length (i)
// second dim is numOfFeatures (fIndex)
// third dim is numClasses (y)
// fourth dim is labelIndexSize (matching the clique type of fIndex, for \theta)
double[][][][] FAlpha = null;
double[][][][] FBeta = null;
if (!dropoutApprox) {
FAlpha = new double[docData.length][][][];
FBeta = new double[docData.length][][][];
}
for (int i = 0; i < docData.length; i++) {
if (!dropoutApprox) {
FAlpha[i] = new double[activeFeatures.length][][];
FBeta[i] = new double[activeFeatures.length][][];
}
}
if (!dropoutApprox) {
if (TIMED) {
timer.start();
}
// computing FAlpha
int fIndex = 0;
double aa, bb, cc = 0;
boolean prevFeaturePresent = false;
for (int i = 1; i < docData.length; i++) {
// for each possible clique at this position
Set<Integer> docDataHashIMinusOne = docDataHash.get(i-1);
for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
fIndex = activeFeatures[fIndexPos];
prevFeaturePresent = docDataHashIMinusOne.contains(fIndex);
int j = map[fIndex];
Index<CRFLabel> labelIndex = labelIndices.get(j);
int labelIndexSize = labelIndex.size();
if (FAlpha[i-1][fIndexPos] == null) {
FAlpha[i-1][fIndexPos] = new double[numClasses][labelIndexSize];
for (int q = 0; q < numClasses; q++)
FAlpha[i-1][fIndexPos][q] = new double[labelIndexSize];
}
for (Map.Entry<Integer, List<Integer>> entry : currPrevLabelsMap.entrySet()) {
int y = entry.getKey(); // value at i-1
double[] sum = new double[labelIndexSize];
for (int yPrime: entry.getValue()) { // value at i-2
for (int kk = 0; kk < labelIndexSize; kk++) {
int[] prevLabel = labelIndex.get(kk).getLabel();
aa = (prevGivenCurr[i-1][y][yPrime]);
bb = (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0);
cc = 0;
if (FAlpha[i-1][fIndexPos][yPrime] != null)
cc = FAlpha[i-1][fIndexPos][yPrime][kk];
sum[kk] += aa * (bb + cc);
// sum[kk] += (prevGivenCurr[i-1][y][yPrime]) * ((prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) + FAlpha[i-1][fIndexPos][yPrime][kk]);
if (DEBUG2)
System.err.printf("alpha[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f), prevLabel=%s\n", i, fIndex, y, kk, (prevGivenCurr[i-1][y][yPrime]), (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) , FAlpha[i-1][fIndexPos][yPrime][kk], Arrays.toString(prevLabel));
}
}
if (FAlpha[i][fIndexPos] == null) {
FAlpha[i][fIndexPos] = new double[numClasses][];
}
FAlpha[i][fIndexPos][y] = sum;
if (DEBUG2)
System.err.println("FAlpha["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));
}
}
}
if (TIMED) {
elapsedMs = timer.stop();
System.err.println("\t alpha took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
timer.start();
}
// computing FBeta
int docDataLen = docData.length;
for (int i = docDataLen-2; i >= 0; i--) {
Set<Integer> docDataHashIPlusOne = docDataHash.get(i+1);
// for each possible clique at this position
for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
fIndex = activeFeatures[fIndexPos];
boolean nextFeaturePresent = docDataHashIPlusOne.contains(fIndex);
int j = map[fIndex];
Index<CRFLabel> labelIndex = labelIndices.get(j);
int labelIndexSize = labelIndex.size();
if (FBeta[i+1][fIndexPos] == null) {
FBeta[i+1][fIndexPos] = new double[numClasses][labelIndexSize];
for (int q = 0; q < numClasses; q++)
FBeta[i+1][fIndexPos][q] = new double[labelIndexSize];
}
for (Map.Entry<Integer, List<Integer>> entry : currNextLabelsMap.entrySet()) {
int y = entry.getKey(); // value at i
double[] sum = new double[labelIndexSize];
for (int yPrime: entry.getValue()) { // value at i+1
for (int kk=0; kk < labelIndexSize; kk++) {
int[] nextLabel = labelIndex.get(kk).getLabel();
// System.err.println("labelIndexSize:"+labelIndexSize+", nextGivenCurr:"+nextGivenCurr+", nextLabel:"+nextLabel+", FBeta["+(i+1)+"]["+ fIndexPos +"]["+yPrime+"] :"+FBeta[i+1][fIndexPos][yPrime]);
aa = (nextGivenCurr[i][y][yPrime]);
bb = (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0);
cc = 0;
if (FBeta[i+1][fIndexPos][yPrime] != null)
cc = FBeta[i+1][fIndexPos][yPrime][kk];
sum[kk] += aa * ( bb + cc);
// sum[kk] += (nextGivenCurr[i][y][yPrime]) * ( (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0) + FBeta[i+1][fIndexPos][yPrime][kk]);
if (DEBUG2)
System.err.printf("beta[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f)\n", i, fIndex, y, kk, (nextGivenCurr[i][y][yPrime]), (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0), FBeta[i+1][fIndexPos][yPrime][kk]);
}
}
if (FBeta[i][fIndexPos] == null) {
FBeta[i][fIndexPos] = new double[numClasses][];
}
FBeta[i][fIndexPos][y] = sum;
if (DEBUG2)
System.err.println("FBeta["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));
}
}
}
if (TIMED) {
elapsedMs = timer.stop();
System.err.println("\t beta took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
}
}
if (TIMED) {
timer.start();
}
// derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * (1-PtYYp)'
// derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * -PtYYp'
// derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1 - 2 * PtYYp)
double deltaDivByOneMinusDelta = delta / (1.0-delta);
Timing innerTimer = new Timing();
long eTiming = 0;
long dropoutTiming= 0;
boolean containsFeature = false;
// iterate over the positions in this document
for (int i = 1; i < docData.length; i++) {
Set<Integer> docDataHashI = docDataHash.get(i);
Map<Integer, double[]> EForADocPosAtI = null;
if (dropoutApprox)
EForADocPosAtI = EForADocPos.get(i);
// for each possible clique at this position
for (int k = 0; k < edgeLabelIndexSize; k++) { // sum over (y, y')
int[] label = edgeLabels[k];
int y = label[0];
int yP = label[1];
if (TIMED)
innerTimer.start();
// important to use label as an int[] for calculating cliqueTree.prob()
// if it's a node clique, and label index is 2, if we don't use int[]{2} but just pass 2,
// cliqueTree is going to treat it as index of the edge clique labels, and convert 2
// into int[]{0,2}, and return the edge prob marginal instead of node marginal
double PtYYp = cliqueTree.prob(i, label);
double PtYYpTimesOneMinusPtYYp = PtYYp * (1.0 - PtYYp);
double oneMinus2PtYYp = (1.0 - 2 * PtYYp);
double USum = 0;
int fIndex;
for (int jjj=0; jjj<labelIndices.size(); jjj++) {
for (int n = 0; n < docData[i][jjj].length; n++) {
fIndex = docData[i][jjj][n];
int valIndex;
if (jjj == 1)
valIndex = k;
else
valIndex = yP;
double theta;
try {
theta = weights[fIndex][valIndex];
}catch (Exception ex) {
System.err.printf("weights[%d][%d], map[%d]=%d, labelIndices.get(map[%d]).size() = %d, weights.length=%d\n", fIndex, valIndex, fIndex, map[fIndex], fIndex, labelIndices.get(map[fIndex]).size(), weights.length);
throw new RuntimeException(ex);
}
USum += weightSquare[fIndex][valIndex];
// first half of derivative: VarU' * PtYYp * (1-PtYYp)
double VarUp = deltaDivByOneMinusDelta * theta;
increScoreAllowNull(dropoutPriorGradFirstHalf, fIndex, valIndex, VarUp * PtYYpTimesOneMinusPtYYp);
}
}
if (TIMED) {
eTiming += innerTimer.stop();
innerTimer.start();
}
double VarU = 0.5 * deltaDivByOneMinusDelta * USum;
// update function objective
priorValue += VarU * PtYYpTimesOneMinusPtYYp;
double VarUTimesOneMinus2PtYYp = VarU * oneMinus2PtYYp;
// second half of derivative: VarU * PtYYp' * (1 - 2 * PtYYp)
// boolean prevFeaturePresent = false;
// boolean nextFeaturePresent = false;
for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
fIndex = activeFeatures[fIndexPos];
containsFeature = docDataHashI.contains(fIndex);
// if (!containsFeature) continue;
int jj = map[fIndex];
Index<CRFLabel> fLabelIndex = labelIndices.get(jj);
for (int kk = 0; kk < fLabelIndex.size(); kk++) { // for all parameter \theta
int[] fLabel = fLabelIndex.get(kk).getLabel();
// if (FAlpha[i] != null)
// System.err.println("fIndex: " + fIndex+", FAlpha[i].size:"+FAlpha[i].length);
double fCount = containsFeature && ((jj == 0 && fLabel[0] == yP) || (jj == 1 && k == kk)) ? 1 : 0;
double alpha;
double beta;
double condE;
double PtYYpPrime;
if (!dropoutApprox) {
alpha = ((FAlpha[i][fIndexPos] == null || FAlpha[i][fIndexPos][y] == null) ? 0 : FAlpha[i][fIndexPos][y][kk]);
beta = ((FBeta[i][fIndexPos] == null || FBeta[i][fIndexPos][yP] == null) ? 0 : FBeta[i][fIndexPos][yP][kk]);
condE = fCount + alpha + beta;
if (DEBUG2)
System.err.printf("fLabel=%s, yP = %d, fCount:%f = ((jj == 0 && fLabel[0] == yP)=%b || (jj == 1 && k == kk))=%b\n", Arrays.toString(fLabel),yP, fCount,(jj == 0 && fLabel[0] == yP) , (jj == 1 && k == kk));
PtYYpPrime = PtYYp * (condE - EForADoc.get(fIndex)[kk]);
} else {
double E = 0;
if (EForADocPosAtI.containsKey(fIndex))
E = EForADocPosAtI.get(fIndex)[kk];
condE = fCount;
PtYYpPrime = PtYYp * (condE - E);
}
if (DEBUG2)
System.err.printf("for i=%d, k=%d, y=%d, yP=%d, fIndex=%d, kk=%d, PtYYpPrime=% 5.3f, PtYYp=% 3.3f, (condE-E[fIndex][kk])=% 3.3f, condE=% 3.3f, E[fIndex][k]=% 3.3f, alpha=% 3.3f, beta=% 3.3f, fCount=% 3.3f\n", i, k, y, yP, fIndex, kk, PtYYpPrime, PtYYp, (condE - EForADoc.get(fIndex)[kk]), condE, EForADoc.get(fIndex)[kk], alpha, beta, fCount);
increScore(dropoutPriorGrad, fIndex, kk, VarUTimesOneMinus2PtYYp * PtYYpPrime);
}
if (DEBUG2)
System.err.println();
}
if (TIMED)
dropoutTiming += innerTimer.stop();
}
}
if (CONDENSE) {
// copy for condensedFeaturesMap
for (Map.Entry<Integer, List<Integer>> entry: condensedFeaturesMap.entrySet()) {