Package cc.mallet.classify.evaluate

Source Code of cc.mallet.classify.evaluate.ConfusionMatrix

/* Copyright (C) 2002 Dept. of Computer Science, Univ. of Massachusetts, Amherst

   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet

   This program toolkit free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License as
   published by the Free Software Foundation; either version 2 of the
   License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  For more
   details see the GNU General Public License and the file README-LEGAL.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
   02111-1307, USA. */


/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.classify.evaluate;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.*;
import java.text.*;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Trial;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;

/**
* Calculates and prints confusion matrix, accuracy,
* and precision for a given clasification trial.
*/
public class ConfusionMatrix
{
  private static Logger logger = MalletLogger.getLogger(ConfusionMatrix.class.getName());
 
  int numClasses;
  /**
   * the list of classifications from the trial
   */
  ArrayList classifications;
  /**
   * 2-d confiusion matrix
   */
  int[][] values;

  Trial trial;

  /**
   * Constructs matrix and calculates values
   * @param t the trial to build matrix from
   */
  public ConfusionMatrix(Trial t)
  {
    this.trial = t;
    this.classifications = t;
    Labeling tempLabeling =
      ((Classification)classifications.get(0)).getLabeling();
    this.numClasses = tempLabeling.getLabelAlphabet().size();
    values = new int[numClasses][numClasses];
    for(int i=0; i < classifications.size(); i++)
    {
      LabelVector lv =
        ((Classification)classifications.get(i)).getLabelVector();
      Instance inst = ((Classification)classifications.get(i)).getInstance();
      int bestIndex = lv.getBestIndex();
      int correctIndex = inst.getLabeling().getBestIndex();
      assert(correctIndex != -1);
      //System.out.println("Best index="+bestIndex+". Correct="+correctIndex);
      values[correctIndex][bestIndex]++;
    }     
  }

  /** Return the count at row i (true) , column j (predicted) */
   double value(int i, int j)
   {
      assert(i >= 0 && j >= 0 && i < numClasses && j < numClasses);
      return values[i][j];     
   }

  static private void appendJustifiedInt (StringBuffer sb, int i, boolean zeroDot) {
    if (i < 100) sb.append (' ');
    if (i < 10) sb.append (' ');
    if (i == 0 && zeroDot)
      sb.append (".");
    else
      sb.append (""+i);
  }

  public String toString () {
    StringBuffer sb = new StringBuffer ();
    int maxLabelNameLength = 0;
    LabelAlphabet labelAlphabet = trial.getClassifier().getLabelAlphabet();
    for (int i = 0; i < numClasses; i++) {
      int len = labelAlphabet.lookupLabel(i).toString().length();
      if (maxLabelNameLength < len)
        maxLabelNameLength = len;
    }

    sb.append ("Confusion Matrix, row=true, column=predicted  accuracy="+trial.getAccuracy()+"\n");
    for (int i = 0; i < maxLabelNameLength-5+4; i++) sb.append (' ');
    sb.append ("label");
    for (int c2 = 0; c2 < Math.min(10,numClasses); c2++sb.append ("   "+c2);
    for (int c2 = 10; c2 < numClasses; c2++sb.append ("  "+c2);
    sb.append ("  |total\n");
    for (int c = 0; c < numClasses; c++) {
      appendJustifiedInt (sb, c, false);
      String labelName = labelAlphabet.lookupLabel(c).toString();
      for (int i = 0; i < maxLabelNameLength-labelName.length(); i++) sb.append (' ');
      sb.append (" "+labelName+" ");
      for (int c2 = 0; c2 < numClasses; c2++) {
        appendJustifiedInt (sb, values[c][c2], true);
        sb.append (' ');
      }
      sb.append (" |"+ MatrixOps.sum(values[c]));
      sb.append ('\n');
    }
    return sb.toString();
  }
 
  /**
   * Returns the precision of this predicted class
   */
  public double getPrecision (int predictedClassIndex)
  {
    int total = 0;
    for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {
      total += values[trueClassIndex][predictedClassIndex];
    }
    if (total == 0)
      return 0.0;
    else
      return (double) (values[predictedClassIndex][predictedClassIndex]) / total;
  }
 
  /**
   * Returns percent of time that class2 is true class when
   * class1 is predicted class
   *
   */
  public double getConfusionBetween (int class1, int class2)
  {
    int total = 0;
    for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {
      total += values[trueClassIndex][class1];
    }
    if (total == 0)
      return 0.0;
    else
      return (double) (values[class2][class1]) / total;     
  }

  /**
   * Returns the percentage of instances with
   * true label = classIndex
   */
  public double getClassPrior (int classIndex)
  {
    int sum= 0;
    for(int i=0; i < numClasses; i++)
      sum += values[classIndex][i];
    return (double)sum / classifications.size();
  }



  /**
   * prints to stdout the confusion matrix,
   * class frequency, precision, and recall
   */
  /*
  public void print()
  {
    double totalPrecision = 0;
    double totalRecall = 0;
    double totalF1 = 0;
    HashMap index2class = new HashMap();
    LabelVector lv =
      ((Classification)classifications.get(0)).getLabelVector();
    DecimalFormat df = new DecimalFormat("###.##");
    int [] numInstances = new int[this.numClasses];
    for(int i=0; i<this.numClasses; i++){
      int count = 0;
      for(int j=0; j<this.numClasses; j++)
        count += values[i][j];
      numInstances[i] = count;
      String label = lv.labelAtLocation(i).toString();
      System.out.println("index "+i+": "+label+
                         " "+count+" instances "+
                         df.format(100*(double)count/classifications.size())
        +"%");
      index2class.put (new Integer (i), label);
    }
    System.out.println("Confusion Matrix");
    for(int i=0; i<this.numClasses; i++){
      for(int j=0; j<this.numClasses; j++)
        System.out.print(values[j][i]+"\t\t");
      System.out.println("");
    }
    for(int i=0; i<this.numClasses; i++){
      double recall = 100.0*(double)values[j][j]/numInstances[i];
      double precision;
      int rowCount = 0;
      for(int j=0; j<this.numClasses; j++)
        rowCount += values[j][i];
      if (rowCount == 0)
        precision = 0;
      else
        precision = 100.0*(double)values[j][j] / rowCount;
      double f1;
      if (precision + recall == 0.0)
        f1 = 0;
      else
        f1 = 2 * precision * recall / (precision + recall);
      System.out.println("Class " +
                         (String)index2class.get(new Integer (i)));
      System.out.println("F1="+df.format(f1)+"%");
      System.out.println("Recall="+df.format(recall)+"%");
      System.out.println("Precision="+df.format(precision)+"%");
      totalPrecision += precision;
      totalRecall += recall;
      totalF1 += f1;
    }
   
    int numCorrect = 0;
    int totalInstances = 0;
    for(int i=0; i<this.numClasses; i++)
    {
      numCorrect += values[j][j];
      totalInstances+=numInstances[i];
    }
    System.out.println("Overall Accuracy="+
                       df.format(100.0*(double)numCorrect/totalInstances)+"%");
    System.out.println ("Average F1: " + (totalF1 / this.numClasses) +
                        "\nAverage Precision: " +
                        (totalPrecision / this.numClasses) +
                        "\nAverage Recall: " + (totalRecall / this.numClasses));
  }
*/

}














 
TOP

Related Classes of cc.mallet.classify.evaluate.ConfusionMatrix

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.