Package org.apache.mahout.clustering.dirichlet

Source Code of org.apache.mahout.clustering.dirichlet.TestDirichletClustering

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.dirichlet;

import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

import com.google.common.collect.Lists;

@Deprecated
public final class TestDirichletClustering extends MahoutTestCase {

  private List<VectorWritable> sampleData;
 
  @Override
  @Before
  public void setUp() throws Exception {
    super.setUp();
    sampleData = Lists.newArrayList();
  }

  /**
   * Generate random samples and add them to the sampleData
   *
   * @param num int number of samples to generate
   * @param mx  double x-value of the sample mean
   * @param my  double y-value of the sample mean
   * @param sd  double standard deviation of the samples
   * @param card int cardinality of the generated sample vectors
   */
  private void generateSamples(int num, double mx, double my, double sd, int card) {
    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my + "] sd=" + sd);
    for (int i = 0; i < num; i++) {
      DenseVector v = new DenseVector(card);
      for (int j = 0; j < card; j++) {
        v.set(j, UncommonDistributions.rNorm(mx, sd));
      }
      sampleData.add(new VectorWritable(v));
    }
  }

  /**
   * Generate 2-d samples for backwards compatibility with existing tests
   * @param num int number of samples to generate
   * @param mx  double x-value of the sample mean
   * @param my  double y-value of the sample mean
   * @param sd  double standard deviation of the samples
   */
  private void generateSamples(int num, double mx, double my, double sd) {
    generateSamples(num, mx, my, sd, 2);
  }

  @Test
  public void testDirichletClusteringSeq() throws Exception {
    Path output = getTestTempDirPath("output");
    Configuration conf = getConfiguration();
    FileSystem fs = FileSystem.get(getConfiguration());
   
    generateSamples(40, 1, 1, 3);
    generateSamples(30, 1, 0, 0.1);
    generateSamples(30, 0, 1, 0.1);

    ClusteringTestUtils.writePointsToFile(sampleData,
            getTestTempFilePath("testdata/file1"), fs, conf);

    DenseVector prototype = (DenseVector) sampleData.get(0).get();
   
    DistributionDescription description = new DistributionDescription(
        DistanceMeasureClusterDistribution.class.getName(),
        RandomAccessSparseVector.class.getName(),
        ManhattanDistanceMeasure.class.getName(), prototype.size());
   
    DirichletDriver.run(conf, getTestTempDirPath("testdata"), output,
        description, 10, 1, 1.0, true, true, 0, true);
   
    Path path = new Path(output, "clusteredPoints/part-m-0");
    long count = HadoopUtil.countRecords(path, conf);
    assertEquals("number of points", sampleData.size(), count);
  }
 
  @Test
  public void testDirichletClusteringMR() throws Exception {
    Path output = getTestTempDirPath("output");
    Configuration conf = getConfiguration();
    FileSystem fs = FileSystem.get(getConfiguration());
   
    generateSamples(40, 1, 1, 3);
    generateSamples(30, 1, 0, 0.1);

    ClusteringTestUtils.writePointsToFile(sampleData, true,
            getTestTempFilePath("testdata/file1"), fs, conf);

    DenseVector prototype = (DenseVector) sampleData.get(0).get();
   
    DistributionDescription description = new DistributionDescription(
        DistanceMeasureClusterDistribution.class.getName(),
        RandomAccessSparseVector.class.getName(),
        ManhattanDistanceMeasure.class.getName(), prototype.size());
   
    DirichletDriver.run(conf, getTestTempDirPath("testdata"), output,
        description, 10, 1, 1.0, true, true, 0, false);
   
    Path path = new Path(output, "clusteredPoints/part-m-00000");
    long count = HadoopUtil.countRecords(path, conf);
    assertEquals("number of points", sampleData.size(), count);
  }
 
}
TOP

Related Classes of org.apache.mahout.clustering.dirichlet.TestDirichletClustering

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.