}
}).countByValue().size();
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
// Train a GradientBoosting model for classification.
final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double trainErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
System.out.println("Training error: " + trainErr);
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double trainMSE =
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override public Double call(Tuple2<Double, Double> pl) {