Package org.apache.mahout.clustering.meanshift

Source Code of org.apache.mahout.clustering.meanshift.MeanShiftCanopy

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.meanshift;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.CardinalityException;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.JsonVectorAdapter;
import org.apache.mahout.matrix.PlusFunction;
import org.apache.mahout.matrix.Vector;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;

/**
* This class models a canopy as a center point, the number of points that are contained within it according to the
* application of some distance metric, and a point total which is the sum of all the points and is used to compute the
* centroid when needed.
*/
public class MeanShiftCanopy extends ClusterBase {

  // keys used by Driver, Mapper, Combiner & Reducer
  public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.canopy.measure";

  public static final String T1_KEY = "org.apache.mahout.clustering.canopy.t1";

  public static final String T2_KEY = "org.apache.mahout.clustering.canopy.t2";

  public static final String CONTROL_PATH_KEY = "org.apache.mahout.clustering.control.path";

  public static final String CLUSTER_CONVERGENCE_KEY = "org.apache.mahout.clustering.canopy.convergence";

  private static double convergenceDelta = 0;

  // the next canopyId to be allocated
  private static int nextCanopyId = 0;

  // the T1 distance threshold
  private static double t1;

  // the T2 distance threshold
  private static double t2;

  // the distance measure
  private static DistanceMeasure measure;

  // TODO: this is problematic, but how else to encode membership?
  private List<Vector> boundPoints = new ArrayList<Vector>();

  private boolean converged = false;

  static double getT1() {
    return t1;
  }

  static double getT2() {
    return t2;
  }

  /**
   * Configure the Canopy and its distance measure
   *
   * @param job the JobConf for this job
   */
  public static void configure(JobConf job) {
    try {
      measure = Class.forName(job.get(DISTANCE_MEASURE_KEY)).asSubclass(
          DistanceMeasure.class).newInstance();
      measure.configure(job);
    } catch (ClassNotFoundException e) {
      throw new IllegalStateException(e);
    } catch (IllegalAccessException e) {
      throw new IllegalStateException(e);
    } catch (InstantiationException e) {
      throw new IllegalStateException(e);
    }
    nextCanopyId = 0;
    t1 = Double.parseDouble(job.get(T1_KEY));
    t2 = Double.parseDouble(job.get(T2_KEY));
    convergenceDelta = Double.parseDouble(job.get(CLUSTER_CONVERGENCE_KEY));
  }

  /**
   * Configure the Canopy for unit tests
   *
   * @param aDelta the convergence criteria
   */
  public static void config(DistanceMeasure aMeasure, double aT1, double aT2,
                            double aDelta) {
    nextCanopyId = 100; // so canopyIds will sort properly
    measure = aMeasure;
    t1 = aT1;
    t2 = aT2;
    convergenceDelta = aDelta;
  }

  /**
   * Merge the given canopy into the canopies list. If it touches any existing canopy (norm<T1) then add the center of
   * each to the other. If it covers any other canopies (norm<T2), then merge the given canopy with the closest covering
   * canopy. If the given canopy does not cover any other canopies, add it to the canopies list.
   *
   * @param aCanopy  a MeanShiftCanopy to be merged
   * @param canopies the List<Canopy> to be appended
   */
  public static void mergeCanopy(MeanShiftCanopy aCanopy,
                                 List<MeanShiftCanopy> canopies) {
    MeanShiftCanopy closestCoveringCanopy = null;
    double closestNorm = Double.MAX_VALUE;
    for (MeanShiftCanopy canopy : canopies) {
      double norm = measure.distance(canopy.getCenter(), aCanopy.getCenter());
      if (norm < t1) {
        aCanopy.touch(canopy);
      }
      if (norm < t2) {
        if (closestCoveringCanopy == null || norm < closestNorm) {
          closestNorm = norm;
          closestCoveringCanopy = canopy;
        }
      }
    }
    if (closestCoveringCanopy == null) {
      canopies.add(aCanopy);
    } else {
      closestCoveringCanopy.merge(aCanopy);
    }
  }

  /** Format the canopy for output */
  public static String formatCanopy(MeanShiftCanopy canopy) {
    Type vectorType = new TypeToken<Vector>() {
    }.getType();
    GsonBuilder gBuilder = new GsonBuilder();
    gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
    Gson gson = gBuilder.create();
    return gson.toJson(canopy, MeanShiftCanopy.class);
  }

  /**
   * Decodes and returns a Canopy from the formattedString
   *
   * @param formattedString a String produced by formatCanopy
   * @return a new Canopy
   */
  public static MeanShiftCanopy decodeCanopy(String formattedString) {
    Type vectorType = new TypeToken<Vector>() {
    }.getType();
    GsonBuilder gBuilder = new GsonBuilder();
    gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
    Gson gson = gBuilder.create();
    return gson.fromJson(formattedString, MeanShiftCanopy.class);
  }

  public MeanShiftCanopy() {
    super();
  }

  /** Create a new Canopy with the given canopyId */
  public MeanShiftCanopy(String id) {
    this.setId(Integer.parseInt(id.substring(1)));
    this.setCenter(null);
    this.setPointTotal(null);
    this.setNumPoints(0);
  }

  /**
   * Create a new Canopy containing the given point
   *
   * @param point a Vector
   */
  public MeanShiftCanopy(Vector point) {
    this.setId(nextCanopyId++);
    this.setCenter(point);
    this.setPointTotal(point.clone());
    this.setNumPoints(1);
    this.boundPoints.add(point);
  }

