Package cc.mallet.cluster.neighbor_evaluator

Source Code of cc.mallet.cluster.neighbor_evaluator.MedoidEvaluator$Average

package cc.mallet.cluster.neighbor_evaluator;


//import weka.core.Instances;
import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;

/**
* Uses a {@link Classifier} over pairs of {@link Instances} to score
* {@link Neighbor}. Currently only supports {@link
* AgglomerativeNeighbor}s.
*
* @author "Michael Wick" <mwick@cs.umass.edu>
* @version 1.0
* @since 1.0
* @see ClassifyingNeighborEvaluator
*/
public class MedoidEvaluator extends ClassifyingNeighborEvaluator {

  private static final long serialVersionUID = 1L;

  /**
     * If single link is true, then the score of clusters A and B is the score of the link between the two medoids.
     */
    boolean singleLink=false;

  /**
   * How to combine a set of pairwise scores (e.g. mean, max, ...)... [currently not supported in this class]
   */
  CombiningStrategy combiningStrategy;

  /**
   * If true, score all edges involved in a merge. If false, only
   * score the edges that croess the boundaries of the clusters being
   * merged.
   */
  boolean mergeFirst=true;

  /**
   * Cache for calls to getScore. In some experiments, reduced running
   * time by nearly half.
   */
  PairwiseMatrix scoreCache;
 
  /**
   *
   * @param classifier Classifier to assign scores to {@link
   * Neighbor}s for which a pair of Instances has been merged.
   * @param scoringLabel The predicted label that corresponds to a
   * positive example (e.g. "YES").
   * @param combiningStrategy How to combine the pairwise scores
   * (e.g. max, mean, ...).
   * @param mergeFirst If true, score all edges involved in a
   * merge. If false, only score the edges that cross the boundaries
   * of the clusters being merged.
   * @return
   */
    public MedoidEvaluator(Classifier classifier, String scoringLabel)
    {
  super(classifier,scoringLabel);
  System.out.println("Using Medoid Evaluator");
    }
    public MedoidEvaluator(Classifier classifier, String scoringLabel,boolean singleLink,boolean mergeFirst)
    {
  super(classifier,scoringLabel);
  this.singleLink=singleLink;
  this.mergeFirst=mergeFirst;
  System.out.println("Using Medoid Evaluator. Single link="+singleLink+".");
    }

    /*
    public MedoidEvaluator (Classifier classifier,
            String scoringLabel,
            CombiningStrategy combiningStrategy,
            boolean mergeFirst) {
    super(classifier, scoringLabel);
    this.combiningStrategy = combiningStrategy;
    this.mergeFirst = mergeFirst;
  System.out.println("Using Centroid Evaluator (2)");

  }
    */

