Package org.apache.mahout.clustering.dirichlet

Source Code of org.apache.mahout.clustering.dirichlet.DirichletClusterer

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.dirichlet;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/**
* Performs Bayesian mixture modeling.
* <p/>
* The idea is that we use a probabilistic mixture of a number of models that we use to explain some observed
* data. The idea here is that each observed data point is assumed to have come from one of the models in the
* mixture, but we don't know which. The way we deal with that is to use a so-called latent parameter which
* specifies which model each data point came from.
* <p/>
* In addition, since this is a Bayesian clustering algorithm, we don't want to actually commit to any single
* explanation, but rather to sample from the distribution of models and latent assignments of data points to
* models given the observed data and the prior distributions of model parameters.
* <p/>
* This sampling process is initialized by taking models at random from the prior distribution for models.
* <p/>
* Then, we iteratively assign points to the different models using the mixture probabilities and the degree
* of fit between the point and each model expressed as a probability that the point was generated by that
* model.
* <p/>
* After points are assigned, new parameters for each model are sampled from the posterior distribution for
* the model parameters considering all of the observed data points that were assigned to the model. Models
* without any data points are also sampled, but since they have no points assigned, the new samples are
* effectively taken from the prior distribution for model parameters.
* <p/>
* The result is a number of samples that represent mixing probabilities, models and assignment of points to
* models. If the total number of possible models is substantially larger than the number that ever have
* points assigned to them, then this algorithm provides a (nearly) non-parametric clustering algorithm.
* <p/>
* These samples can give us interesting information that is lacking from a normal clustering that consists of
* a single assignment of points to clusters. Firstly, by examining the number of models in each sample that
* actually has any points assigned to it, we can get information about how many models (clusters) that the
* data support.
* <p/>
* Morevoer, by examining how often two points are assigned to the same model, we can get an approximate
* measure of how likely these points are to be explained by the same model. Such soft membership information
* is difficult to come by with conventional clustering methods.
* <p/>
* Finally, we can get an idea of the stability of how the data can be described. Typically, aspects of the
* data with lots of data available wind up with stable descriptions while at the edges, there are aspects
* that are phenomena that we can't really commit to a solid description, but it is still clear that the well
* supported explanations are insufficient to explain these additional aspects.
* <p/>
* One thing that can be difficult about these samples is that we can't always assign a correlation between
* the models in the different samples. Probably the best way to do this is to look for overlap in the
* assignments of data observations to the different models.
* <p/>
*
* <pre>
*    \theta_i ~ prior()
*    \lambda_i ~ Dirichlet(\alpha_0)
*    z_j ~ Multinomial( \lambda )
*    x_j ~ model(\theta_i)
* </pre>
*/
public class DirichletClusterer {

  // observed data
  private final List<VectorWritable> sampleData;

  // the ModelDistribution for the computation
  private final ModelDistribution<VectorWritable> modelFactory;

  // the state of the clustering process
  private final DirichletState state;

  private final int thin;

  private final int burnin;

  private final int numClusters;

  private final List<Cluster[]> clusterSamples = new ArrayList<Cluster[]>();

  private boolean emitMostLikely;

  private double threshold;

  /**
   * Create a new instance on the sample data with the given additional parameters
   *
   * @param points
   *          the observed data to be clustered
   * @param modelFactory
   *          the ModelDistribution to use
   * @param alpha0
   *          the double value for the beta distributions
   * @param numClusters
   *          the int number of clusters
   * @param thin
   *          the int thinning interval, used to report every n iterations
   * @param burnin
   *          the int burnin interval, used to suppress early iterations
   * @param numIterations
   *          number of iterations to be performed
   */
  public static List<Cluster[]> clusterPoints(List<VectorWritable> points,
                                              ModelDistribution<VectorWritable> modelFactory,
                                              double alpha0,
                                              int numClusters,
                                              int thin,
                                              int burnin,
                                              int numIterations) {
    DirichletClusterer clusterer = new DirichletClusterer(points, modelFactory, alpha0, numClusters, thin, burnin);
    return clusterer.cluster(numIterations);

  }

