package cc.mallet.cluster.examples;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.cluster.Clusterer;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.GreedyAgglomerativeByDensity;
import cc.mallet.cluster.evaluate.AccuracyEvaluator;
import cc.mallet.cluster.evaluate.BCubedEvaluator;
import cc.mallet.cluster.evaluate.ClusteringEvaluator;
import cc.mallet.cluster.evaluate.ClusteringEvaluators;
import cc.mallet.cluster.evaluate.MUCEvaluator;
import cc.mallet.cluster.evaluate.PairF1Evaluator;
import cc.mallet.cluster.iterator.ClusterSampleIterator;
import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.PropertyList;
import cc.mallet.util.Randoms;
/**
* Illustrates use of a supervised clustering method that uses
* features over clusters. Synthetic data is created where Instances
* belong in same cluster iff they each have a feature called
* "feature0".
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
*/
public class FirstOrderClusterExample {
Randoms random;
double noise;
public FirstOrderClusterExample () {
this.random = new Randoms(123456789);
this.noise = 0.01;
}
public void run () {
Alphabet alphabet = dictOfSize(20);
// TRAIN
Clustering training = sampleClustering(alphabet);
Pipe clusterPipe = new OverlappingFeaturePipe();
System.err.println("Training with " + training);
InstanceList trainList = new InstanceList(clusterPipe);
trainList.addThruPipe(new ClusterSampleIterator(training, random, 0.5, 100));
System.err.println("Created " + trainList.size() + " instances.");
Classifier me = new MaxEntTrainer().train(trainList);
ClassifyingNeighborEvaluator eval =
new ClassifyingNeighborEvaluator(me, "YES");
Trial trial = new Trial(me, trainList);
System.err.println(new ConfusionMatrix(trial));
InfoGain ig = new InfoGain(trainList);
ig.print();
// Clusterer clusterer = new GreedyAgglomerative(training.getInstances().getPipe(),
// eval, 0.5);
Clusterer clusterer = new GreedyAgglomerativeByDensity(training.getInstances().getPipe(),
eval, 0.5, false,
new java.util.Random(1));
// TEST
Clustering testing = sampleClustering(alphabet);
InstanceList testList = testing.getInstances();
Clustering predictedClusters = clusterer.cluster(testList);
// EVALUATE
System.err.println("\n\nEvaluating System: " + clusterer);
ClusteringEvaluators evaluators = new ClusteringEvaluators(new ClusteringEvaluator[]{
new BCubedEvaluator(),
new PairF1Evaluator(),
new MUCEvaluator(),
new AccuracyEvaluator()});
System.err.println("truth:" + testing);
System.err.println("pred: " + predictedClusters);
System.err.println(evaluators.evaluate(testing, predictedClusters));
}
/**
* Sample a InstanceList and its true clustering.
* @param alph
* @return
*/
private Clustering sampleClustering (Alphabet alph) {
InstanceList instances =
new InstanceList(random,
alph,
new String[]{"foo", "bar"},
30).subList(0, 20);
Clustering singletons = ClusterUtils.createSingletonClustering(instances);
// Merge instances that both have feature0
for (int i = 0; i < instances.size(); i++) {
FeatureVector fvi = (FeatureVector)instances.get(i).getData();
for (int j = i + 1; j < instances.size(); j++) {
FeatureVector fvj = (FeatureVector)instances.get(j).getData();
if (fvi.contains("feature0") && fvj.contains("feature0")) {
singletons = ClusterUtils.mergeClusters(singletons,
singletons.getLabel(i),
singletons.getLabel(j));
} else if (!(fvi.contains("feature0") || fvj.contains("feature0"))
&& random.nextUniform() < noise) {
// Random noise.
singletons = ClusterUtils.mergeClusters(singletons,
singletons.getLabel(i),
singletons.getLabel(j));
}
}
}
return singletons;
}
private Alphabet dictOfSize (int size) {
Alphabet ret = new Alphabet ();
for (int i = 0; i < size; i++)
ret.lookupIndex ("feature"+i);
return ret;
}
/**
* Computes a feature that indicates whether or not all members of a
* cluster have a feature named "feature0".
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
* @see Pipe
*/
private class OverlappingFeaturePipe extends Pipe {
private static final long serialVersionUID = 1L;
public OverlappingFeaturePipe () {
super (new Alphabet(), new LabelAlphabet());
}
public Instance pipe (Instance carrier) {
boolean mergeFirst = false;
AgglomerativeNeighbor neighbor = (AgglomerativeNeighbor)carrier.getData();
Clustering original = neighbor.getOriginal();
InstanceList list = original.getInstances();
int[] mergedIndices = neighbor.getNewCluster();
boolean match = true;
for (int i = 0; i < mergedIndices.length; i++) {
for (int j = i + 1; j < mergedIndices.length; j++) {
if ((original.getLabel(mergedIndices[i]) !=
original.getLabel(mergedIndices[j])) || mergeFirst) {
FeatureVector fvi = (FeatureVector)list.get(mergedIndices[i]).getData();
FeatureVector fvj = (FeatureVector)list.get(mergedIndices[j]).getData();
if (!(fvi.contains("feature0") && fvj.contains("feature0"))) {
match = false;
break;
}
}
}
}
PropertyList pl = null;
if (match)
pl = PropertyList.add("Match", 1.0, pl);
else
pl = PropertyList.add("NoMatch", 1.0, pl);
FeatureVector fv = new FeatureVector ((Alphabet)getDataAlphabet(),
pl, true);
carrier.setData(fv);
boolean positive = true;
for (int i = 0; i < mergedIndices.length; i++) {
for (int j = i + 1; j < mergedIndices.length; j++) {
if (original.getLabel(mergedIndices[i]) != original.getLabel(mergedIndices[j])) {
positive = false;
break;
}
}
}
LabelAlphabet ldict = (LabelAlphabet)getTargetAlphabet();
String label = positive ? "YES" : "NO";
carrier.setTarget(ldict.lookupLabel(label));
return carrier;
}
}
public static void main (String[] args) {
FirstOrderClusterExample ex = new FirstOrderClusterExample();
ex.run();
}
}