Package org.apache.mahout.classifier.df.node

Examples of org.apache.mahout.classifier.df.node.Leaf


        buff.append("|   ");
      }
      buff.append((attrNames == null ? attr : attrNames[attr]) + " >= " + doubleToString(split));
      buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
    } else if (node instanceof Leaf) {
      Leaf leaf = (Leaf) node;
      double label = (Double) fields.get("Leaf.label").get(leaf);
      if (dataset.isNumerical(dataset.getLabelId())) {
        buff.append(" : ").append(doubleToString(label));
      } else {
        buff.append(" : ").append(dataset.getLabelString((int) label));
View Full Code Here


          + doubleToString(instance.get(attr)) + ") >= " + doubleToString(split));
        buff.append(" -> ");
        buff.append(toStringPredict(hiChild, instance, dataset, attrNames, fields));
      }
    } else if (node instanceof Leaf) {
      Leaf leaf = (Leaf) node;
      double label = (Double) fields.get("Leaf.label").get(leaf);
      if (dataset.isNumerical(dataset.getLabelId())) {
        buff.append(doubleToString(label));
      } else {
        buff.append(dataset.getLabelString((int) label));
View Full Code Here

        m = (int) Math.ceil(Math.sqrt(e));
      }
    }

    if (data.isEmpty()) {
      return new Leaf(-1);
    }

    double sum = 0.0;
    if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
      // regression
      // sum and sum squared of a label is computed
      double sumSquared = 0.0;
      for (int i = 0; i < data.size(); i++) {
        double label = data.getDataset().getLabel(data.get(i));
        sum += label;
        sumSquared += label * label;
      }

      // computes the variance
      double var = sumSquared - (sum * sum) / data.size();

      // computes the minimum variance
      if (Double.compare(minVariance, Double.NaN) == 0) {
        minVariance = var / data.size() * minVarianceProportion;
        log.debug("minVariance:{}", minVariance);
      }

      // variance is compared with minimum variance
      if ((var / data.size()) < minVariance) {
        log.debug("variance(" + (var / data.size()) + ") < minVariance(" + minVariance + ") Leaf(" +
            (sum / data.size()) + ')');
        return new Leaf(sum / data.size());
      }
    } else {
      // classification
      if (isIdentical(data)) {
        return new Leaf(data.majorityLabel(rng));
      }
      if (data.identicalLabel()) {
        return new Leaf(data.getDataset().getLabel(data.get(0)));
      }
    }

    // store full set data
    if (fullSet == null) {
      fullSet = data;
    }

    int[] attributes = randomAttributes(rng, selected, m);
    if (attributes == null || attributes.length == 0) {
      // we tried all the attributes and could not split the data anymore
      double label;
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        // regression
        label = sum / data.size();
      } else {
        // classification
        label = data.majorityLabel(rng);
      }
      log.warn("attribute which can be selected is not found Leaf({})", label);
      return new Leaf(label);
    }

    if (igSplit == null) {
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        // regression
        igSplit = new RegressionSplit();
      } else {
        // classification
        igSplit = new OptIgSplit();
      }
    }

    // find the best split
    Split best = null;
    for (int attr : attributes) {
      Split split = igSplit.computeSplit(data, attr);
      if (best == null || best.getIg() < split.getIg()) {
        best = split;
      }
    }

    // information gain is near to zero.
    if (best.getIg() < EPSILON) {
      double label;
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        label = sum / data.size();
      } else {
        label = data.majorityLabel(rng);
      }
      log.debug("ig is near to zero Leaf({})", label);
      return new Leaf(label);
    }

    log.debug("best split attr:" + best.getAttr() + ", split:" + best.getSplit() + ", ig:"
        + best.getIg());

    boolean alreadySelected = selected[best.getAttr()];
    if (alreadySelected) {
      // attribute already selected
      log.warn("attribute {} already selected in a parent node", best.getAttr());
    }

    Node childNode;
    if (data.getDataset().isNumerical(best.getAttr())) {
      boolean[] temp = null;

      Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
      Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));

      if (loSubset.isEmpty() || hiSubset.isEmpty()) {
        // the selected attribute did not change the data, avoid using it in the child notes
        selected[best.getAttr()] = true;
      } else {
        // the data changed, so we can unselect all previousely selected NUMERICAL attributes
        temp = selected;
        selected = cloneCategoricalAttributes(data.getDataset(), selected);
      }

      // size of the subset is less than the minSpitNum
      if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {
        // branch is not split
        double label;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
          label = sum / data.size();
        } else {
          label = data.majorityLabel(rng);
        }
        log.debug("branch is not split Leaf({})", label);
        return new Leaf(label);
      }

      Node loChild = build(rng, loSubset);
      Node hiChild = build(rng, hiSubset);

      // restore the selection state of the attributes
      if (temp != null) {
        selected = temp;
      } else {
        selected[best.getAttr()] = alreadySelected;
      }

      childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
    } else { // CATEGORICAL attribute
      double[] values = data.values(best.getAttr());

      // tree is complemented
      Collection<Double> subsetValues = null;
      if (complemented) {
        subsetValues = Sets.newHashSet();
        for (double value : values) {
          subsetValues.add(value);
        }
        values = fullSet.values(best.getAttr());
      }

      int cnt = 0;
      Data[] subsets = new Data[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && !subsetValues.contains(values[index])) {
          continue;
        }
        subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
        if (subsets[index].size() >= minSplitNum) {
          cnt++;
        }
      }

      // size of the subset is less than the minSpitNum
      if (cnt < 2) {
        // branch is not split
        double label;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
          label = sum / data.size();
        } else {
          label = data.majorityLabel(rng);
        }
        log.debug("branch is not split Leaf({})", label);
        return new Leaf(label);
      }

      selected[best.getAttr()] = true;

      Node[] children = new Node[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
          // tree is complemented
          double label;
          if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
            label = sum / data.size();
          } else {
            label = data.majorityLabel(rng);
          }
          log.debug("complemented Leaf({})", label);
          children[index] = new Leaf(label);
          continue;
        }
        children[index] = build(rng, subsets[index]);
      }

