Package weka.gui.beans

Source Code of weka.gui.beans.CostBenefitAnalysis$AnalysisPanel$ConfusionCell

/*
*    This program is free software; you can redistribute it and/or modify
*    it under the terms of the GNU General Public License as published by
*    the Free Software Foundation; either version 2 of the License, or
*    (at your option) any later version.
*
*    This program is distributed in the hope that it will be useful,
*    but WITHOUT ANY WARRANTY; without even the implied warranty of
*    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*    GNU General Public License for more details.
*
*    You should have received a copy of the GNU General Public License
*    along with this program; if not, write to the Free Software
*    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/

/*
*    CostBenefitAnalysis.java
*    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
*
*/

package weka.gui.beans;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.GridLayout;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
import java.beans.EventSetDescriptor;
import java.beans.PropertyVetoException;
import java.beans.VetoableChangeListener;
import java.beans.beancontext.BeanContext;
import java.beans.beancontext.BeanContextChild;
import java.beans.beancontext.BeanContextChildSupport;
import java.io.Serializable;
import java.util.Enumeration;
import java.util.Vector;

import javax.swing.BorderFactory;
import javax.swing.ButtonGroup;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JRadioButton;
import javax.swing.JSlider;
import javax.swing.JTextField;
import javax.swing.SwingConstants;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;

import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.Logger;
import weka.gui.visualize.VisualizePanel;
import weka.gui.visualize.Plot2D;
import weka.gui.visualize.PlotData2D;


