package cc.mallet.cluster.neighbor_evaluator;
//import weka.core.Instances;
import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;
/**
* Uses a {@link Classifier} over pairs of {@link Instances} to score
* {@link Neighbor}. Currently only supports {@link
* AgglomerativeNeighbor}s.
*
* @author "Michael Wick" <mwick@cs.umass.edu>
* @version 1.0
* @since 1.0
* @see ClassifyingNeighborEvaluator
*/
public class MedoidEvaluator extends ClassifyingNeighborEvaluator {
private static final long serialVersionUID = 1L;
/**
* If single link is true, then the score of clusters A and B is the score of the link between the two medoids.
*/
boolean singleLink=false;
/**
* How to combine a set of pairwise scores (e.g. mean, max, ...)... [currently not supported in this class]
*/
CombiningStrategy combiningStrategy;
/**
* If true, score all edges involved in a merge. If false, only
* score the edges that croess the boundaries of the clusters being
* merged.
*/
boolean mergeFirst=true;
/**
* Cache for calls to getScore. In some experiments, reduced running
* time by nearly half.
*/
PairwiseMatrix scoreCache;
/**
*
* @param classifier Classifier to assign scores to {@link
* Neighbor}s for which a pair of Instances has been merged.
* @param scoringLabel The predicted label that corresponds to a
* positive example (e.g. "YES").
* @param combiningStrategy How to combine the pairwise scores
* (e.g. max, mean, ...).
* @param mergeFirst If true, score all edges involved in a
* merge. If false, only score the edges that cross the boundaries
* of the clusters being merged.
* @return
*/
public MedoidEvaluator(Classifier classifier, String scoringLabel)
{
super(classifier,scoringLabel);
System.out.println("Using Medoid Evaluator");
}
public MedoidEvaluator(Classifier classifier, String scoringLabel,boolean singleLink,boolean mergeFirst)
{
super(classifier,scoringLabel);
this.singleLink=singleLink;
this.mergeFirst=mergeFirst;
System.out.println("Using Medoid Evaluator. Single link="+singleLink+".");
}
/*
public MedoidEvaluator (Classifier classifier,
String scoringLabel,
CombiningStrategy combiningStrategy,
boolean mergeFirst) {
super(classifier, scoringLabel);
this.combiningStrategy = combiningStrategy;
this.mergeFirst = mergeFirst;
System.out.println("Using Centroid Evaluator (2)");
}
*/
public double[] evaluate (Neighbor[] neighbors) {
double[] scores = new double[neighbors.length];
for (int i = 0; i < neighbors.length; i++)
scores[i] = evaluate(neighbors[i]);
return scores;
}
public double evaluate(Neighbor neighbor)
{
int result[] = new int[2];
if (!(neighbor instanceof AgglomerativeNeighbor))
throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
int[][] oldIndices = ((AgglomerativeNeighbor)neighbor).getOldClusters();
int[] mergedIndices=((AgglomerativeNeighbor)neighbor).getNewCluster();
Clustering original = neighbor.getOriginal();
result[0]=getCentroid(oldIndices[0],original);
result[1]=getCentroid(oldIndices[1],original);
if(singleLink) //scores a cluster based on link between medoid of each cluster
{
AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][result[0]],oldIndices[1][result[1]]);
double score = getScore(pwn);
return score;
}
//
//Returns average weighted average where weights are proportional to similarity to medoid
double[] medsA=getMedWeights(result[0],oldIndices[0],original);
double[] medsB=getMedWeights(result[1],oldIndices[1],original);
double numerator=0;
double denominator=0;
for(int i=0;i<oldIndices[0].length;i++)
{
//
//cross-boundary
for(int j=0;j<oldIndices[1].length;j++)
{
AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][i],oldIndices[1][j]);
double interScore=getScore(pwn);
numerator+=interScore*medsA[i]*medsB[j];
denominator+=medsA[i]*medsB[j];
}
//
//intra-cluster1
if(mergeFirst)
{
for(int j=i+1;j<oldIndices[0].length;j++)
{
AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][i],oldIndices[0][j]);
double interScore=getScore(pwn);
numerator+=interScore*medsA[i]*medsA[j];
denominator+=medsA[i]*medsA[j];
}
}
}
//
//intra-cluster2
if(mergeFirst)
{
for(int i=0;i<oldIndices[1].length;i++)
{
for(int j=i+1;j<oldIndices[1].length;j++)
{
AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[1][i],oldIndices[1][j]);
double interScore=getScore(pwn);
numerator+=interScore*medsB[i]*medsB[j];
denominator+=medsB[i]*medsB[j];
}
}
}
return numerator/denominator;
}
private double[] getMedWeights(int medIdx,int[] indices,Clustering original)
{
double result[] = new double[indices.length];
for(int i=0;i<result.length;i++)
{
if(medIdx==i)
result[i]=1;
else
{
AgglomerativeNeighbor an = new AgglomerativeNeighbor(original,original,indices[medIdx],indices[i]);
result[i] = getScore(an);
}
}
return result;
}
//
//a bettter strategy would use caching to incrimentally determine the centroid
private int getCentroid(int[] indices,Clustering original)
{
if(indices.length<2)
return 0;
//return indices[0];
double centDist=Double.NEGATIVE_INFINITY;
int centIdx=-1;
double[] scores = new double[indices.length];
for(int i=0;i<indices.length;i++)
{
double acc=0;
for(int k=0;k<indices.length;k++)
{
if(i==k)break;
AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,indices[i],indices[k]);
double score=getScore(pwn);
acc+=score;
//scores[i] = getScore(pwn);
}
acc/=(indices.length-1);
scores[i]=acc;
}
for(int i=0;i<scores.length;i++)
{
if(scores[i]>centDist)
{
centDist=scores[i];
centIdx=i;
//centIdx=indices[i];
}
}
return centIdx;
}
/*
public double evaluate (Neighbor neighbor) {
if (!(neighbor instanceof AgglomerativeNeighbor))
throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
Clustering original = neighbor.getOriginal();
int[] mergedIndices = ((AgglomerativeNeighbor)neighbor).getNewCluster();
ArrayList scores = new ArrayList();
for (int i = 0; i < mergedIndices.length; i++) {
for (int j = i + 1; j < mergedIndices.length; j++) {
if ((original.getLabel(mergedIndices[i]) != original.getLabel(mergedIndices[j])) || mergeFirst) {
AgglomerativeNeighbor pwneighbor =
new AgglomerativeNeighbor(original, original,
mergedIndices[i], mergedIndices[j]);
scores.add(new Double(getScore(pwneighbor)));
}
}
}
if (scores.size() < 1)
throw new IllegalStateException("No pairs of Instances were scored.");
double[] vals = new double[scores.size()];
for (int i = 0; i < vals.length; i++)
vals[i] = ((Double)scores.get(i)).doubleValue();
return combiningStrategy.combine(vals);
}
*/
public void reset () {
scoreCache = null;
}
public String toString () {
return "class=" + this.getClass().getName() +
" classifier=" + classifier.getClass().getName();
}
private double getScore (AgglomerativeNeighbor pwneighbor) {
if (scoreCache == null)
scoreCache = new PairwiseMatrix(pwneighbor.getOriginal().getNumInstances());
int[] indices = pwneighbor.getNewCluster();
if (scoreCache.get(indices[0], indices[1]) == 0.0) {
scoreCache.set(indices[0], indices[1],
classifier.classify(pwneighbor).getLabelVector().value(scoringLabel));
}
return scoreCache.get(indices[0], indices[1]);
}
/**
* Specifies how to combine a set of pairwise scores into a
* cluster-wise score.
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
*/
public static interface CombiningStrategy {
public double combine (double[] scores);
}
public static class Average implements CombiningStrategy {
public double combine (double[] scores) {
return MatrixOps.mean(scores);
}
}
public static class Minimum implements CombiningStrategy {
public double combine (double[] scores) {
return MatrixOps.min(scores);
}
}
public static class Maximum implements CombiningStrategy {
public double combine (double[] scores) {
return MatrixOps.max(scores);
}
}
}