HashMap<Integer,double[][]> constraints = FSTConstraintUtil.loadGEConstraints(constraintsFile,trainingData);
if (learningOption.value.equalsIgnoreCase("ge")) {
ArrayList<GEConstraint> constraintsList = new ArrayList<GEConstraint>();
if (penaltyOption.value.equalsIgnoreCase("kl")) {
OneLabelKLGEConstraints geConstraints = new OneLabelKLGEConstraints();
for (int fi : constraints.keySet()) {
double[][] dist = constraints.get(fi);
boolean allSame = true;
double sum = 0;
double[] prob = new double[dist.length];
for (int li = 0; li < dist.length; li++) {
prob[li] = dist[li][0];
if (!Maths.almostEquals(dist[li][0],dist[li][1])) {
allSame = false;
break;
}
else if (Double.isInfinite(prob[li])) {
prob[li] = 0;
}
sum += prob[li];
}
if (!allSame) {
throw new RuntimeException("A KL divergence penalty cannot be used with target ranges!");
}
if (!Maths.almostEquals(sum, 1)) {
throw new RuntimeException("Targets must sum to 1 when using a KL divergence penalty!");
}
geConstraints.addConstraint(fi, prob, 1);
}
constraintsList.add(geConstraints);
}
else if (penaltyOption.value.equalsIgnoreCase("l2")) {
OneLabelL2RangeGEConstraints geConstraints = new OneLabelL2RangeGEConstraints();
for (int fi : constraints.keySet()) {
double[][] dist = constraints.get(fi);
for (int li = 0; li < dist.length; li++) {
if (!Double.isInfinite(dist[li][0])) {
geConstraints.addConstraint(fi, li, dist[li][0], dist[li][1], 1);
}
}
}
constraintsList.add(geConstraints);
}