/**
* Bean that aids in analyzing cost/benefit tradeoffs.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 6137 $
*/
public class CostBenefitAnalysis extends JPanel
  implements BeanCommon, ThresholdDataListener, Visible, UserRequestAcceptor,
  Serializable, BeanContextChild {
 
  /** For serialization */
  private static final long serialVersionUID = 8647471654613320469L;

  protected BeanVisual m_visual;
 
  protected transient JFrame m_popupFrame;

  protected boolean m_framePoppedUp = false;
 
  private transient AnalysisPanel m_analysisPanel;
 
  /**
   * True if this bean's appearance is the design mode appearance
   */
  protected boolean m_design;

  /**
   * BeanContex that this bean might be contained within
   */
  protected transient BeanContext m_beanContext = null;
 
  /**
   * BeanContextChild support
   */
  protected BeanContextChildSupport m_bcSupport =
    new BeanContextChildSupport(this);
 
  /**
   * The object sending us data (we allow only one connection at any one time)
   */
  protected Object m_listenee;
 
  /**
   * Inner class for displaying the plots and all control widgets.
   *
   * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
   */
  protected static class AnalysisPanel extends JPanel {
   
    /** For serialization */
    private static final long serialVersionUID = 5364871945448769003L;

    /** Displays the performance graphs(s) */
    protected VisualizePanel m_performancePanel = new VisualizePanel();
   
    /** Displays the cost/benefit (profit/loss) graph */
    protected VisualizePanel m_costBenefitPanel = new VisualizePanel();
   
    /**
     * The class attribute from the data that was used to generate
     * the threshold curve
     */
    protected Attribute m_classAttribute;
   
    /** Data for the threshold curve */
    protected PlotData2D m_masterPlot;
   
    /** Data for the cost/benefit curve */
    protected PlotData2D m_costBenefit;
   
    /** The size of the points being plotted */
    protected int[] m_shapeSizes;
   
    /** The index of the previous plotted point that was highlighted */
    protected int m_previousShapeIndex = -1;
       
    /** The slider for adjusting the threshold */
    protected JSlider m_thresholdSlider = new JSlider(0,100,0);
   
    protected JRadioButton m_percPop = new JRadioButton("% of Population");
    protected JRadioButton m_percOfTarget = new JRadioButton("% of Target (recall)");
    protected JRadioButton m_threshold = new JRadioButton("Score Threshold");
   
    protected JLabel m_percPopLab = new JLabel();
    protected JLabel m_percOfTargetLab = new JLabel();
    protected JLabel m_thresholdLab = new JLabel();
   
    // Confusion matrix stuff
    protected JLabel m_conf_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT);
    protected JLabel m_conf_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT);
    protected JLabel m_conf_actualA = new JLabel(" Actual (a):");
    protected JLabel m_conf_actualB = new JLabel(" Actual (b):");
    protected ConfusionCell m_conf_aa = new ConfusionCell();
    protected ConfusionCell m_conf_ab = new ConfusionCell();
    protected ConfusionCell m_conf_ba = new ConfusionCell();
    protected ConfusionCell m_conf_bb = new ConfusionCell();
   
    // Cost matrix stuff
    protected JLabel m_cost_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT);
    protected JLabel m_cost_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT);
    protected JLabel m_cost_actualA = new JLabel(" Actual (a)");
    protected JLabel m_cost_actualB = new JLabel(" Actual (b)");
    protected JTextField m_cost_aa = new JTextField("0.0", 5);
    protected JTextField m_cost_ab = new JTextField("1.0", 5);
    protected JTextField m_cost_ba = new JTextField("1.0", 5);
    protected JTextField m_cost_bb = new JTextField("0.0" ,5);
    protected JButton m_maximizeCB = new JButton("Maximize Cost/Benefit");
    protected JButton m_minimizeCB = new JButton("Minimize Cost/Benefit");
    protected JRadioButton m_costR = new JRadioButton("Cost");
    protected JRadioButton m_benefitR = new JRadioButton("Benefit");
    protected JLabel m_costBenefitL = new JLabel("Cost: ", SwingConstants.RIGHT);
    protected JLabel m_costBenefitV = new JLabel("0");
    protected JLabel m_randomV = new JLabel("0");
    protected JLabel m_gainV = new JLabel("0");
   
    protected int m_originalPopSize;
   
    /** Population text field */
    protected JTextField m_totalPopField = new JTextField(6);
    protected int m_totalPopPrevious;
   
    /** Classification accuracy */
    protected JLabel m_classificationAccV = new JLabel("-");
   
    // Only update curve & stats if values in cost matrix have changed
    protected double m_tpPrevious;
    protected double m_fpPrevious;
    protected double m_tnPrevious;
    protected double m_fnPrevious;
   
    /**
     * Inner class for handling a single cell in the confusion matrix.
     * Displays the value, value as a percentage of total population and
     * graphical depiction of percentage.
     *
     * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
     */
    protected static class ConfusionCell extends JPanel {

      /** For serialization */
      private static final long serialVersionUID = 6148640235434494767L;
     
      private JLabel m_conf_cell = new JLabel("-", SwingConstants.RIGHT);
      JLabel m_conf_perc = new JLabel("-", SwingConstants.RIGHT);
     
      private JPanel m_percentageP;
     
      protected double m_percentage = 0;
     
      public ConfusionCell() {
        setLayout(new BorderLayout());
        setBorder(BorderFactory.createEtchedBorder());
       
        add(m_conf_cell, BorderLayout.NORTH);
       
        m_percentageP = new JPanel() {
          public void paintComponent(Graphics gx) {
            super.paintComponent(gx);
           
            if (m_percentage > 0) {
              gx.setColor(Color.BLUE);
              int height = this.getHeight();
              double width = this.getWidth();
              int barWidth = (int)(m_percentage * width);
              gx.fillRect(0, 0, barWidth, height);
            }
          }
        };
       
        Dimension d = new Dimension(30,5);
        m_percentageP.setMinimumSize(d);
        m_percentageP.setPreferredSize(d);
        JPanel percHolder = new JPanel();
        percHolder.setLayout(new BorderLayout());
        percHolder.add(m_percentageP, BorderLayout.CENTER);
        percHolder.add(m_conf_perc, BorderLayout.EAST);
       
        add(percHolder, BorderLayout.SOUTH);
      }
     
      /**
       * Set the value of a cell.
       *
       * @param cellValue the value of the cell
       * @param max the max (for setting value as a percentage)
       * @param scaleFactor scale the value by this amount
       * @param precision precision for the percentage value
       */
      public void setCellValue(double cellValue, double max, double scaleFactor, int precision) {
        if (!Utils.isMissingValue(cellValue)) {
          m_percentage = cellValue / max;
        } else {
          m_percentage = 0;
        }
       
        m_conf_cell.setText(Utils.doubleToString((cellValue * scaleFactor), 0));
        m_conf_perc.setText(Utils.doubleToString(m_percentage * 100.0, precision) + "%");
       
        // refresh the percentage bar
        m_percentageP.repaint();
      }
    }
   
    public AnalysisPanel() {
      setLayout(new BorderLayout());
      m_performancePanel.setShowAttBars(false);
      m_performancePanel.setShowClassPanel(false);
      m_costBenefitPanel.setShowAttBars(false);
      m_costBenefitPanel.setShowClassPanel(false);
     
      Dimension size = new Dimension(500, 400);
      m_performancePanel.setPreferredSize(size);
      m_performancePanel.setMinimumSize(size);
     
      size = new Dimension(500, 400);
      m_costBenefitPanel.setMinimumSize(size);
      m_costBenefitPanel.setPreferredSize(size);
     
      m_thresholdSlider.addChangeListener(new ChangeListener() {
        public void stateChanged(ChangeEvent e) {
          updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0);
        }
      });
     
      JPanel plotHolder = new JPanel();
      plotHolder.setLayout(new GridLayout(1,2));     
      plotHolder.add(m_performancePanel);
      plotHolder.add(m_costBenefitPanel);
      add(plotHolder, BorderLayout.CENTER);
     
      JPanel lowerPanel = new JPanel();
      lowerPanel.setLayout(new BorderLayout());
     
      ButtonGroup bGroup = new ButtonGroup();
      bGroup.add(m_percPop);
      bGroup.add(m_percOfTarget);
      bGroup.add(m_threshold);
     
      ButtonGroup bGroup2 = new ButtonGroup();
      bGroup2.add(m_costR);
      bGroup2.add(m_benefitR);
      ActionListener rl = new ActionListener() {
        public void actionPerformed(ActionEvent e) {
          if (m_costR.isSelected()) {
            m_costBenefitL.setText("Cost: ");
          } else {
            m_costBenefitL.setText("Benefit: ");
          }

          double gain = Double.parseDouble(m_gainV.getText());
          gain = -gain;
          m_gainV.setText(Utils.doubleToString(gain, 2));
        }
      };
      m_costR.addActionListener(rl);
      m_benefitR.addActionListener(rl);
      m_costR.setSelected(true);
     
      m_percPop.setSelected(true);
      JPanel threshPanel = new JPanel();
      threshPanel.setLayout(new BorderLayout());
      JPanel radioHolder = new JPanel();
      radioHolder.setLayout(new FlowLayout());
      radioHolder.add(m_percPop);
      radioHolder.add(m_percOfTarget);
      radioHolder.add(m_threshold);
      threshPanel.add(radioHolder, BorderLayout.NORTH);
      threshPanel.add(m_thresholdSlider, BorderLayout.SOUTH);
     
      JPanel threshInfoPanel = new JPanel();
      threshInfoPanel.setLayout(new GridLayout(3,2));
      threshInfoPanel.add(new JLabel("% of Population: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_percPopLab);
      threshInfoPanel.add(new JLabel("% of Target: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_percOfTargetLab);
      threshInfoPanel.add(new JLabel("Score Threshold: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_thresholdLab);
     
      JPanel threshHolder = new JPanel();
      threshHolder.setBorder(BorderFactory.createTitledBorder("Threshold"));
      threshHolder.setLayout(new BorderLayout());
      threshHolder.add(threshPanel, BorderLayout.CENTER);
      threshHolder.add(threshInfoPanel, BorderLayout.EAST);
     
      lowerPanel.add(threshHolder, BorderLayout.NORTH);
     
      // holder for the two matrixes
      JPanel matrixHolder = new JPanel();
      matrixHolder.setLayout(new GridLayout(1,2));
     
      // confusion matrix
      JPanel confusionPanel = new JPanel();
      confusionPanel.setLayout(new GridLayout(3,3));
      confusionPanel.add(m_conf_predictedA);
      confusionPanel.add(m_conf_predictedB);
      confusionPanel.add(new JLabel()); // dummy
      confusionPanel.add(m_conf_aa);
      confusionPanel.add(m_conf_ab);
      confusionPanel.add(m_conf_actualA);
      confusionPanel.add(m_conf_ba);
      confusionPanel.add(m_conf_bb);
      confusionPanel.add(m_conf_actualB);
      JPanel tempHolderCA = new JPanel();
      tempHolderCA.setLayout(new BorderLayout());
      tempHolderCA.setBorder(BorderFactory.createTitledBorder("Confusion Matrix"));
      tempHolderCA.add(confusionPanel, BorderLayout.CENTER);
     
      JPanel accHolder = new JPanel();
      accHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
      accHolder.add(new JLabel("Classification Accuracy: "));
      accHolder.add(m_classificationAccV);
      tempHolderCA.add(accHolder, BorderLayout.SOUTH);
     
      matrixHolder.add(tempHolderCA);
     
      // cost matrix
      JPanel costPanel = new JPanel();
      costPanel.setBorder(BorderFactory.createTitledBorder("Cost Matrix"));
      costPanel.setLayout(new BorderLayout());
     
      JPanel cmHolder = new JPanel();
      cmHolder.setLayout(new GridLayout(3, 3));
      cmHolder.add(m_cost_predictedA);     
      cmHolder.add(m_cost_predictedB);
      cmHolder.add(new JLabel()); // dummy
      cmHolder.add(m_cost_aa);
      cmHolder.add(m_cost_ab);
      cmHolder.add(m_cost_actualA);
      cmHolder.add(m_cost_ba);
      cmHolder.add(m_cost_bb);
      cmHolder.add(m_cost_actualB);
      costPanel.add(cmHolder, BorderLayout.CENTER);
     
      FocusListener fl = new FocusListener() {
        public void focusGained(FocusEvent e) {
         
        }
       
        public void focusLost(FocusEvent e) {
          if (constructCostBenefitData()) {
            try {
              m_costBenefitPanel.setMasterPlot(m_costBenefit);
              m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();
            } catch (Exception ex) {
              ex.printStackTrace();
            }
            updateCostBenefit();
          }
        }
      };
     
      ActionListener al = new ActionListener() {
        public void actionPerformed(ActionEvent e) {
          if (constructCostBenefitData()) {
            try {
              m_costBenefitPanel.setMasterPlot(m_costBenefit);
              m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();
            } catch (Exception ex) {
              ex.printStackTrace();
            }
            updateCostBenefit();
          }
        }
      };
           
      m_cost_aa.addFocusListener(fl);
      m_cost_aa.addActionListener(al);
      m_cost_ab.addFocusListener(fl);
      m_cost_ab.addActionListener(al);
      m_cost_ba.addFocusListener(fl);
      m_cost_ba.addActionListener(al);
      m_cost_bb.addFocusListener(fl);
      m_cost_bb.addActionListener(al);
     
      m_totalPopField.addFocusListener(fl);
      m_totalPopField.addActionListener(al);
     
      JPanel cbHolder = new JPanel();
      cbHolder.setLayout(new BorderLayout());
      JPanel tempP = new JPanel();
      tempP.setLayout(new GridLayout(3, 2));
      tempP.add(m_costBenefitL);
      tempP.add(m_costBenefitV);
      tempP.add(new JLabel("Random: ", SwingConstants.RIGHT));
      tempP.add(m_randomV);
      tempP.add(new JLabel("Gain: ", SwingConstants.RIGHT));
      tempP.add(m_gainV);
      cbHolder.add(tempP, BorderLayout.NORTH);
      JPanel butHolder = new JPanel();
      butHolder.setLayout(new GridLayout(2, 1));
      butHolder.add(m_maximizeCB);
      butHolder.add(m_minimizeCB);
      m_maximizeCB.addActionListener(new ActionListener() {
        public void actionPerformed(ActionEvent e) {
          findMaxMinCB(true);
        }
      });
     
      m_minimizeCB.addActionListener(new ActionListener() {
        public void actionPerformed(ActionEvent e) {
          findMaxMinCB(false);
        }
      });
     
      cbHolder.add(butHolder, BorderLayout.SOUTH);
      costPanel.add(cbHolder, BorderLayout.EAST);
     
      JPanel popCBR = new JPanel();
      popCBR.setLayout(new GridLayout(1, 2));
      JPanel popHolder = new JPanel();
      popHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
      popHolder.add(new JLabel("Total Population: "));
      popHolder.add(m_totalPopField);
     
      JPanel radioHolder2 = new JPanel();
      radioHolder2.setLayout(new FlowLayout(FlowLayout.RIGHT));
      radioHolder2.add(m_costR);
      radioHolder2.add(m_benefitR);
      popCBR.add(popHolder);
      popCBR.add(radioHolder2);
     
      costPanel.add(popCBR, BorderLayout.SOUTH);
     
      matrixHolder.add(costPanel);
     
     
      lowerPanel.add(matrixHolder, BorderLayout.SOUTH);
     


//      popAccHolder.add(popHolder);
     
      //popAccHolder.add(accHolder);
     
      /*JPanel lowerPanel2 = new JPanel();
      lowerPanel2.setLayout(new BorderLayout());
      lowerPanel2.add(lowerPanel, BorderLayout.NORTH);
      lowerPanel2.add(popAccHolder, BorderLayout.SOUTH); */
     
      add(lowerPanel, BorderLayout.SOUTH);
     
    }
   
    private void findMaxMinCB(boolean max) {
      double maxMin = (max)
      ? Double.NEGATIVE_INFINITY
          : Double.POSITIVE_INFINITY;
     
      Instances cBCurve = m_costBenefit.getPlotInstances();
      int maxMinIndex = 0;
     
      for (int i = 0; i < cBCurve.numInstances(); i++) {
        Instance current = cBCurve.instance(i);
        if (max) {
          if (current.value(1) > maxMin) {
            maxMin = current.value(1);
            maxMinIndex = i;
          }
        } else {
          if (current.value(1) < maxMin) {
            maxMin = current.value(1);
            maxMinIndex = i;
          }
        }
      }
     
     
      // set the slider to the correct position
      int indexOfSampleSize =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;
     
      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;          
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }
     
      double valueOfMetric = m_masterPlot.getPlotInstances().instance(maxMinIndex).value(indexOfMetric);
      valueOfMetric *= 100.0;
     
      // set the approximate location of the slider
      m_thresholdSlider.setValue((int)valueOfMetric);
     
      // make sure the actual values relate to the true min/max rather
      // than being off due to slider location error.
      updateInfoGivenIndex(maxMinIndex);
    }
   
    private void updateCostBenefit() {
      double value = (double)m_thresholdSlider.getValue() / 100.0;
      Instances plotInstances = m_masterPlot.getPlotInstances();
      int indexOfSampleSize =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;
     
      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;          
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }
     
      int index = findIndexForValue(value, plotInstances, indexOfMetric);
      updateCBRandomGainInfo(index);
    }
   
    private void updateCBRandomGainInfo(int index) {
      double requestedPopSize = m_originalPopSize;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {}
      double scaleFactor = requestedPopSize / m_originalPopSize;
     
      double CB = m_costBenefit.
        getPlotInstances().instance(index).value(1);
      m_costBenefitV.setText(Utils.doubleToString(CB,2));
     
      double totalRandomCB = 0.0;
      Instance first = m_masterPlot.getPlotInstances().instance(0);
      double totalPos = first.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.TRUE_POS_NAME).index()) * scaleFactor;
      double totalNeg = first.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.FALSE_POS_NAME)) * scaleFactor;

      double posInSample = (totalPos * (Double.parseDouble(m_percPopLab.getText()) / 100.0));
      double negInSample = (totalNeg * (Double.parseDouble(m_percPopLab.getText()) / 100.0));
      double posOutSample = totalPos - posInSample;
      double negOutSample = totalNeg - negInSample;
     
      double tpCost = 0.0;
      try {
        tpCost = Double.parseDouble(m_cost_aa.getText());
      } catch (NumberFormatException n) {}
      double fpCost = 0.0;
      try {
        fpCost = Double.parseDouble(m_cost_ba.getText());
      } catch (NumberFormatException n) {}
      double tnCost = 0.0;
      try {
        tnCost = Double.parseDouble(m_cost_bb.getText());
      } catch (NumberFormatException n) {}
      double fnCost = 0.0;
      try {
        fnCost = Double.parseDouble(m_cost_ab.getText());
      } catch (NumberFormatException n) {}
           
      totalRandomCB += posInSample * tpCost;
      totalRandomCB += negInSample * fpCost;
      totalRandomCB += posOutSample * fnCost;
      totalRandomCB += negOutSample * tnCost;
     
      m_randomV.setText(Utils.doubleToString(totalRandomCB, 2));
      double gain = (m_costR.isSelected())
      ? totalRandomCB - CB
          : CB - totalRandomCB;
      m_gainV.setText(Utils.doubleToString(gain, 2));
     
      // update classification rate
      Instance currentInst = m_masterPlot.getPlotInstances().instance(index);
      double tp = currentInst.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.TRUE_POS_NAME).index());
      double tn = currentInst.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.TRUE_NEG_NAME).index());
      m_classificationAccV.
        setText(Utils.doubleToString((tp + tn) / (totalPos + totalNeg) * 100.0, 4) + "%");     
    }
   
    private void updateInfoGivenIndex(int index) {
      Instances plotInstances = m_masterPlot.getPlotInstances();
      int indexOfSampleSize =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
     
      // update labels
      m_percPopLab.setText(Utils.
          doubleToString(100.0 * plotInstances.instance(index).value(indexOfSampleSize), 4));
      m_percOfTargetLab.setText(Utils.doubleToString(
          100.0 * plotInstances.instance(index).value(indexOfPercOfTarget), 4));
      m_thresholdLab.setText(Utils.doubleToString(plotInstances.instance(index).value(indexOfThreshold), 4));
      /*if (m_percPop.isSelected()) {
        m_percPopLab.setText(Utils.doubleToString(100.0 * value, 4));
      } else if (m_percOfTarget.isSelected()) {
        m_percOfTargetLab.setText(Utils.doubleToString(100.0 * value, 4));
      } else {
        m_thresholdLab.setText(Utils.doubleToString(value, 4));
      }*/
     
      // Update the highlighted point on the graphs */
      if (m_previousShapeIndex >= 0) {
        m_shapeSizes[m_previousShapeIndex] = 1;
      }
    
      m_shapeSizes[index] = 10;
      m_previousShapeIndex = index;
     
      // Update the confusion matrix
//      double totalInstances =
      int tp = plotInstances.attribute(ThresholdCurve.TRUE_POS_NAME).index();
      int fp = plotInstances.attribute(ThresholdCurve.FALSE_POS_NAME).index();
      int tn = plotInstances.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
      int fn = plotInstances.attribute(ThresholdCurve.FALSE_NEG_NAME).index();
      Instance temp = plotInstances.instance(index);
      double totalInstances = temp.value(tp) + temp.value(fp) + temp.value(tn) + temp.value(fn);
      // get the value out of the total pop field (if possible)
      double requestedPopSize = totalInstances;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {}
     
      m_conf_aa.setCellValue(temp.value(tp), totalInstances,
          requestedPopSize / totalInstances, 2);
      m_conf_ab.setCellValue(temp.value(fn), totalInstances,
          requestedPopSize / totalInstances, 2);
      m_conf_ba.setCellValue(temp.value(fp), totalInstances,
          requestedPopSize / totalInstances, 2);
      m_conf_bb.setCellValue(temp.value(tn), totalInstances,
            requestedPopSize / totalInstances, 2);
     
      updateCBRandomGainInfo(index);
     
      repaint();
    }
   
    private void updateInfoForSliderValue(double value) {
      int indexOfSampleSize =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold =
        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;
     
      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;          
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }
     
      Instances plotInstances = m_masterPlot.getPlotInstances();
      int index = findIndexForValue(value, plotInstances, indexOfMetric);
      updateInfoGivenIndex(index);
    }
   
    private int findIndexForValue(double value, Instances plotInstances, int indexOfMetric) {
      // binary search
      // threshold curve is sorted ascending in the threshold (thus
      // descending for recall and pop size)
      int index = -1;
      int lower = 0;
      int upper = plotInstances.numInstances() - 1;
      int mid = (upper - lower) / 2;
      boolean done = false;
      while (!done) {
        if (upper - lower <= 1) {
         
          // choose the one closest to the value
          double comp1 = plotInstances.instance(upper).value(indexOfMetric);
          double comp2 = plotInstances.instance(lower).value(indexOfMetric);
          if (Math.abs(comp1 - value) < Math.abs(comp2 - value)) {
            index = upper;
          } else {
            index = lower;
          }
         
          break;
        }
        double comparisonVal = plotInstances.instance(mid).value(indexOfMetric);
        if (value > comparisonVal) {
          if (m_threshold.isSelected()) {
            lower = mid;
            mid += (upper - lower) / 2;
          } else {
            upper = mid;
            mid -= (upper - lower) / 2;
          }
        } else if (value < comparisonVal) {
          if (m_threshold.isSelected()) {
            upper = mid;
            mid -= (upper - lower) / 2;
          } else {
            lower = mid;
            mid += (upper - lower) / 2;
          }
        } else {
          index = mid;
          done = true;
        }
      }
     
      // now check for ties in the appropriate direction
      if (!m_threshold.isSelected()) {
        while (index + 1 < plotInstances.numInstances()) {
          if (plotInstances.instance(index + 1).value(indexOfMetric) ==
            plotInstances.instance(index).value(indexOfMetric)) {
            index++;
          } else {
            break;
          }
        }
      } else {
        while (index - 1 >= 0) {
          if (plotInstances.instance(index - 1).value(indexOfMetric) ==
            plotInstances.instance(index).value(indexOfMetric)) {
            index--;
          } else {
            break;
          }
        }
      }
      return index;
    }
   
    /**
     * Set the threshold data for the panel to use.
     *
     * @param data PlotData2D object encapsulating the threshold data.
     * @param classAtt the class attribute from the original data used to generate
     * the threshold data.
     * @throws Exception if something goes wrong.
     */
    public synchronized void setDataSet(PlotData2D data, Attribute classAtt) throws Exception {     
      // make a copy of the PlotData2D object
      m_masterPlot = new PlotData2D(data.getPlotInstances());
      boolean[] connectPoints = new boolean[m_masterPlot.getPlotInstances().numInstances()];
      for (int i = 1; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      m_masterPlot.setConnectPoints(connectPoints);

      m_masterPlot.m_alwaysDisplayPointsOfThisSize = 10;
      setClassForConfusionMatrix(classAtt);
      m_performancePanel.setMasterPlot(m_masterPlot);
      m_performancePanel.validate(); m_performancePanel.repaint();

      m_shapeSizes = new int[m_masterPlot.getPlotInstances().numInstances()];
      for (int i = 0; i < m_shapeSizes.length; i++) {
        m_shapeSizes[i] = 1;
      }
      m_masterPlot.setShapeSize(m_shapeSizes);
      constructCostBenefitData();
      m_costBenefitPanel.setMasterPlot(m_costBenefit);
      m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();

      m_totalPopPrevious = 0;
      m_fpPrevious = 0;
      m_tpPrevious = 0;
      m_tnPrevious = 0;
      m_fnPrevious = 0;
      m_previousShapeIndex = -1;

      // set the total population size
      Instance first = m_masterPlot.getPlotInstances().instance(0);
      double totalPos = first.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.TRUE_POS_NAME).index());
      double totalNeg = first.value(m_masterPlot.getPlotInstances().
          attribute(ThresholdCurve.FALSE_POS_NAME));
      m_originalPopSize = (int)(totalPos + totalNeg);
      m_totalPopField.setText("" + m_originalPopSize);

      m_performancePanel.setYIndex(5);
      m_performancePanel.setXIndex(10);
      m_costBenefitPanel.setXIndex(0);
      m_costBenefitPanel.setYIndex(1);
      //      System.err.println(m_masterPlot.getPlotInstances());
      updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0);
    }
   
    private void setClassForConfusionMatrix(Attribute classAtt) {
      m_classAttribute = classAtt;
      m_conf_actualA.setText(" Actual (a): " + classAtt.value(0));
      m_conf_actualA.setToolTipText(classAtt.value(0));
      String negClasses = "";
      for (int i = 1; i < classAtt.numValues(); i++) {
        negClasses += classAtt.value(i);
        if (i < classAtt.numValues() - 1) {
          negClasses += ",";
        }
      }
      m_conf_actualB.setText(" Actual (b): " + negClasses);
      m_conf_actualB.setToolTipText(negClasses);
    }
   
    private boolean constructCostBenefitData() {
      double tpCost = 0.0;
      try {
        tpCost = Double.parseDouble(m_cost_aa.getText());
      } catch (NumberFormatException n) {}
      double fpCost = 0.0;
      try {
        fpCost = Double.parseDouble(m_cost_ba.getText());
      } catch (NumberFormatException n) {}
      double tnCost = 0.0;
      try {
        tnCost = Double.parseDouble(m_cost_bb.getText());
      } catch (NumberFormatException n) {}
      double fnCost = 0.0;
      try {
        fnCost = Double.parseDouble(m_cost_ab.getText());
      } catch (NumberFormatException n) {}
     
      double requestedPopSize = m_originalPopSize;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {}
     
      double scaleFactor = 1.0;
      if (m_originalPopSize != 0) {
        scaleFactor = requestedPopSize / m_originalPopSize;
      }
     
      if (tpCost == m_tpPrevious && fpCost == m_fpPrevious &&
          tnCost == m_tnPrevious && fnCost == m_fnPrevious &&
          requestedPopSize == m_totalPopPrevious) {
        return false;
      }
     
      // First construct some Instances for the curve
      FastVector fv = new FastVector();
      fv.addElement(new Attribute("Sample Size"));
      fv.addElement(new Attribute("Cost/Benefit"));
      Instances costBenefitI = new Instances("Cost/Benefit Curve", fv, 100);
     
      // process the performance data to make this curve
      Instances performanceI = m_masterPlot.getPlotInstances();
     
      for (int i = 0; i < performanceI.numInstances(); i++) {
        Instance current = performanceI.instance(i);
       
        double[] vals = new double[2];
        vals[0] = current.value(10); // sample size
        vals[1] = (current.value(0) * tpCost
            + current.value(1) * fnCost
            + current.value(2) * fpCost
            + current.value(3) * tnCost) * scaleFactor;
        Instance newInst = new DenseInstance(1.0, vals);
        costBenefitI.add(newInst);
      }
     
      costBenefitI.compactify();
     
      // now set up the plot data
      m_costBenefit = new PlotData2D(costBenefitI);
      m_costBenefit.m_alwaysDisplayPointsOfThisSize = 10;
      m_costBenefit.setPlotName("Cost/benefit curve");
      boolean[] connectPoints = new boolean[costBenefitI.numInstances()];
     
      for (int i = 0; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      try {
        m_costBenefit.setConnectPoints(connectPoints);
        m_costBenefit.setShapeSize(m_shapeSizes);
      } catch (Exception ex) {
        // ignore
      }
     
      m_tpPrevious = tpCost;
      m_fpPrevious = fpCost;
      m_tnPrevious = tnCost;
      m_fnPrevious = fnCost;
     
      return true;
    }
  }
 
  /**
   * Constructor.
   */
  public CostBenefitAnalysis() {
    java.awt.GraphicsEnvironment ge =
      java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment();
    if (!ge.isHeadless()) {
      appearanceFinal();
    }
  }
 
  /**
   * Global info for this bean
   *
   * @return a <code>String</code> value
   */
  public String globalInfo() {
    return "Visualize performance charts (such as ROC).";
  }

  /**
   * Accept a threshold data event and set up the visualization.
   * @param e a threshold data event
   */
  public void acceptDataSet(ThresholdDataEvent e) {
    try {
      setCurveData(e.getDataSet(), e.getClassAttribute());
    } catch (Exception ex) {
      System.err.println("[CostBenefitAnalysis] Problem setting up visualization.");
      ex.printStackTrace();
    }
  }
 
  /**
   * Set the threshold curve data to use.
   *
   * @param curveData a PlotData2D object set up with the curve data.
   * @param origClassAtt the class attribute from the original data used to
   * generate the curve.
   * @throws Exception if somthing goes wrong during the setup process.
   */
  public void setCurveData(PlotData2D curveData, Attribute origClassAtt)
    throws Exception {
    if (m_analysisPanel == null) {
      m_analysisPanel = new AnalysisPanel();
    }
    m_analysisPanel.setDataSet(curveData, origClassAtt);
  }

  public BeanVisual getVisual() {
    return m_visual;
  }

  public void setVisual(BeanVisual newVisual) {
    m_visual = newVisual;
  }

  public void useDefaultVisual() {
    m_visual.loadIcons(BeanVisual.ICON_PATH+"DefaultDataVisualizer.gif",
        BeanVisual.ICON_PATH+"DefaultDataVisualizer_animated.gif");
  }

  public Enumeration enumerateRequests() {
    Vector newVector = new Vector(0);
    if (m_analysisPanel != null) {
      if (m_analysisPanel.m_masterPlot != null) {
        newVector.addElement("Show analysis");
      }
    }
    return newVector.elements();
  }

  public void performRequest(String request) {
    if (request.compareTo("Show analysis") == 0) {
      try {
        // popup visualize panel
        if (!m_framePoppedUp) {
          m_framePoppedUp = true;

          final javax.swing.JFrame jf =
            new javax.swing.JFrame("Cost/Benefit Analysis");
          jf.setSize(1000,600);
          jf.getContentPane().setLayout(new BorderLayout());
          jf.getContentPane().add(m_analysisPanel, BorderLayout.CENTER);
          jf.addWindowListener(new java.awt.event.WindowAdapter() {
              public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
                m_framePoppedUp = false;
              }
            });
          jf.setVisible(true);
          m_popupFrame = jf;
        } else {
          m_popupFrame.toFront();
        }
      } catch (Exception ex) {
        ex.printStackTrace();
        m_framePoppedUp = false;
      }
    } else {
      throw new IllegalArgumentException(request
          + " not supported (Cost/Benefit Analysis");
    }
  }

  public void addVetoableChangeListener(String name, VetoableChangeListener vcl) {
    m_bcSupport.addVetoableChangeListener(name, vcl);
  }

  public BeanContext getBeanContext() {
    return m_beanContext;
  }

  public void removeVetoableChangeListener(String name,
      VetoableChangeListener vcl) {
    m_bcSupport.removeVetoableChangeListener(name, vcl);
  }
 
  protected void appearanceFinal() {
    removeAll();
    setLayout(new BorderLayout());
    setUpFinal();
  }
 
  protected void setUpFinal() {
    if (m_analysisPanel == null) {
      m_analysisPanel = new AnalysisPanel();
    }
    add(m_analysisPanel, BorderLayout.CENTER);
  }
 
  protected void appearanceDesign() {
    removeAll();
    m_visual = new BeanVisual("CostBenefitAnalysis",
                              BeanVisual.ICON_PATH+"ModelPerformanceChart.gif",
                              BeanVisual.ICON_PATH
                              +"ModelPerformanceChart_animated.gif");
    setLayout(new BorderLayout());
    add(m_visual, BorderLayout.CENTER);
  }

  public void setBeanContext(BeanContext bc) throws PropertyVetoException {
    m_beanContext = bc;
    m_design = m_beanContext.isDesignTime();
    if (m_design) {
      appearanceDesign();
    } else {
      java.awt.GraphicsEnvironment ge =
        java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment();
      if (!ge.isHeadless()) {
        appearanceFinal();
      }
    }
  }
 
  /**
   * Returns true if, at this time,
   * the object will accept a connection via the named event
   *
   * @param eventName the name of the event in question
   * @return true if the object will accept a connection
   */
  public boolean connectionAllowed(String eventName) {
    return (m_listenee == null);
  }

  /**
   * Notify this object that it has been registered as a listener with
   * a source for recieving events described by the named event
   * This object is responsible for recording this fact.
   *
   * @param eventName the event
   * @param source the source with which this object has been registered as
   * a listener
   */
  public void connectionNotification(String eventName, Object source) {
    if (connectionAllowed(eventName)) {
      m_listenee = source;
    }
  }
 
  /**
   * Returns true if, at this time,
   * the object will accept a connection according to the supplied
   * EventSetDescriptor
   *
   * @param esd the EventSetDescriptor
   * @return true if the object will accept a connection
   */
  public boolean connectionAllowed(EventSetDescriptor esd) {
    return connectionAllowed(esd.getName());
  }

  /**
   * Notify this object that it has been deregistered as a listener with
   * a source for named event. This object is responsible
   * for recording this fact.
   *
   * @param eventName the event
   * @param source the source with which this object has been registered as
   * a listener
   */
  public void disconnectionNotification(String eventName, Object source) {
    if (m_listenee == source) {
      m_listenee = null;
    }
   
  }

  /**
   * Get the custom (descriptive) name for this bean (if one has been set)
   *
   * @return the custom name (or the default name)
   */
  public String getCustomName() {
    return m_visual.getText();
  }

  /**
   * Returns true if. at this time, the bean is busy with some
   * (i.e. perhaps a worker thread is performing some calculation).
   *
   * @return true if the bean is busy.
   */
  public boolean isBusy() {
    return false;
  }

  /**
   * Set a custom (descriptive) name for this bean
   *
   * @param name the name to use
   */
  public void setCustomName(String name) {
    m_visual.setText(name);
  }

  /**
   * Set a logger
   *
   * @param logger a <code>weka.gui.Logger</code> value
   */
  public void setLog(Logger logger) {
    // we don't need to do any logging   
  }

  /**
   * Stop any processing that the bean might be doing.
   */
  public void stop() {
    // nothing to do here
  }
   
  public static void main(String[] args) {
    try {
      Instances train = new Instances(new java.io.BufferedReader(new java.io.FileReader(args[0])));
      train.setClassIndex(train.numAttributes() - 1);
      weka.classifiers.evaluation.ThresholdCurve tc =
        new weka.classifiers.evaluation.ThresholdCurve();
      weka.classifiers.evaluation.EvaluationUtils eu =
        new weka.classifiers.evaluation.EvaluationUtils();
      //weka.classifiers.Classifier classifier = new weka.classifiers.functions.Logistic();
      weka.classifiers.Classifier classifier = new weka.classifiers.bayes.NaiveBayes();
      FastVector predictions = new FastVector();
      eu.setSeed(1);
      predictions.appendElements(eu.getCVPredictions(classifier, train, 10));
      Instances result = tc.getCurve(predictions, 0);
      PlotData2D pd = new PlotData2D(result);
      pd.m_alwaysDisplayPointsOfThisSize = 10;

      boolean[] connectPoints = new boolean[result.numInstances()];
      for (int i = 1; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      pd.setConnectPoints(connectPoints);
      final javax.swing.JFrame jf =
        new javax.swing.JFrame("CostBenefitTest");
      jf.setSize(1000,600);
      //jf.pack();
      jf.getContentPane().setLayout(new BorderLayout());
      final CostBenefitAnalysis.AnalysisPanel analysisPanel =
        new CostBenefitAnalysis.AnalysisPanel();
     
      jf.getContentPane().add(analysisPanel, BorderLayout.CENTER);
      jf.addWindowListener(new java.awt.event.WindowAdapter() {
        public void windowClosing(java.awt.event.WindowEvent e) {
          jf.dispose();
          System.exit(0);
        }
      });
     
      jf.setVisible(true);
     
      analysisPanel.setDataSet(pd, train.classAttribute());
     
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
}
TOP

Related Classes of weka.gui.beans.CostBenefitAnalysis$AnalysisPanel$ConfusionCell

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.