Package org.apache.mahout.classifier.df.data

Examples of org.apache.mahout.classifier.df.data.Data


    Node childNode;
    if (data.getDataset().isNumerical(best.getAttr())) {
      boolean[] temp = null;

      Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
      Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));

      if (loSubset.isEmpty() || hiSubset.isEmpty()) {
        // the selected attribute did not change the data, avoid using it in the child notes
        selected[best.getAttr()] = true;
      } else {
        // the data changed, so we can unselect all previousely selected NUMERICAL attributes
        temp = selected;
        selected = cloneCategoricalAttributes(data.getDataset(), selected);
      }

      Node loChild = build(rng, loSubset);
      Node hiChild = build(rng, hiSubset);

      // restore the selection state of the attributes
      if (temp != null) {
        selected = temp;
      } else {
        selected[best.getAttr()] = alreadySelected;
      }

      childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
    } else { // CATEGORICAL attribute
      selected[best.getAttr()] = true;

      double[] values = data.values(best.getAttr());
      Node[] children = new Node[values.length];

      for (int index = 0; index < values.length; index++) {
        Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
        children[index] = build(rng, subset);
      }

      selected[best.getAttr()] = alreadySelected;
View Full Code Here


  private static Data[] generateTrainingDataA() throws DescriptorException {
    // Dataset
    Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
   
    // Training data
    Data data = DataLoader.loadData(dataset, TRAIN_DATA);
    @SuppressWarnings("unchecked")
    List<Instance>[] instances = new List[3];
    for (int i = 0; i < instances.length; i++) {
      instances[i] = Lists.newArrayList();
    }
    for (int i = 0; i < data.size(); i++) {
      if (data.get(i).get(0) == 0.0d) {
        instances[0].add(data.get(i));
      } else {
        instances[1].add(data.get(i));
      }
    }
    Data[] datas = new Data[instances.length];
    for (int i = 0; i < datas.length; i++) {
      datas[i] = new Data(dataset, instances[i]);
    }

    return datas;
  }
View Full Code Here

    // Training data
    Data[] datas = generateTrainingDataA();
    // Build Forest
    DecisionForest forest = buildForest(datas);
    // Test data
    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);

    assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON);
    // This one is tie-broken -- 1 is OK too
    assertEquals(0.0, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON);
    assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON);
  }
View Full Code Here

    // Training data
    Data[] datas = generateTrainingDataA();
    // Build Forest
    DecisionForest forest = buildForest(datas);
    // Test data
    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);

    double[][] predictions = new double[testData.size()][];
    forest.classify(testData, predictions);
    assertArrayEquals(new double[][]{{1.0, Double.NaN, Double.NaN},
        {1.0, 0.0, Double.NaN}, {1.0, 1.0, Double.NaN}}, predictions);
  }
View Full Code Here

    for (int i = 0; i < data.size(); i++) {
      if (data.get(i).get(0) != 0.0d) {
        instances.add(data.get(i));
      }
    }
    Data lessData = new Data(data.getDataset(), instances);
   
    // build tree
    DecisionTreeBuilder builder = new DecisionTreeBuilder();
    builder.setM(data.getDataset().nbAttributes() - 1);
    builder.setMinSplitNum(0);
View Full Code Here

    assertEquals("\noutlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
  }
 
  @Test
  public void testEmpty() throws Exception {
    Data emptyData = new Data(data.getDataset());
   
    // build tree
    DecisionTreeBuilder builder = new DecisionTreeBuilder();
    Node tree = builder.build(rng, emptyData);
View Full Code Here

    // all the vectors have the same label (0)
    double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0);
    String[] sData = Utils.double2String(temp);
    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
    Data data = DataLoader.loadData(dataset, sData);
    DefaultIgSplit iG = new DefaultIgSplit();

    double expected = 0.0 - 1.0 * Math.log(1.0) / Math.log(2.0);
    assertEquals(expected, iG.entropy(data), EPSILON);

View Full Code Here

   *          number of trees to grow
   */
  private void runIteration(Random rng, Data data, int m, int nbtrees) {
   
    log.info("Splitting the data");
    Data train = data.clone();
    Data test = train.rsplit(rng, (int) (data.size() * 0.1));
   
    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
   
    SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder, train);
   
    // grow a forest with m = log2(M)+1
    treeBuilder.setM(m);
   
    long time = System.currentTimeMillis();
    log.info("Growing a forest with m={}", m);
    DecisionForest forestM = forestBuilder.build(nbtrees);
    sumTimeM += System.currentTimeMillis() - time;
    numNodesM += forestM.nbNodes();
   
    // grow a forest with m=1
    treeBuilder.setM(1);
   
    time = System.currentTimeMillis();
    log.info("Growing a forest with m=1");
    DecisionForest forestOne = forestBuilder.build(nbtrees);
    sumTimeOne += System.currentTimeMillis() - time;
    numNodesOne += forestOne.nbNodes();
   
    // compute the test set error (Selection Error), and mean tree error (One Tree Error),
    double[] testLabels = test.extractLabels();
    double[][] predictions = new double[test.size()][];
   
    forestM.classify(test, predictions);
    double[] sumPredictions = new double[test.size()];
    Arrays.fill(sumPredictions, 0.0);
    for (int i = 0; i < predictions.length; i++) {
      for (int j = 0; j < predictions[i].length; j++) {
        sumPredictions[i] += predictions[i][j];
      }
View Full Code Here

    }
   
    // load the data
    FileSystem fs = dataPath.getFileSystem(new Configuration());
    Dataset dataset = Dataset.load(getConf(), datasetPath);
    Data data = DataLoader.loadData(dataset, fs, dataPath);
   
    // take m to be the first integer less than log2(M) + 1, where M is the
    // number of inputs
    int m = (int) Math.floor(FastMath.log(2.0, data.getDataset().nbAttributes()) + 1);
   
    Random rng = RandomUtils.getRandom();
    for (int iteration = 0; iteration < nbIterations; iteration++) {
      log.info("Iteration {}", iteration);
      runIteration(rng, data, m, nbTrees);
View Full Code Here

  }
 
  protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException {
    log.info("Loading the data...");
    FileSystem fs = dataPath.getFileSystem(conf);
    Data data = DataLoader.loadData(dataset, fs, dataPath);
    log.info("Data Loaded");
   
    return data;
  }
View Full Code Here

TOP

Related Classes of org.apache.mahout.classifier.df.data.Data

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.