package cc.mallet.cluster.iterator;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.types.Instance;
import cc.mallet.util.Randoms;
/**
* Samples merges of a singleton cluster with another (possibly
* non-singleton) cluster.
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
* @see PairSampleIterator, NeighborIterator
*/
public class NodeClusterSampleIterator extends ClusterSampleIterator {
/**
*
* @param clustering True clustering.
* @param random Source of randomness.
* @param positiveProportion Proportion of Instances that should be positive examples.
* @param numberSamples Total number of samples to generate.
* @return
*/
public NodeClusterSampleIterator (Clustering clustering,
Randoms random,
double positiveProportion,
int numberSamples) {
super(clustering, random, positiveProportion, numberSamples);
this.random=random;
this.positiveProportion=positiveProportion;
this.numberSamples=numberSamples;
}
public Instance next () {
AgglomerativeNeighbor neighbor = null;
if (positiveCount < positiveTarget && nonsingletonClusters.length>0){ // Sample positive.
positiveCount++;
int label = nonsingletonClusters[random.nextInt(nonsingletonClusters.length)];
int[] instances = clustering.getIndicesWithLabel(label);
int[] subcluster = sampleFromArray(instances, random, 2);
int[] cluster1 = new int[]{subcluster[random.nextInt(subcluster.length)]}; // Singleton.
int[] cluster2 = new int[subcluster.length - 1];
int nadded = 0;
for (int i = 0; i < subcluster.length; i++)
if (subcluster[i] != cluster1[0])
cluster2[nadded++] = subcluster[i];
neighbor = new AgglomerativeNeighbor(clustering,
clustering,
cluster1,
cluster2);
} else { // Sample negative.
int labeli = random.nextInt(clustering.getNumClusters());
int labelj = random.nextInt(clustering.getNumClusters());
while (labeli == labelj)
labelj = random.nextInt(clustering.getNumClusters());
int[] ii = sampleFromArray(clustering.getIndicesWithLabel(labeli), random, 1);
int[] ij = sampleFromArray(clustering.getIndicesWithLabel(labelj), random, 1);
neighbor =
new AgglomerativeNeighbor(clustering,
ClusterUtils.copyAndMergeClusters(clustering,
labeli,
labelj),
ii,
new int[]{ij[random.nextInt(ij.length)]});
}
totalCount++;
return new Instance(neighbor, null, null, null);
}
}