  public double[] evaluate (Neighbor[] neighbors) {
    double[] scores = new double[neighbors.length];
    for (int i = 0; i < neighbors.length; i++)
      scores[i] = evaluate(neighbors[i]);
    return scores;
  }
 

   
    public double evaluate(Neighbor neighbor)
    {
  int result[] = new int[2];
  if (!(neighbor instanceof AgglomerativeNeighbor))
      throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
  int[][] oldIndices = ((AgglomerativeNeighbor)neighbor).getOldClusters();
  int[] mergedIndices=((AgglomerativeNeighbor)neighbor).getNewCluster();

  Clustering original = neighbor.getOriginal();

  result[0]=getCentroid(oldIndices[0],original);
  result[1]=getCentroid(oldIndices[1],original);
  if(singleLink) //scores a cluster based on link between medoid of each cluster
      {
    AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][result[0]],oldIndices[1][result[1]]);
    double score = getScore(pwn);
    return score;
      }

  //
  //Returns average weighted average where weights are proportional to similarity to medoid
  double[] medsA=getMedWeights(result[0],oldIndices[0],original);
  double[] medsB=getMedWeights(result[1],oldIndices[1],original);

  double numerator=0;
  double denominator=0;
  for(int i=0;i<oldIndices[0].length;i++)
      {
    //
    //cross-boundary
    for(int j=0;j<oldIndices[1].length;j++)
        {

      AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][i],oldIndices[1][j]);
      double interScore=getScore(pwn);
      numerator+=interScore*medsA[i]*medsB[j];
      denominator+=medsA[i]*medsB[j];
        }
    //
    //intra-cluster1
    if(mergeFirst)
        {
      for(int j=i+1;j<oldIndices[0].length;j++)
          {
        AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[0][i],oldIndices[0][j]);
        double interScore=getScore(pwn);
        numerator+=interScore*medsA[i]*medsA[j];
        denominator+=medsA[i]*medsA[j];
          }
        }
      }
  //
  //intra-cluster2
  if(mergeFirst)
      {
    for(int i=0;i<oldIndices[1].length;i++)
        {
      for(int j=i+1;j<oldIndices[1].length;j++)
          {
        AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,oldIndices[1][i],oldIndices[1][j]);
        double interScore=getScore(pwn);
        numerator+=interScore*medsB[i]*medsB[j];
        denominator+=medsB[i]*medsB[j];
          }
        }
      }

  return numerator/denominator;

    }
    private double[] getMedWeights(int medIdx,int[] indices,Clustering original)
    {
  double result[] = new double[indices.length];
  for(int i=0;i<result.length;i++)
      {
    if(medIdx==i)
        result[i]=1;
    else
        {
      AgglomerativeNeighbor an = new AgglomerativeNeighbor(original,original,indices[medIdx],indices[i]);
      result[i] = getScore(an);
        }
      }
  return result; 
    }


    //
    //a bettter strategy would use caching to incrimentally determine the centroid
    private int getCentroid(int[] indices,Clustering original)
    {
  if(indices.length<2)
      return 0;
      //return indices[0];

  double centDist=Double.NEGATIVE_INFINITY;
  int centIdx=-1;
  double[] scores = new double[indices.length];
  for(int i=0;i<indices.length;i++)
      {
    double acc=0;
    for(int k=0;k<indices.length;k++)
        {
      if(i==k)break;
      AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original,original,indices[i],indices[k]);
      double score=getScore(pwn);
      acc+=score;
      //scores[i] = getScore(pwn);
        }
    acc/=(indices.length-1);
    scores[i]=acc;
      }
  for(int i=0;i<scores.length;i++)
      {
    if(scores[i]>centDist)
        {
      centDist=scores[i];
      centIdx=i;
      //centIdx=indices[i];
        }
      }
  return centIdx;
    }
   
    /*
  public double evaluate (Neighbor neighbor) {
     if (!(neighbor instanceof AgglomerativeNeighbor))
       throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());

    Clustering original = neighbor.getOriginal();
    int[] mergedIndices = ((AgglomerativeNeighbor)neighbor).getNewCluster();
    ArrayList scores = new ArrayList();
     for (int i = 0; i < mergedIndices.length; i++) {
      for (int j = i + 1; j < mergedIndices.length; j++) {
        if ((original.getLabel(mergedIndices[i]) != original.getLabel(mergedIndices[j])) || mergeFirst) {
          AgglomerativeNeighbor pwneighbor =
            new AgglomerativeNeighbor(original,  original,
                                      mergedIndices[i], mergedIndices[j]);
          scores.add(new Double(getScore(pwneighbor)));
        }
      }
    }

    if (scores.size() < 1)
      throw new IllegalStateException("No pairs of Instances were scored.");
   
     double[] vals = new double[scores.size()];
    for (int i = 0; i < vals.length; i++)
      vals[i] = ((Double)scores.get(i)).doubleValue();
     return combiningStrategy.combine(vals);
  }
    */

  public void reset () {
    scoreCache = null;
  }
 
  public String toString () {
    return "class=" + this.getClass().getName() +
      " classifier=" + classifier.getClass().getName();
  }

  private double getScore (AgglomerativeNeighbor pwneighbor) {
    if (scoreCache == null)
      scoreCache = new PairwiseMatrix(pwneighbor.getOriginal().getNumInstances());
    int[] indices = pwneighbor.getNewCluster();
    if (scoreCache.get(indices[0], indices[1]) == 0.0) {
      scoreCache.set(indices[0], indices[1],
                 classifier.classify(pwneighbor).getLabelVector().value(scoringLabel));
    }
    return scoreCache.get(indices[0], indices[1]);
  }

  /**
   * Specifies how to combine a set of pairwise scores into a
   * cluster-wise score.
   *
   * @author "Aron Culotta" <culotta@degas.cs.umass.edu>
   * @version 1.0
   * @since 1.0
   */
  public static interface CombiningStrategy {
    public double combine (double[] scores);
  }

  public static class Average implements CombiningStrategy {
    public double combine (double[] scores) {
      return MatrixOps.mean(scores);
    }   
  }

  public static class Minimum implements CombiningStrategy {
    public double combine (double[] scores) {
      return MatrixOps.min(scores);
    }   
  }

  public static class Maximum implements CombiningStrategy {
    public double combine (double[] scores) {
      return MatrixOps.max(scores);
    }   
  }
}
TOP

Related Classes of cc.mallet.cluster.neighbor_evaluator.MedoidEvaluator$Average

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.