Package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans

Source Code of de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.AbstractKMeans

package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import de.lmu.ifi.dbs.elki.algorithm.AbstractPrimitiveDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.model.MeanModel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance;
import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;

/*
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures

Copyright (C) 2012
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/**
* Abstract base class for k-means implementations.
*
* @author Erich Schubert
*
* @param <V> Vector type
* @param <D> Distance type
*/
public abstract class AbstractKMeans<V extends NumberVector<V, ?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm<NumberVector<?, ?>, D, Clustering<MeanModel<V>>> {
  /**
   * Parameter to specify the number of clusters to find, must be an integer
   * greater than 0.
   */
  public static final OptionID K_ID = OptionID.getOrCreateOptionID("kmeans.k", "The number of clusters to find.");

  /**
   * Parameter to specify the number of clusters to find, must be an integer
   * greater or equal to 0, where 0 means no limit.
   */
  public static final OptionID MAXITER_ID = OptionID.getOrCreateOptionID("kmeans.maxiter", "The maximum number of iterations to do. 0 means no limit.");

  /**
   * Parameter to specify the random generator seed.
   */
  public static final OptionID SEED_ID = OptionID.getOrCreateOptionID("kmeans.seed", "The random number generator seed.");

  /**
   * Parameter to specify the initialization method
   */
  public static final OptionID INIT_ID = OptionID.getOrCreateOptionID("kmeans.initialization", "Method to choose the initial means.");

  /**
   * Holds the value of {@link #K_ID}.
   */
  protected int k;

  /**
   * Holds the value of {@link #MAXITER_ID}.
   */
  protected int maxiter;

  /**
   * Method to choose initial means.
   */
  protected KMeansInitialization<V> initializer;

  /**
   * Constructor.
   *
   * @param distanceFunction distance function
   * @param k k parameter
   * @param maxiter Maxiter parameter
   */
  public AbstractKMeans(PrimitiveDistanceFunction<NumberVector<?, ?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) {
    super(distanceFunction);
    this.k = k;
    this.maxiter = maxiter;
    this.initializer = initializer;
  }

