DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder, train);
// grow a forest with m = log2(M)+1
ForestPredictions errorM = new ForestPredictions(train.size(), nblabels); // oob error when using m =
// log2(M)+1
treeBuilder.setM(m);
long time = System.currentTimeMillis();
log.info("Growing a forest with m={}", m);
DecisionForest forestM = forestBuilder.build(nbtrees, errorM);
sumTimeM += System.currentTimeMillis() - time;
numNodesM += forestM.nbNodes();
double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate
// when m = log2(M)+1
// grow a forest with m=1
ForestPredictions errorOne = new ForestPredictions(train.size(), nblabels); // oob error when using m = 1
treeBuilder.setM(1);
time = System.currentTimeMillis();
log.info("Growing a forest with m=1");
DecisionForest forestOne = forestBuilder.build(nbtrees, errorOne);
sumTimeOne += System.currentTimeMillis() - time;
numNodesOne += forestOne.nbNodes();
double oobOne = ErrorEstimate.errorRate(trainLabels, errorOne.computePredictions(rng)); // oob error
// estimate when m
// = 1
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
// using the lowest oob error forest
ForestPredictions testError = new ForestPredictions(test.size(), nblabels); // test set error
MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees); // mean tree error
// compute the test set error using m=1 (Single Input Error)
errorOne = new ForestPredictions(test.size(), nblabels);
if (oobM < oobOne) {
forestM.classify(test, new MultiCallback(testError, treeError));
forestOne.classify(test, errorOne);
} else {
forestOne.classify(test, new MultiCallback(testError, treeError, errorOne));
}
sumTestErr += ErrorEstimate.errorRate(testLabels, testError.computePredictions(rng));
sumOneErr += ErrorEstimate.errorRate(testLabels, errorOne.computePredictions(rng));
sumTreeErr += treeError.meanTreeError();
}