m = (int) Math.ceil(Math.sqrt(e));
}
}
if (data.isEmpty()) {
return new Leaf(Double.NaN);
}
double sum = 0.0;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
// sum and sum squared of a label is computed
double sumSquared = 0.0;
for (int i = 0; i < data.size(); i++) {
double label = data.getDataset().getLabel(data.get(i));
sum += label;
sumSquared += label * label;
}
// computes the variance
double var = sumSquared - (sum * sum) / data.size();
// computes the minimum variance
if (Double.compare(minVariance, Double.NaN) == 0) {
minVariance = var / data.size() * minVarianceProportion;
log.debug("minVariance:{}", minVariance);
}
// variance is compared with minimum variance
if ((var / data.size()) < minVariance) {
log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size());
return new Leaf(sum / data.size());
}
} else {
// classification
if (isIdentical(data)) {
return new Leaf(data.majorityLabel(rng));
}
if (data.identicalLabel()) {
return new Leaf(data.getDataset().getLabel(data.get(0)));
}
}
// store full set data
if (fullSet == null) {
fullSet = data;
}
int[] attributes = randomAttributes(rng, selected, m);
if (attributes == null || attributes.length == 0) {
// we tried all the attributes and could not split the data anymore
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
label = sum / data.size();
} else {
// classification
label = data.majorityLabel(rng);
}
log.warn("attribute which can be selected is not found Leaf({})", label);
return new Leaf(label);
}
if (igSplit == null) {
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
// regression
igSplit = new RegressionSplit();
} else {
// classification
igSplit = new OptIgSplit();
}
}
// find the best split
Split best = null;
for (int attr : attributes) {
Split split = igSplit.computeSplit(data, attr);
if (best == null || best.getIg() < split.getIg()) {
best = split;
}
}
// information gain is near to zero.
if (best.getIg() < EPSILON) {
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("ig is near to zero Leaf({})", label);
return new Leaf(label);
}
log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg());
boolean alreadySelected = selected[best.getAttr()];
if (alreadySelected) {
// attribute already selected
log.warn("attribute {} already selected in a parent node", best.getAttr());
}
Node childNode;
if (data.getDataset().isNumerical(best.getAttr())) {
boolean[] temp = null;
Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
if (loSubset.isEmpty() || hiSubset.isEmpty()) {
// the selected attribute did not change the data, avoid using it in the child notes
selected[best.getAttr()] = true;
} else {
// the data changed, so we can unselect all previousely selected NUMERICAL attributes
temp = selected;
selected = cloneCategoricalAttributes(data.getDataset(), selected);
}
// size of the subset is less than the minSpitNum
if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {
// branch is not split
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("branch is not split Leaf({})", label);
return new Leaf(label);
}
Node loChild = build(rng, loSubset);
Node hiChild = build(rng, hiSubset);
// restore the selection state of the attributes
if (temp != null) {
selected = temp;
} else {
selected[best.getAttr()] = alreadySelected;
}
childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
} else { // CATEGORICAL attribute
double[] values = data.values(best.getAttr());
// tree is complemented
Collection<Double> subsetValues = null;
if (complemented) {
subsetValues = Sets.newHashSet();
for (double value : values) {
subsetValues.add(value);
}
values = fullSet.values(best.getAttr());
}
int cnt = 0;
Data[] subsets = new Data[values.length];
for (int index = 0; index < values.length; index++) {
if (complemented && !subsetValues.contains(values[index])) {
continue;
}
subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
if (subsets[index].size() >= minSplitNum) {
cnt++;
}
}
// size of the subset is less than the minSpitNum
if (cnt < 2) {
// branch is not split
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("branch is not split Leaf({})", label);
return new Leaf(label);
}
selected[best.getAttr()] = true;
Node[] children = new Node[values.length];
for (int index = 0; index < values.length; index++) {
if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
// tree is complemented
double label;
if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
label = sum / data.size();
} else {
label = data.majorityLabel(rng);
}
log.debug("complemented Leaf({})", label);
children[index] = new Leaf(label);
continue;
}
children[index] = build(rng, subsets[index]);
}