  /**
   * Create a new Canopy containing the given point, id and bound points
   *
   * @param point       a Vector
   * @param id          an int identifying the canopy local to this process only
   * @param boundPoints a List<Vector> containing points bound to the canopy
   * @param converged   true if the canopy has converged
   */
  MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints,
                  boolean converged) {
    this.setId(id);
    this.setCenter(point);
    this.setPointTotal(point.clone());
    this.setNumPoints(1);
    this.boundPoints = boundPoints;
    this.converged = converged;
  }

  /**
   * Add a point to the canopy some number of times
   *
   * @param point   a Vector to add
   * @param nPoints the number of times to add the point
   * @throws CardinalityException if the cardinalities disagree
   */
  void addPoints(Vector point, int nPoints) {
    setNumPoints(getNumPoints() + nPoints);
    Vector subTotal = (nPoints == 1) ? point.clone() : point.times(nPoints);
    setPointTotal((getPointTotal() == null) ? subTotal : getPointTotal().plus(subTotal));
  }

  /**
   * Return if the point is closely covered by this canopy
   *
   * @param point a Vector point
   * @return if the point is covered
   */
  public boolean closelyBound(Vector point) {
    return measure.distance(getCenter(), point) < t2;
  }

  /**
   * Compute the bound centroid by averaging the bound points
   *
   * @return a Vector which is the new bound centroid
   */
  public Vector computeBoundCentroid() {
    Vector result = new DenseVector(getCenter().size());
    for (Vector v : boundPoints) {
      result.assign(v, new PlusFunction());
    }
    return result.divide(boundPoints.size());
  }

  /**
   * Compute the centroid by normalizing the pointTotal
   *
   * @return a Vector which is the new centroid
   */
  public Vector computeCentroid() {
    if (getNumPoints() == 0) {
      return getCenter();
    } else {
      return getPointTotal().divide(getNumPoints());
    }
  }

  /**
   * Return if the point is covered by this canopy
   *
   * @param point a Vector point
   * @return if the point is covered
   */
  boolean covers(Vector point) {
    return measure.distance(getCenter(), point) < t1;
  }

  /** Emit the new canopy to the collector, keyed by the canopy's Id */
  void emitCanopy(MeanShiftCanopy canopy,
                  OutputCollector<Text, WritableComparable<?>> collector)
      throws IOException {
    String identifier = this.getIdentifier();
    collector.collect(new Text(identifier),
        new Text("new " + canopy.toString()));
  }

  public List<Vector> getBoundPoints() {
    return boundPoints;
  }

  public int getCanopyId() {
    return getId();
  }

  public String getIdentifier() {
    return converged ? "V" + getId() : "C" + getId();
  }

  void init(MeanShiftCanopy canopy) {
    setId(canopy.getId());
    setCenter(canopy.getCenter());
    addPoints(getCenter(), 1);
    boundPoints.addAll(canopy.getBoundPoints());
  }

  public boolean isConverged() {
    return converged;
  }

  /**
   * The receiver overlaps the given canopy. Touch it and add my bound points to it.
   *
   * @param canopy an existing MeanShiftCanopy
   */
  void merge(MeanShiftCanopy canopy) {
    boundPoints.addAll(canopy.boundPoints);
  }

  /**
   * Shift the center to the new centroid of the cluster
   *
   * @return if the cluster is converged
   */
  public boolean shiftToMean() {
    Vector centroid = computeCentroid();
    converged = new EuclideanDistanceMeasure().distance(centroid, getCenter()) < convergenceDelta;
    setCenter(centroid);
    setNumPoints(1);
    setPointTotal(centroid.clone());
    return converged;
  }

  @Override
  public String toString() {
    return formatCanopy(this);
  }

  /**
   * The receiver touches the given canopy. Add respective centers.
   *
   * @param canopy an existing MeanShiftCanopy
   */
  void touch(MeanShiftCanopy canopy) {
    canopy.addPoints(getCenter(), boundPoints.size());
    addPoints(canopy.getCenter(), canopy.boundPoints.size());
  }

  @Override
  public void readFields(DataInput in) throws IOException {
    super.readFields(in);
    this.setCenter(AbstractVector.readVector(in));
    int numpoints = in.readInt();
    this.boundPoints = new ArrayList<Vector>();
    for (int i = 0; i < numpoints; i++) {
      this.boundPoints.add(AbstractVector.readVector(in));
    }
  }

  @Override
  public void write(DataOutput out) throws IOException {
    super.write(out);
    AbstractVector.writeVector(out, computeCentroid());
    out.writeInt(boundPoints.size());
    for (Vector v : boundPoints) {
      AbstractVector.writeVector(out, v);
    }
  }

  public MeanShiftCanopy shallowCopy() {
    MeanShiftCanopy result = new MeanShiftCanopy();
    result.setId(this.getId());
    result.setCenter(this.getCenter());
    result.setPointTotal(this.getPointTotal());
    result.setNumPoints(this.getNumPoints());
    result.boundPoints = this.boundPoints;
    return result;
  }

  @Override
  public String asFormatString() {
    return formatCanopy(this);
  }
}
TOP

Related Classes of org.apache.mahout.clustering.meanshift.MeanShiftCanopy

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.