/**
* Assumes that praphrasederivations are scored and sorted
*/
private void setEvaluation(ParsingExample ex, Params params) {
final Evaluation eval = new Evaluation();
int numCandidates = ex.predParaDeriv.size();
LogInfo.begin_track_printAll("Parser.setEvaluation: %d candidates", numCandidates);
// Each derivation has a compatibility score (in [0,1]) as well as a model probability.
// Terminology:
// True (correct): compatibility = 1
// Partial: 0 < compatibility < 1
// Wrong: compatibility = 0
List<ParaphraseDerivation> predictions = ex.predParaDeriv;
// Make sure derivations are executed
for (ParaphraseDerivation paraphraseDeriv : predictions) {
paraphraseDeriv.ensureExecuted(executor);
}
// Did we get the answer correct?
int correct_i = -1; // Index of first correct derivation
double maxCompatibility = 0.0;
double[] compatibilities = null;
if (ex.targetValue != null) {
compatibilities = new double[numCandidates];
for (int i = 0; i < numCandidates; i++) {
ParaphraseDerivation proofDeriv = predictions.get(i);
compatibilities[i] = proofDeriv.compatibility = ex.targetValue.getCompatibility(proofDeriv.value);
// Must be fully compatible to count as correct.
if (compatibilities[i] == 1 && correct_i == -1)
correct_i = i;
//record maximum compatibility for partial oracle
maxCompatibility = Math.max(compatibilities[i], maxCompatibility);
}
}
// Compute probabilities
double[] probs = ParaphraseDerivation.getProbs(predictions, 1);
for (int i = 0; i < numCandidates; i++) {
predictions.get(i).prob = probs[i];
}
// List<Pair<Value, DoubleContainer>> valueList = computeValueList(ex.predParaDeriv);
// evaluateValues(eval, ex, valueList);
// Number of derivations which have the same top score
int numTop = 0;
double topMass = 0;
if (ex.targetValue != null) {
while (numTop < numCandidates &&
compatibilities[numTop] > 0.0d &&
Math.abs(predictions.get(numTop).score - predictions.get(0).score) < 1e-10) {
topMass += probs[numTop];
numTop++;
}
}
double correct = 0;
double partial_correct = 0;
if (ex.targetValue != null) {
for (int i = 0; i < numTop; i++) {
if (compatibilities[i] == 1) correct += probs[i] / topMass;
if (compatibilities[i] > 0)
partial_correct += (compatibilities[i] * probs[i]) / topMass;
}
}
// Print features (note this is only with respect to the first correct, is NOT the gradient).
// Things are not printed if there is only partial compatability.
if (correct_i != -1 && correct != 1) {
ParaphraseDerivation trueDeriv = predictions.get(correct_i);
ParaphraseDerivation predDeriv = predictions.get(0);
HashMap<String, Double> featureDiff = new HashMap<>();
trueDeriv.incrementAllFeatureVector(+1, featureDiff); //TODO if features will go out of proof this needs to change
predDeriv.incrementAllFeatureVector(-1, featureDiff);
String heading = String.format("TopTrue (%d) - Pred (%d) = Diff", correct_i, 0);
FeatureVector.logFeatureWeights(heading, featureDiff, params);
}
// Fully correct
for (int i = 0; i < predictions.size(); i++) {
ParaphraseDerivation deriv = predictions.get(i);
if (compatibilities != null && compatibilities[i] == 1) {
LogInfo.logs(
"True@%04d: %s [score=%s, prob=%s%s]", i, deriv.toString(),
Fmt.D(deriv.score), Fmt.D(probs[i]), compatibilities != null ? ", comp=" + Fmt.D(compatibilities[i]) : "");
}
}
// Partially correct
for (int i = 0; i < predictions.size(); i++) {
ParaphraseDerivation deriv = predictions.get(i);
if (compatibilities != null && compatibilities[i] > 0 && compatibilities[i] < 1) {
LogInfo.logs(
"Part@%04d: %s [score=%s, prob=%s%s]", i, deriv.toString(),
Fmt.D(deriv.score), Fmt.D(probs[i]), compatibilities != null ? ", comp=" + Fmt.D(compatibilities[i]) : "");
}
}
// Anything that's predicted.
for (int i = 0; i < predictions.size(); i++) {
ParaphraseDerivation deriv = predictions.get(i);
// Either print all predictions or this prediction is worse by some amount.
boolean print;
print = probs[i] >= probs[0] / 2 || i < 10;
if (print) {
LogInfo.logs(
"Pred@%04d: %s [score=%s, prob=%s%s]", i, deriv.toString(),
Fmt.D(deriv.score), Fmt.D(probs[i]), compatibilities != null ? ", comp=" + Fmt.D(compatibilities[i]) : "");
}
}
eval.add("correct", correct);
eval.add("oracle", correct_i != -1);
eval.add("partCorrect", partial_correct);
eval.add("partOracle", maxCompatibility);
eval.add("numCandidates", numCandidates); // From this parse
if (predictions.size() > 0)
eval.add("parsedNumCandidates", predictions.size());
for (ParaphraseDerivation deriv : predictions) {
if (deriv.executorStats != null)
eval.add(deriv.executorStats);
}
// Finally, set all of these stats as the example's evaluation.
ex.setEvaluation(eval);
LogInfo.end_track();
}