  /**
   * Returns a list of clusters. The k<sup>th</sup> cluster contains the ids of
   * those FeatureVectors, that are nearest to the k<sup>th</sup> mean.
   *
   * @param relation the database to cluster
   * @param means a list of k means
   * @param clusters cluster assignment
   * @return true when the object was reassigned
   */
  protected boolean assignToNearestCluster(Relation<V> relation, List<Vector> means, List<? extends ModifiableDBIDs> clusters) {
    boolean changed = false;

    if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
      @SuppressWarnings("unchecked")
      final PrimitiveDoubleDistanceFunction<? super NumberVector<?, ?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?, ?>>) getDistanceFunction();
      for(DBID id : relation.iterDBIDs()) {
        double mindist = Double.POSITIVE_INFINITY;
        V fv = relation.get(id);
        int minIndex = 0;
        for(int i = 0; i < k; i++) {
          double dist = df.doubleDistance(fv, means.get(i));
          if(dist < mindist) {
            minIndex = i;
            mindist = dist;
          }
        }
        if(clusters.get(minIndex).add(id)) {
          changed = true;
          // Remove from previous cluster
          // TODO: keep a list of cluster assignments to save this search?
          for(int i = 0; i < k; i++) {
            if(i != minIndex) {
              if(clusters.get(i).remove(id)) {
                break;
              }
            }
          }
        }
      }
    }
    else {
      final PrimitiveDistanceFunction<? super NumberVector<?, ?>, D> df = getDistanceFunction();
      for(DBID id : relation.iterDBIDs()) {
        D mindist = df.getDistanceFactory().infiniteDistance();
        V fv = relation.get(id);
        int minIndex = 0;
        for(int i = 0; i < k; i++) {
          D dist = df.distance(fv, means.get(i));
          if(dist.compareTo(mindist) < 0) {
            minIndex = i;
            mindist = dist;
          }
        }
        if(clusters.get(minIndex).add(id)) {
          changed = true;
          // Remove from previous cluster
          // TODO: keep a list of cluster assignments to save this search?
          for(int i = 0; i < k; i++) {
            if(i != minIndex) {
              if(clusters.get(i).remove(id)) {
                break;
              }
            }
          }
        }
      }
    }
    return changed;
  }

  @Override
  public TypeInformation[] getInputTypeRestriction() {
    return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
  }

  /**
   * Returns the mean vectors of the given clusters in the given database.
   *
   * @param clusters the clusters to compute the means
   * @param means the recent means
   * @param database the database containing the vectors
   * @return the mean vectors of the given clusters in the given database
   */
  protected List<Vector> means(List<? extends ModifiableDBIDs> clusters, List<Vector> means, Relation<V> database) {
    List<Vector> newMeans = new ArrayList<Vector>(k);
    for(int i = 0; i < k; i++) {
      ModifiableDBIDs list = clusters.get(i);
      Vector mean = null;
      for(Iterator<DBID> clusterIter = list.iterator(); clusterIter.hasNext();) {
        if(mean == null) {
          mean = database.get(clusterIter.next()).getColumnVector();
        }
        else {
          mean.plusEquals(database.get(clusterIter.next()).getColumnVector());
        }
      }
      if(list.size() > 0) {
        assert mean != null;
        mean.timesEquals(1.0 / list.size());
      }
      else {
        mean = means.get(i);
      }
      newMeans.add(mean);
    }
    return newMeans;
  }

  /**
   * Compute an incremental update for the mean
   *
   * @param mean Mean to update
   * @param vec Object vector
   * @param newsize (New) size of cluster
   * @param op Cluster size change / Weight change
   */
  protected void incrementalUpdateMean(Vector mean, V vec, int newsize, double op) {
    if(newsize == 0) {
      return; // Keep old mean
    }
    Vector delta = vec.getColumnVector();
    // Compute difference from mean
    delta.minusEquals(mean);
    delta.timesEquals(op / newsize);
    mean.plusEquals(delta);
  }

  /**
   * Perform a MacQueen style iteration.
   *
   * @param relation Relation
   * @param means Means
   * @param clusters Clusters
   * @return true when the means have changed
   */
  protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters) {
    boolean changed = false;
 
    if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
      // Raw distance function
      @SuppressWarnings("unchecked")
      final PrimitiveDoubleDistanceFunction<? super NumberVector<?, ?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?, ?>>) getDistanceFunction();
 
      // Incremental update
      for(DBID id : relation.iterDBIDs()) {
        double mindist = Double.POSITIVE_INFINITY;
        V fv = relation.get(id);
        int minIndex = 0;
        for(int i = 0; i < k; i++) {
          double dist = df.doubleDistance(fv, means.get(i));
          if(dist < mindist) {
            minIndex = i;
            mindist = dist;
          }
        }
        // Update the cluster mean incrementally:
        for(int i = 0; i < k; i++) {
          ModifiableDBIDs ci = clusters.get(i);
          if(i == minIndex) {
            if(ci.add(id)) {
              incrementalUpdateMean(means.get(i), relation.get(id), ci.size(), +1);
              changed = true;
            }
          }
          else if(ci.remove(id)) {
            incrementalUpdateMean(means.get(i), relation.get(id), ci.size() + 1, -1);
            changed = true;
          }
        }
      }
    }
    else {
      // Raw distance function
      final PrimitiveDistanceFunction<? super NumberVector<?, ?>, D> df = getDistanceFunction();
 
      // Incremental update
      for(DBID id : relation.iterDBIDs()) {
        D mindist = df.getDistanceFactory().infiniteDistance();
        V fv = relation.get(id);
        int minIndex = 0;
        for(int i = 0; i < k; i++) {
          D dist = df.distance(fv, means.get(i));
          if(dist.compareTo(mindist) < 0) {
            minIndex = i;
            mindist = dist;
          }
        }
        // Update the cluster mean incrementally:
        for(int i = 0; i < k; i++) {
          ModifiableDBIDs ci = clusters.get(i);
          if(i == minIndex) {
            if(ci.add(id)) {
              incrementalUpdateMean(means.get(i), relation.get(id), ci.size(), +1);
              changed = true;
            }
          }
          else if(ci.remove(id)) {
            incrementalUpdateMean(means.get(i), relation.get(id), ci.size() + 1, -1);
            changed = true;
          }
        }
      }
    }
    return changed;
  }
}
TOP

Related Classes of de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.AbstractKMeans

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.