Problem trainProb = null;
Problem testProb = null;
Map<Double, Double> cc = null;
Parameter linearParams = params.getParamsCopy();
/*
* We set the weights here, because LibLINEAR crashes when there are weights for labels that are in the total set, but not in the train set.
* However in the case of cross-validation for parameter optimisation we cannot make sure that this does not happen, since the CV is done by the LibLINEAR CV method.
* In the future we should implement our own CV. Moreover, we introduce a slight bias in this way, since the weights are computed over train+test in the CV case.
*
*/
if (params.isDoCrossValidation()) {
target = prob.y;
cc = EvaluationUtils.computeClassCounts(EvaluationUtils.doubles2target(prob.y));
/*
if (params.isDoWeightLabels()) {
linearParams.setWeights(EvaluationUtils.computeWeights(EvaluationUtils.doubles2target(prob.y)), EvaluationUtils.computeWeightLabels(EvaluationUtils.doubles2target(prob.y)));
}
*/
} else {
trainProb = createProblemTrainSplit(prob, params.getSplitFraction());
testProb = createProblemTestSplit(prob, params.getSplitFraction());
target = testProb.y;
cc = EvaluationUtils.computeClassCounts(EvaluationUtils.doubles2target(trainProb.y));
/*
if (params.isDoWeightLabels()) {
linearParams.setWeights(EvaluationUtils.computeWeights(EvaluationUtils.doubles2target(trainProb.y)), EvaluationUtils.computeWeightLabels(EvaluationUtils.doubles2target(trainProb.y)));
}
*/
}
if (params.isDoWeightLabels()) {
List<Double> wl = new ArrayList<Double>();
List<Integer> wll = new ArrayList<Integer>();
for (int i = 0; i < params.getWeightLabels().length; i++) {
if (cc.containsKey((double) params.getWeightLabels()[i])) {
wll.add(params.getWeightLabels()[i]);
wl.add(params.getWeights()[i]);
}
}
linearParams.setWeights(EvaluationUtils.target2Doubles(wl), EvaluationUtils.target2Integers(wll));
}
double score = 0, bestScore = 0, bestC = 0, bestP = 0;
for (double p : params.getPs()) {
linearParams.setP(p);
for (double c : params.getCs()) {
linearParams.setC(c);
if (params.isDoCrossValidation()) {
prediction = crossValidate(prob, linearParams, params.getNumFolds());
} else {
prediction = testLinearModel(new LibLINEARModel(Linear.train(trainProb, linearParams)), testProb.x);
}
score = params.getEvalFunction().computeScore(target, prediction);
if (bestC == 0 || params.getEvalFunction().isBetter(score, bestScore)) {
bestC = c;
bestP = p;
bestScore = score;
}
}
}
linearParams.setC(bestC);
System.out.println("Training model for C: " + bestC + " and P (SVR only): " + bestP);
return new LibLINEARModel(Linear.train(prob, linearParams));
}