Package org.apache.mahout.classifier.df

Source Code of org.apache.mahout.classifier.df.DecisionForestTest

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.df;

import java.util.List;
import java.util.Random;

import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.DescriptorException;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.junit.Test;

import com.google.common.collect.Lists;

public final class DecisionForestTest extends MahoutTestCase {

  private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
    "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
    "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
    "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
    "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
    "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
    "rainy,71,91,TRUE,no"};
 
  private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
    "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};

  private Random rng;

  @Override
  public void setUp() throws Exception {
    super.setUp();
    rng = RandomUtils.getRandom();
  }

  private static Data[] generateTrainingDataA() throws DescriptorException {
    // Dataset
    Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
   
    // Training data
    Data data = DataLoader.loadData(dataset, TRAIN_DATA);
    @SuppressWarnings("unchecked")
    List<Instance>[] instances = new List[3];
    for (int i = 0; i < instances.length; i++) {
      instances[i] = Lists.newArrayList();
    }
    for (int i = 0; i < data.size(); i++) {
      if (data.get(i).get(0) == 0.0d) {
        instances[0].add(data.get(i));
      } else {
        instances[1].add(data.get(i));
      }
    }
    Data[] datas = new Data[instances.length];
    for (int i = 0; i < datas.length; i++) {
      datas[i] = new Data(dataset, instances[i]);
    }

    return datas;
  }

  private static Data[] generateTrainingDataB() throws DescriptorException {

    // Training data
    String[] trainData = new String[20];
    for (int i = 0; i < trainData.length; i++) {
      if (i % 3 == 0) {
        trainData[i] = "A," + (40 - i) + ',' (i + 20);
      } else if (i % 3 == 1) {
        trainData[i] = "B," + (i + 20) + ',' (40 - i);
      } else {
        trainData[i] = "C," + (i + 20) + ',' (i + 20);
      }
    }
    // Dataset
    Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
    Data[] datas = new Data[3];
    datas[0] = DataLoader.loadData(dataset, trainData);

    // Training data
    trainData = new String[20];
    for (int i = 0; i < trainData.length; i++) {
      if (i % 2 == 0) {
        trainData[i] = "A," + (50 - i) + ',' (i + 10);
      } else {
        trainData[i] = "B," + (i + 10) + ',' (50 - i);
      }
    }
    datas[1] = DataLoader.loadData(dataset, trainData);

    // Training data
    trainData = new String[10];
    for (int i = 0; i < trainData.length; i++) {
      trainData[i] = "A," + (40 - i) + ',' (i + 20);
    }
    datas[2] = DataLoader.loadData(dataset, trainData);

    return datas;
  }
 
  private DecisionForest buildForest(Data[] datas) {
    List<Node> trees = Lists.newArrayList();
    for (Data data : datas) {
      // build tree
      DecisionTreeBuilder builder = new DecisionTreeBuilder();
      builder.setM(data.getDataset().nbAttributes() - 1);
      builder.setMinSplitNum(0);
      builder.setComplemented(false);
      trees.add(builder.build(rng, data));
    }
    return new DecisionForest(trees);
  }
 
  @Test
  public void testClassify() throws DescriptorException {
    // Training data
    Data[] datas = generateTrainingDataA();
    // Build Forest
    DecisionForest forest = buildForest(datas);
    // Test data
    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);

    assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON);
    // This one is tie-broken -- 1 is OK too
    assertEquals(0.0, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON);
    assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON);
  }

  @Test
  public void testClassifyData() throws DescriptorException {
    // Training data
    Data[] datas = generateTrainingDataA();
    // Build Forest
    DecisionForest forest = buildForest(datas);
    // Test data
    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);

    double[][] predictions = new double[testData.size()][];
    forest.classify(testData, predictions);
    assertArrayEquals(new double[][]{{1.0, Double.NaN, Double.NaN},
        {1.0, 0.0, Double.NaN}, {1.0, 1.0, Double.NaN}}, predictions);
  }

  @Test
  public void testRegression() throws DescriptorException {
    Data[] datas = generateTrainingDataB();
    DecisionForest[] forests = new DecisionForest[datas.length];
    for (int i = 0; i < datas.length; i++) {
      Data[] subDatas = new Data[datas.length - 1];
      int k = 0;
      for (int j = 0; j < datas.length; j++) {
        if (j != i) {
          subDatas[k] = datas[j];
          k++;
        }
      }
      forests[i] = buildForest(subDatas);
    }
   
    double[][] predictions = new double[datas[0].size()][];
    forests[0].classify(datas[0], predictions);
    assertArrayEquals(new double[]{20.0, 20.0}, predictions[0], EPSILON);
    assertArrayEquals(new double[]{39.0, 29.0}, predictions[1], EPSILON);
    assertArrayEquals(new double[]{Double.NaN, 29.0}, predictions[2], EPSILON);
    assertArrayEquals(new double[]{Double.NaN, 23.0}, predictions[17], EPSILON);

    predictions = new double[datas[1].size()][];
    forests[1].classify(datas[1], predictions);
    assertArrayEquals(new double[]{30.0, 29.0}, predictions[19], EPSILON);

    predictions = new double[datas[2].size()][];
    forests[2].classify(datas[2], predictions);
    assertArrayEquals(new double[]{29.0, 28.0}, predictions[9], EPSILON);

    assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), EPSILON);
    assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), EPSILON);
    assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), EPSILON);
  }
}
TOP

Related Classes of org.apache.mahout.classifier.df.DecisionForestTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.