  /**
   * Create a new instance on the sample data with the given additional parameters
   *
   * @param sampleData
   *          the observed data to be clustered
   * @param modelFactory
   *          the ModelDistribution to use
   * @param alpha0
   *          the double value for the beta distributions
   * @param numClusters
   *          the int number of clusters
   * @param thin
   *          the int thinning interval, used to report every n iterations
   * @param burnin
   *          the int burnin interval, used to suppress early iterations
   */
  public DirichletClusterer(List<VectorWritable> sampleData,
                            ModelDistribution<VectorWritable> modelFactory,
                            double alpha0,
                            int numClusters,
                            int thin,
                            int burnin) {
    this.sampleData = sampleData;
    this.modelFactory = modelFactory;
    this.thin = thin;
    this.burnin = burnin;
    this.numClusters = numClusters;
    state = new DirichletState(modelFactory, numClusters, alpha0);
  }

  /**
   * This constructor only used by DirichletClusterMapper for setting up clustering params
   * @param emitMostLikely
   * @param threshold
   */
  public DirichletClusterer(boolean emitMostLikely, double threshold) {
    this.sampleData = null;
    this.modelFactory = null;
    this.thin = 0;
    this.burnin = 0;
    this.numClusters = 0;
    this.state = null;
    this.emitMostLikely = emitMostLikely;
    this.threshold = threshold;
  }

  /**
   * This constructor is used by DirichletMapper and DirichletReducer for setting up their clusterer
   * @param state
   */
  public DirichletClusterer(DirichletState state) {
    this.state = state;
    this.modelFactory = state.getModelFactory();
    this.sampleData = null;
    this.numClusters = state.getNumClusters();
    this.thin = 0;
    this.burnin = 0;
  }

  /**
   * Iterate over the sample data, obtaining cluster samples periodically and returning them.
   *
   * @param numIterations
   *          the int number of iterations to perform
   * @return a List<List<Model<Observation>>> of the observed models
   */
  public List<Cluster[]> cluster(int numIterations) {
    for (int iteration = 0; iteration < numIterations; iteration++) {
      iterate(iteration);
    }
    return clusterSamples;
  }

  /**
   * Perform one iteration of the clustering process, iterating over the samples to build a new array of
   * models, then updating the state for the next iteration
   */
  private void iterate(int iteration) {

    // create new posterior models
    Cluster[] newModels = (Cluster[]) modelFactory.sampleFromPosterior(state.getModels());

    // iterate over the samples, assigning each to a model
    for (VectorWritable observation : sampleData) {
      observe(newModels, observation);
    }

    // periodically add models to the cluster samples after the burn-in period
    if (iteration >= burnin && iteration % thin == 0) {
      clusterSamples.add(newModels);
    }
    // update the state from the new models
    state.update(newModels);
  }

  /**
   * @param newModels
   * @param observation
   */
  protected void observe(Model<VectorWritable>[] newModels, VectorWritable observation) {
    int k = assignToModel(observation);
    // ask the selected model to observe the datum
    newModels[k].observe(observation);
  }

  /**
   * Assign the observation to one of the models based upon probabilities
   * @param observation
   * @return the assigned model's index
   */
  protected int assignToModel(VectorWritable observation) {
    // compute an unnormalized vector of probabilities that x is described by each model
    Vector pi = new DenseVector(numClusters);
    for (int k1 = 0; k1 < numClusters; k1++) {
      pi.set(k1, state.adjustedProbability(observation, k1));
    }
    // then pick one cluster by sampling a Multinomial distribution based upon them
    // see: http://en.wikipedia.org/wiki/Multinomial_distribution
    return UncommonDistributions.rMultinom(pi);
  }

  protected void updateModels(Cluster[] newModels) {
    state.update(newModels);
  }

  protected Model<VectorWritable>[] samplePosteriorModels() {
    return state.getModelFactory().sampleFromPosterior(state.getModels());
  }

  protected DirichletCluster updateCluster(Cluster model, int k) {
    model.computeParameters();
    DirichletCluster cluster = state.getClusters().get(k);
    cluster.setModel(model);
    return cluster;
  }

