Package org.apache.mahout.clustering.dirichlet

Source Code of org.apache.mahout.clustering.dirichlet.DisplayDirichlet

package org.apache.mahout.clustering.dirichlet;

import java.awt.Color;
import java.awt.Frame;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Toolkit;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.awt.geom.AffineTransform;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.List;

import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.TimesFunction;
import org.apache.mahout.matrix.Vector;

class DisplayDirichlet extends Frame {
  private static final long serialVersionUID = 1L;

  int res; //screen resolution

  int ds = 72; //default scale = 72 pixels per inch

  int size = 8; // screen size in inches

  static List<Vector> sampleData = new ArrayList<Vector>();

  static List<Model<Vector>[]> result;

  static double significance = 0.05;

  static List<Vector> sampleParams = new ArrayList<Vector>();

  static Color[] colors = { Color.red, Color.orange, Color.yellow, Color.green,
      Color.blue, Color.magenta, Color.lightGray };

  /**
   * Licensed to the Apache Software Foundation (ASF) under one or more
   * contributor license agreements.  See the NOTICE file distributed with
   * this work for additional information regarding copyright ownership.
   * The ASF licenses this file to You under the Apache License, Version 2.0
   * (the "License"); you may not use this file except in compliance with
   * the License.  You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   * See the License for the specific language governing permissions and
   * limitations under the License.
   */

  DisplayDirichlet() {
    initialize();
  }

  void initialize() {
    //Get screen resolution
    res = Toolkit.getDefaultToolkit().getScreenResolution();

    //Set Frame size in inches
    this.setSize(size * res, size * res);
    this.setVisible(true);
    this.setTitle("Dirichlet Process Sample Data");

    //Window listener to terminate program.
    this.addWindowListener(new WindowAdapter() {
      public void windowClosing(WindowEvent e) {
        System.exit(0);
      }
    });
  }

  public static void main(String[] args) {
    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
    generateSamples();
    new DisplayDirichlet();
  }

  // Override the paint() method
  public void paint(Graphics g) {
    Graphics2D g2 = (Graphics2D) g;
    plotSampleData(g);
    Vector v = new DenseVector(2);
    Vector dv = new DenseVector(2);
    g2.setColor(Color.RED);
    for (Vector param : sampleParams) {
      v.set(0, param.get(0));
      v.set(1, param.get(1));
      dv.set(0, param.get(2) * 3);
      dv.set(1, param.get(3) * 3);
      plotEllipse(g2, v, dv);
    }
  }

  void plotSampleData(Graphics g) {
    Graphics2D g2 = (Graphics2D) g;
    double sx = (double) res / ds;
    g2.setTransform(AffineTransform.getScaleInstance(sx, sx));

    // plot the axes
    g2.setColor(Color.BLACK);
    Vector dv = new DenseVector(2).assign(size / 2);
    plotRectangle(g2, new DenseVector(2).assign(2), dv);
    plotRectangle(g2, new DenseVector(2).assign(-2), dv);

    // plot the sample data
    g2.setColor(Color.DARK_GRAY);
    dv.assign(0.03);
    for (Vector v : sampleData)
      plotRectangle(g2, v, dv);
  }

  /**
   * Plot the points on the graphics context
   * @param g2 a Graphics2D context
   * @param v a Vector of rectangle centers
   * @param dv a Vector of rectangle sizes
   */
  void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
    int h = size / 2;
    double[] flip = { 1, -1 };
    Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
    v2 = v2.minus(dv.divide(2));
    double x = v2.get(0) + h;
    double y = v2.get(1) + h;
    g2.draw(new Rectangle2D.Double(x * ds, y * ds, dv.get(0) * ds, dv.get(1)
        * ds));
  }

  /**
   * Plot the points on the graphics context
   * @param g2 a Graphics2D context
   * @param v a Vector of rectangle centers
   * @param dv a Vector of rectangle sizes
   */
  void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
    int h = size / 2;
    double[] flip = { 1, -1 };
    Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
    v2 = v2.minus(dv.divide(2));
    double x = v2.get(0) + h;
    double y = v2.get(1) + h;
    g2
        .draw(new Ellipse2D.Double(x * ds, y * ds, dv.get(0) * ds, dv.get(1)
            * ds));
  }

  private static void printModels(List<Model<Vector>[]> results, int significant) {
    int row = 0;
    for (Model<Vector>[] r : results) {
      System.out.print("sample[" + row++ + "]= ");
      for (int k = 0; k < r.length; k++) {
        Model<Vector> model = r[k];
        if (model.count() > significant) {
          System.out.print("m" + k + model.toString() + ", ");
        }
      }
      System.out.println();
    }
    System.out.println();
  }

  static void generateSamples() {
    generateSamples(400, 1, 1, 3);
    generateSamples(300, 1, 0, 0.5);
    generateSamples(300, 0, 2, 0.1);
  }

  static void generate2dSamples() {
    generate2dSamples(400, 1, 1, 3, 1);
    generate2dSamples(300, 1, 0, 0.5, 1);
    generate2dSamples(300, 0, 2, 0.1, 0.5);
  }

  /**
   * Generate random samples and add them to the sampleData
   * @param num int number of samples to generate
   * @param mx double x-value of the sample mean
   * @param my double y-value of the sample mean
   * @param sd double standard deviation of the samples
   */
  public static void generateSamples(int num, double mx, double my, double sd) {
    double[] params = { mx, my, sd, sd };
    sampleParams.add(new DenseVector(params));
    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
        + "] sd=" + sd);
    for (int i = 0; i < num; i++)
      sampleData.add(new DenseVector(new double[] {
          UncommonDistributions.rNorm(mx, sd),
          UncommonDistributions.rNorm(my, sd) }));
  }

  /**
   * Generate random samples and add them to the sampleData
   * @param num int number of samples to generate
   * @param mx double x-value of the sample mean
   * @param my double y-value of the sample mean
   * @param sdx double x-value standard deviation of the samples
   * @param sdy double y-value standard deviation of the samples
   */
  public static void generate2dSamples(int num, double mx, double my,
      double sdx, double sdy) {
    double[] params = { mx, my, sdx, sdy };
    sampleParams.add(new DenseVector(params));
    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
        + "] sd=[" + sdx + ", " + sdy + "]");
    for (int i = 0; i < num; i++)
      sampleData.add(new DenseVector(new double[] {
          UncommonDistributions.rNorm(mx, sdx),
          UncommonDistributions.rNorm(my, sdy) }));
  }

  static void generateResults(ModelDistribution<Vector> modelDist) {
    DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
        modelDist, 1.0, 10, 2, 2);
    result = dc.cluster(20);
    printModels(result, 5);
  }

  static boolean isSignificant(Model<Vector> model) {
    return (((double) model.count() / sampleData.size()) > significance);
  }

}
TOP

Related Classes of org.apache.mahout.clustering.dirichlet.DisplayDirichlet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.