center.update(row);
}
// Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints
// this accelerates seed selection.
double radius = 0;
DistanceMeasure l2 = new SquaredEuclideanDistanceMeasure();
for (WeightedVector row : datapoints) {
radius += l2.distance(row, center);
}
// Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that
// the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done
// by first selecting c_1 with probability:
//
// p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2
//
// This can be simplified to:
//
// p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X))
//
// where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2
//
// All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability
// proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2.
// Multinomial distribution of vector indices for the selection seeds. These correspond to
// the indices of the vectors in the original datapoints list.
Multinomial<Integer> seedSelector = new Multinomial<Integer>();
for (int i = 0; i < datapoints.size(); ++i) {
double selectionProbability =
radius + datapoints.size() * l2.distance(datapoints.get(i), center);
seedSelector.add(i, selectionProbability);
}
Centroid c_1 = new Centroid((WeightedVector)datapoints.get(seedSelector.sample()).clone());
c_1.setIndex(0);
// Construct a set of weighted things which can be used for random selection. Initial weights are
// set to the squared distance from c_1
for (int i = 0; i < datapoints.size(); ++i) {
WeightedVector row = datapoints.get(i);
final double w = l2.distance(c_1, row) * row.getWeight();
seedSelector.set(i, w);
}
// From here, seeds are selected with probablity proportional to:
//
// r_i = min_{c_j} || x_i - c_j ||^2
//
// when we only have c_1, we have already set these distances and as we select each new
// seed, we update the minimum distances.
centroids.add(c_1);
int clusterIndex = 1;
while (centroids.size() < numClusters) {
// Select according to weights.
int seedIndex = seedSelector.sample();
Centroid nextSeed = new Centroid((WeightedVector)datapoints.get(seedIndex).clone());
nextSeed.setIndex(clusterIndex++);
centroids.add(nextSeed);
// Don't select this one again.
seedSelector.set(seedIndex, 0);
// Re-weight everything according to the minimum distance to a seed.
for (int currSeedIndex : seedSelector) {
WeightedVector curr = datapoints.get(currSeedIndex);
double newWeight = nextSeed.getWeight() * l2.distance(nextSeed, curr);
if (newWeight < seedSelector.getWeight(currSeedIndex)) {
seedSelector.set(currSeedIndex, newWeight);
}
}
}