  /**
   * Emit the point to one or more clusters depending upon clusterer state
   *
   * @param vector a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param context a Mapper.Context to emit to
   */
  public void emitPointToClusters(VectorWritable vector,
                                  List<DirichletCluster> clusters,
                                  Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
    throws IOException, InterruptedException {
    Vector pi = new DenseVector(clusters.size());
    for (int i = 0; i < clusters.size(); i++) {
      pi.set(i, clusters.get(i).getModel().pdf(vector));
    }
    pi = pi.divide(pi.zSum());
    if (emitMostLikely) {
      emitMostLikelyCluster(vector, clusters, pi, context);
    } else {
      emitAllClusters(vector, clusters, pi, context);
    }
  }

  /**
   * Emit the point to the most likely cluster
   *
   * @param point a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param pi the normalized pdf Vector for the point
   * @param context a Mapper.Context to emit to
   */
  private void emitMostLikelyCluster(VectorWritable point,
                                     Collection<DirichletCluster> clusters,
                                     Vector pi,
                                     Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
    throws IOException, InterruptedException {
    int clusterId = -1;
    double clusterPdf = 0;
    for (int i = 0; i < clusters.size(); i++) {
      double pdf = pi.get(i);
      if (pdf > clusterPdf) {
        clusterId = i;
        clusterPdf = pdf;
      }
    }
    //System.out.println(clusterId + ": " + ClusterBase.formatVector(vector.get(), null));
    context.write(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point.get()));
  }

  /**
   * Emit the point to all clusters if pdf exceeds the threshold
   * @param point a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param pi the normalized pdf Vector for the point
   * @param context a Mapper.Context to emit to
   */
  private void emitAllClusters(VectorWritable point,
                               List<DirichletCluster> clusters,
                               Vector pi,
                               Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
    throws IOException, InterruptedException {
    for (int i = 0; i < clusters.size(); i++) {
      double pdf = pi.get(i);
      if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
        //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
        context.write(new IntWritable(i), new WeightedVectorWritable(pdf, point.get()));
      }
    }
  }

  /**
   * Emit the point to one or more clusters depending upon clusterer state
   *
   * @param vector a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param writer a SequenceFile.Writer to emit to
   */
  public void emitPointToClusters(VectorWritable vector, List<DirichletCluster> clusters, Writer writer)
    throws IOException {
    Vector pi = new DenseVector(clusters.size());
    for (int i = 0; i < clusters.size(); i++) {
      double pdf = clusters.get(i).getModel().pdf(vector);
      pi.set(i, pdf);
    }
    pi = pi.divide(pi.zSum());
    if (emitMostLikely) {
      emitMostLikelyCluster(vector, clusters, pi, writer);
    } else {
      emitAllClusters(vector, clusters, pi, writer);
    }
  }

  /**
   * Emit the point to all clusters if pdf exceeds the threshold
   *
   * @param vector a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param pi the normalized pdf Vector for the point
   * @param writer a SequenceFile.Writer to emit to
   */
  private void emitAllClusters(VectorWritable vector, List<DirichletCluster> clusters, Vector pi, Writer writer)
    throws IOException {
    for (int i = 0; i < clusters.size(); i++) {
      double pdf = pi.get(i);
      if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
        //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
        writer.append(new IntWritable(i), new WeightedVectorWritable(pdf, vector.get()));
      }
    }
  }

  /**
   * Emit the point to the most likely cluster
   *
   * @param vector a VectorWritable holding the Vector
   * @param clusters a List of DirichletClusters
   * @param pi the normalized pdf Vector for the point
   * @param writer a SequenceFile.Writer to emit to
   */
  private static void emitMostLikelyCluster(VectorWritable vector,
                                            Collection<DirichletCluster> clusters,
                                            Vector pi,
                                            Writer writer) throws IOException {
    double maxPdf = 0;
    int clusterId = -1;
    for (int i = 0; i < clusters.size(); i++) {
      double pdf = pi.get(i);
      if (pdf > maxPdf) {
        maxPdf = pdf;
        clusterId = i;
      }
    }
    //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
    writer.append(new IntWritable(clusterId), new WeightedVectorWritable(maxPdf, vector.get()));
  }
}
TOP

Related Classes of org.apache.mahout.clustering.dirichlet.DirichletClusterer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.