System.out.println("Randomizing and stratifying segments.");
CrossValidationPrep<Segment> validationPrep = new CrossValidationPrep<Segment>();
validationPrep.randomize(segmentSet.getSegments());
segmentSet.setSegments(validationPrep.stratify(segmentSet.getSegments(), folds));
evaluationReport = new EvaluationReport(this.folds + " Cross Validation", falsePositiveCost, falseNegativeCost);
for (int foldIndex = 0; foldIndex < this.folds; foldIndex++) {
System.out.println("- Starting fold " + (foldIndex + 1));
//Split the data
CrossValidationSplit<Segment> split = new CrossValidationSplit<Segment>();
System.out.println("- Splitting out training set");
SegmentSet trainingSegments = new SegmentSet();
trainingSegments.setSegments(split.getTrainingForFold(segmentSet.getSegments(), foldIndex, this.folds));
if (getBalancingImpl() != null) {
trainingSegments = getBalancingImpl().balance(trainingSegments);
}
System.out.println("- Extracting basic features from training set");
ExampleSet basicTrainingExamples = trainingSegments.getBasicExamples();
trainingSegments = null;
FeatureGeneration generation = getFeatureGenerationImpl();
System.out.println("- Generating features");
FeatureSpecification spec = generation.generateFeatures(basicTrainingExamples);
FeatureExtraction extraction = getFeatureExtractionImpl();
System.out.println("- Extracting features from training set");
ExampleSet trainingSet = extraction.extractFeatures(basicTrainingExamples, spec);
basicTrainingExamples = null;
Training training = getTrainingImpl();
Model model = training.train(trainingSet);
trainingSet = null;
System.out.println("- Splitting out test set");
SegmentSet testingSegments = new SegmentSet();
testingSegments.setSegments(split.getTestingForFold(segmentSet.getSegments(), foldIndex, this.folds));
if (getBalancingImpl() != null && balanceTestSet) {
testingSegments = getBalancingImpl().balance(testingSegments);
}
System.out.println("- Extracting basic features from test set");
ExampleSet basicTestingExamples = testingSegments.getBasicExamples();
System.out.println("- Extracting features from test set");
ExampleSet testingSet = extraction.extractFeatures(basicTestingExamples, spec);
basicTestingExamples = null;
Predictions predictions = model.getPredictions(testingSet);
EvaluationReport report = new EvaluationReport("Fold " + (foldIndex + 1), falsePositiveCost, falseNegativeCost);
report.addPredictions(predictions);
LabelMapping mapping = getMappingImpl();
mapping.map(predictions, testingSegments);
report.addLabeledTestData(testingSegments);
evaluationReport.addPartial(report);
int numCorrect = report.getTrueNegativeCount() + report.getTruePositiveCount();
System.out.println("- Fold " + (foldIndex + 1) + " completed (" + numCorrect + "/" + testingSet.size() + " correct).");
System.out.println();
}
} else {
System.out.println("== Skipping Cross Validation ==");