String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double trainErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
System.out.println("Training error: " + trainErr);
System.out.println("Learned classification tree model:\n" + model);
// Train a DecisionTree model for regression.
impurity = "variance";
final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> regressorPredictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(regressionModel.predict(p.features()), p.label());
}
});
Double trainMSE =
regressorPredictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override public Double call(Tuple2<Double, Double> pl) {