View Full Code Here

  }
 
  @Test
  public void testForestVisualize() throws Exception {
    // Tree
    NumericalNode root = new NumericalNode(2, 90, new Leaf(0),
      new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
        new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
        new Leaf(0)}));
    List<Node> trees = new ArrayList<Node>();
    trees.add(root);
   
    // Forest
    DecisionForest forest = new DecisionForest(trees);
View Full Code Here

      partitions.add(partition);

      int nbTrees = Step1Mapper.nbTrees(NUM_MAPS, NUM_TREES, partition);

      for (int treeId = 0; treeId < nbTrees; treeId++) {
        Node tree = new Leaf(rng.nextInt(100));

        keys[index] = new TreeID(partition, treeId);
        values[index] = new MapredOutput(tree, nextIntArray(rng, NUM_INSTANCES));

        index++;
View Full Code Here

    public Node build(Random rng, Data data) {
      for (int index = 0; index < data.size(); index++) {
        assertTrue(expected.contains(data.get(index)));
      }

      return new Leaf(-1);
    }
View Full Code Here

  }
 
  @Test
  public void testForestVisualize() throws Exception {
    // Tree
    NumericalNode root = new NumericalNode(2, 90, new Leaf(0),
        new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
            new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
            new Leaf(0)}));
    List<Node> trees = Lists.newArrayList();
    trees.add(root);

    // Forest
    DecisionForest forest = new DecisionForest(trees);
View Full Code Here

    public Node build(Random rng, Data data) {
      for (int index = 0; index < data.size(); index++) {
        assertTrue(expected.contains(data.get(index)));
      }

      return new Leaf(Double.NaN);
    }
View Full Code Here

TOP

Related Classes of org.apache.mahout.classifier.df.node.Leaf

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.