Package org.apache.mahout.clustering.dirichlet

Source Code of org.apache.mahout.clustering.dirichlet.TestMapReduce

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.dirichlet;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import junit.framework.TestCase;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.NormalModel;
import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.Vector;
import org.apache.mahout.utils.DummyOutputCollector;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;

public class TestMapReduce extends TestCase {

  private List<Vector> sampleData = new ArrayList<Vector>();

  /**
   * Generate random samples and add them to the sampleData
   * @param num int number of samples to generate
   * @param mx double x-value of the sample mean
   * @param my double y-value of the sample mean
   * @param sd double standard deviation of the samples
   */
  private void generateSamples(int num, double mx, double my, double sd) {
    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
        + "] sd=" + sd);
    for (int i = 0; i < num; i++)
      sampleData.add(new DenseVector(new double[] {
          UncommonDistributions.rNorm(mx, sd),
          UncommonDistributions.rNorm(my, sd) }));
  }

  public static void writePointsToFileWithPayload(List<Vector> points,
      String fileName, String payload) throws IOException {
    BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
        new FileOutputStream(fileName), Charset.forName("UTF-8")));
    for (Vector point : points) {
      output.write(point.asFormatString());
      output.write(payload);
      output.write('\n');
    }
    output.flush();
    output.close();
  }

  @Override
  protected void setUp() throws Exception {
    super.setUp();
    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
    File f = new File("input");
    if (!f.exists())
      f.mkdir();
  }

  /**
   * Test the basic Mapper
   * @throws Exception
   */
  public void testMapper() throws Exception {
    generateSamples(10, 0, 0, 1);
    DirichletState<Vector> state = new DirichletState<Vector>(
        new NormalModelDistribution(), 5, 1, 0, 0);
    DirichletMapper mapper = new DirichletMapper();
    mapper.configure(state);

    DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
    for (Vector v : sampleData)
      mapper.map(null, new Text(v.asFormatString()), collector, null);
    Map<String, List<Text>> data = collector.getData();
    // this seed happens to produce two partitions, but they work
    assertEquals("output size", 3, data.size());
  }

  /**
   * Test the basic Reducer
   * @throws Exception
   */
  public void testReducer() throws Exception {
    generateSamples(100, 0, 0, 1);
    generateSamples(100, 2, 0, 1);
    generateSamples(100, 0, 2, 1);
    generateSamples(100, 2, 2, 1);
    DirichletState<Vector> state = new DirichletState<Vector>(
        new SampledNormalDistribution(), 20, 1, 1, 0);
    DirichletMapper mapper = new DirichletMapper();
    mapper.configure(state);

    DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
    for (Vector v : sampleData)
      mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
    Map<String, List<Text>> data = mapCollector.getData();
    // this seed happens to produce three partitions, but they work
    assertEquals("output size", 7, data.size());

    DirichletReducer reducer = new DirichletReducer();
    reducer.configure(state);
    DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
    for (String key : mapCollector.getKeys())
      reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
          reduceCollector, null);

    Model<Vector>[] newModels = reducer.newModels;
    state.update(newModels);
  }

  private void printModels(List<Model<Vector>[]> results, int significant) {
    int row = 0;
    for (Model<Vector>[] r : results) {
      System.out.print("sample[" + row++ + "]= ");
      for (int k = 0; k < r.length; k++) {
        Model<Vector> model = r[k];
        if (model.count() > significant) {
          System.out.print("m" + k + model.toString() + ", ");
        }
      }
      System.out.println();
    }
    System.out.println();
  }

  /**
   * Test the Mapper and Reducer in an iteration loop
   * @throws Exception
   */
  public void testMRIterations() throws Exception {
    generateSamples(100, 0, 0, 1);
    generateSamples(100, 2, 0, 1);
    generateSamples(100, 0, 2, 1);
    generateSamples(100, 2, 2, 1);
    DirichletState<Vector> state = new DirichletState<Vector>(
        new SampledNormalDistribution(), 20, 1.0, 1, 0);

    List<Model<Vector>[]> models = new ArrayList<Model<Vector>[]>();

    for (int iteration = 0; iteration < 10; iteration++) {
      DirichletMapper mapper = new DirichletMapper();
      mapper.configure(state);
      DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
      for (Vector v : sampleData)
        mapper.map(null, new Text(v.asFormatString()), mapCollector, null);

      DirichletReducer reducer = new DirichletReducer();
      reducer.configure(state);
      DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
      for (String key : mapCollector.getKeys())
        reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
            reduceCollector, null);

      Model<Vector>[] newModels = reducer.newModels;
      state.update(newModels);
      models.add(newModels);
    }
    printModels(models, 0);
  }

  @SuppressWarnings("unchecked")
  public void testNormalModelSerialization() {
    double[] m = { 1.1, 2.2 };
    Model model = new NormalModel(new DenseVector(m), 3.3);
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(model);
    Model model2 = gson.fromJson(jsonString, NormalModel.class);
    assertEquals("models", model.toString(), model2.toString());
  }

  @SuppressWarnings("unchecked")
  public void testNormalModelDistributionSerialization() {
    NormalModelDistribution dist = new NormalModelDistribution();
    Model[] models = dist.sampleFromPrior(20);
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(models);
    Model[] models2 = gson.fromJson(jsonString, NormalModel[].class);
    assertEquals("models", models.length, models2.length);
    for (int i = 0; i < models.length; i++)
      assertEquals("model[" + i + "]", models[i].toString(), models2[i]
          .toString());
  }

  @SuppressWarnings("unchecked")
  public void testSampledNormalModelSerialization() {
    double[] m = { 1.1, 2.2 };
    Model model = new SampledNormalModel(new DenseVector(m), 3.3);
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(model);
    Model model2 = gson.fromJson(jsonString, SampledNormalModel.class);
    assertEquals("models", model.toString(), model2.toString());
  }

  @SuppressWarnings("unchecked")
  public void testSampledNormalDistributionSerialization() {
    SampledNormalDistribution dist = new SampledNormalDistribution();
    Model[] models = dist.sampleFromPrior(20);
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(models);
    Model[] models2 = gson.fromJson(jsonString, SampledNormalModel[].class);
    assertEquals("models", models.length, models2.length);
    for (int i = 0; i < models.length; i++)
      assertEquals("model[" + i + "]", models[i].toString(), models2[i]
          .toString());
  }

  @SuppressWarnings("unchecked")
  public void testAsymmetricSampledNormalModelSerialization() {
    double[] m = { 1.1, 2.2 };
    double[] s = { 3.3, 4.4 };
    Model model = new AsymmetricSampledNormalModel(new DenseVector(m),
        new DenseVector(s));
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(model);
    Model model2 = gson
        .fromJson(jsonString, AsymmetricSampledNormalModel.class);
    assertEquals("models", model.toString(), model2.toString());
  }

  @SuppressWarnings("unchecked")
  public void testAsymmetricSampledNormalDistributionSerialization() {
    AsymmetricSampledNormalDistribution dist = new AsymmetricSampledNormalDistribution();
    Model[] models = dist.sampleFromPrior(20);
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    Gson gson = builder.create();
    String jsonString = gson.toJson(models);
    Model[] models2 = gson.fromJson(jsonString,
        AsymmetricSampledNormalModel[].class);
    assertEquals("models", models.length, models2.length);
    for (int i = 0; i < models.length; i++)
      assertEquals("model[" + i + "]", models[i].toString(), models2[i]
          .toString());
  }

  @SuppressWarnings("unchecked")
  public void testModelHolderSerialization() {
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    builder
        .registerTypeAdapter(ModelHolder.class, new JsonModelHolderAdapter());
    Gson gson = builder.create();
    double[] d = { 1.1, 2.2 };
    ModelHolder mh = new ModelHolder(new NormalModel(new DenseVector(d), 3.3));
    String format = gson.toJson(mh);
    System.out.println(format);
    ModelHolder mh2 = gson.fromJson(format, ModelHolder.class);
    assertEquals("mh", mh.model.toString(), mh2.model.toString());
  }

  @SuppressWarnings("unchecked")
  public void testModelHolderSerialization2() {
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
    builder
        .registerTypeAdapter(ModelHolder.class, new JsonModelHolderAdapter());
    Gson gson = builder.create();
    double[] d = { 1.1, 2.2 };
    double[] s = { 3.3, 4.4 };
    ModelHolder mh = new ModelHolder(new AsymmetricSampledNormalModel(
        new DenseVector(d), new DenseVector(s)));
    String format = gson.toJson(mh);
    System.out.println(format);
    ModelHolder mh2 = gson.fromJson(format, ModelHolder.class);
    assertEquals("mh", mh.model.toString(), mh2.model.toString());
  }

  @SuppressWarnings("unchecked")
  public void testStateSerialization() {
    GsonBuilder builder = new GsonBuilder();
    builder.registerTypeAdapter(DirichletState.class,
        new JsonDirichletStateAdapter());
    Gson gson = builder.create();
    DirichletState state = new DirichletState(new SampledNormalDistribution(),
        20, 1, 1, 0);
    String format = gson.toJson(state);
    System.out.println(format);
    DirichletState state2 = gson.fromJson(format, DirichletState.class);
    assertNotNull("State2 null", state2);
    assertEquals("numClusters", state.numClusters, state2.numClusters);
    assertEquals("modelFactory", state.modelFactory.getClass().getName(),
        state2.modelFactory.getClass().getName());
    assertEquals("clusters", state.clusters.size(), state2.clusters.size());
    assertEquals("mixture", state.mixture.cardinality(), state2.mixture
        .cardinality());
    assertEquals("dirichlet", state.offset, state2.offset);
  }

  /**
   * Test the Mapper and Reducer using the Driver
   * @throws Exception
   */
  public void testDriverMRIterations() throws Exception {
    File f = new File("input");
    for (File g : f.listFiles())
      g.delete();
    generateSamples(100, 0, 0, 0.5);
    generateSamples(100, 2, 0, 0.2);
    generateSamples(100, 0, 2, 0.3);
    generateSamples(100, 2, 2, 1);
    writePointsToFileWithPayload(sampleData, "input/data.txt", "");
    // Now run the driver
    DirichletDriver.runJob("input", "output",
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
        10, 1.0, 1);
    // and inspect results
    List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
    JobConf conf = new JobConf(KMeansDriver.class);
    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
    conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
    conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
    for (int i = 0; i < 11; i++) {
      conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
      clusters.add(DirichletMapper.getDirichletState(conf).clusters);
    }
    printResults(clusters, 0);
  }

  private void printResults(List<List<DirichletCluster<Vector>>> clusters,
      int significant) {
    int row = 0;
    for (List<DirichletCluster<Vector>> r : clusters) {
      System.out.print("sample[" + row++ + "]= ");
      for (int k = 0; k < r.size(); k++) {
        Model<Vector> model = r.get(k).model;
        if (model.count() > significant) {
          int total = new Double(r.get(k).totalCount).intValue();
          System.out.print("m" + k + "(" + total + ")" + model.toString()
              + ", ");
        }
      }
      System.out.println();
    }
    System.out.println();
  }

  /**
   * Test the Mapper and Reducer using the Driver
   * @throws Exception
   */
  public void testDriverMnRIterations() throws Exception {
    File f = new File("input");
    for (File g : f.listFiles())
      g.delete();
    generateSamples(500, 0, 0, 0.5);
    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 0, 0.2);
    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 0, 2, 0.3);
    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 2, 1);
    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
    // Now run the driver
    DirichletDriver.runJob("input", "output",
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
        15, 1.0, 1);
    // and inspect results
    List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
    JobConf conf = new JobConf(KMeansDriver.class);
    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
    conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
    conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
    for (int i = 0; i < 11; i++) {
      conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
      clusters.add(DirichletMapper.getDirichletState(conf).clusters);
    }
    printResults(clusters, 0);
  }

  /**
   * Test the Mapper and Reducer using the Driver
   * @throws Exception
   */
  public void testDriverMnRnIterations() throws Exception {
    File f = new File("input");
    for (File g : f.listFiles())
      g.delete();
    generateSamples(500, 0, 0, 0.5);
    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 0, 0.2);
    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 0, 2, 0.3);
    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 2, 1);
    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
    // Now run the driver
    DirichletDriver.runJob("input", "output",
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
        15, 1.0, 2);
    // and inspect results
    List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
    JobConf conf = new JobConf(KMeansDriver.class);
    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
    conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
    conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
    for (int i = 0; i < 11; i++) {
      conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
      clusters.add(DirichletMapper.getDirichletState(conf).clusters);
    }
    printResults(clusters, 0);
  }

  /**
   * Generate random samples and add them to the sampleData
   * @param num int number of samples to generate
   * @param mx double x-value of the sample mean
   * @param my double y-value of the sample mean
   * @param sdx double x-standard deviation of the samples
   * @param sdy double y-standard deviation of the samples
   */
  private void generateSamples(int num, double mx, double my, double sdx,
      double sdy) {
    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
        + "] sd=[" + sdx + ", " + sdy + "]");
    for (int i = 0; i < num; i++)
      sampleData.add(new DenseVector(new double[] {
          UncommonDistributions.rNorm(mx, sdx),
          UncommonDistributions.rNorm(my, sdy) }));
  }

  /**
   * Test the Mapper and Reducer using the Driver
   * @throws Exception
   */
  public void testDriverMnRnIterationsAsymmetric() throws Exception {
    File f = new File("input");
    for (File g : f.listFiles())
      g.delete();
    generateSamples(500, 0, 0, 0.5, 1.0);
    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 0, 0.2, 0.1);
    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 0, 2, 0.3, 0.5);
    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
    sampleData = new ArrayList<Vector>();
    generateSamples(500, 2, 2, 1, 0.5);
    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
    // Now run the driver
    DirichletDriver
        .runJob(
            "input",
            "output",
            "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution",
            20, 15, 1.0, 2);
    // and inspect results
    List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
    JobConf conf = new JobConf(KMeansDriver.class);
    conf
        .set(DirichletDriver.MODEL_FACTORY_KEY,
            "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
    conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
    conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
    for (int i = 0; i < 11; i++) {
      conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
      clusters.add(DirichletMapper.getDirichletState(conf).clusters);
    }
    printResults(clusters, 0);
  }

}
TOP

Related Classes of org.apache.mahout.clustering.dirichlet.TestMapReduce

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.