Package dkpro.similarity.experiments.rte.util

Source Code of dkpro.similarity.experiments.rte.util.Evaluator$CwsData

/*******************************************************************************
* Copyright 2013
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0
* which accompanies this distribution, and is available at
* http://www.gnu.org/licenses/gpl-3.0.txt
******************************************************************************/
package dkpro.similarity.experiments.rte.util;

import static dkpro.similarity.experiments.rte.Pipeline.GOLD_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.MODELS_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.OUTPUT_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.Accuracy;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.AveragePrecision;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.CWS;

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Random;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.springframework.util.CollectionUtils;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.PlainText;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSink;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AddClassification;
import weka.filters.unsupervised.attribute.AddID;
import weka.filters.unsupervised.attribute.Remove;
import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure;
import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure.WekaClassifier;
import dkpro.similarity.experiments.rte.Pipeline.Dataset;
import dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric;
//import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.EvaluationMetric;
//import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.Mode;
//import de.tudarmstadt.ukp.similarity.experiments.rte.filter.LogFilter;


public class Evaluator
{
  public static final String LF = System.getProperty("line.separator");
 
//  public static void runClassifier(Dataset train, Dataset test)
//    throws UIMAException, IOException
//  {
//    CollectionReader reader = createCollectionReader(
//        RTECorpusReader.class,
//        RTECorpusReader.PARAM_INPUT_FILE, RteUtil.getInputFilePathForDataset(DATASET_DIR, test),
//        RTECorpusReader.PARAM_COMBINATION_STRATEGY, CombinationStrategy.SAME_ROW_ONLY.toString());
//   
//    AnalysisEngineDescription seg = createPrimitiveDescription(
//        BreakIteratorSegmenter.class);
//   
//    AggregateBuilder builder = new AggregateBuilder();
//    builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_1);
//    builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_2);
//    AnalysisEngine aggr_seg = builder.createAggregate();
//
//    AnalysisEngine scorer = createPrimitive(
//        SimilarityScorer.class,
//          SimilarityScorer.PARAM_NAME_VIEW_1, CombinationReader.VIEW_1,
//          SimilarityScorer.PARAM_NAME_VIEW_2, CombinationReader.VIEW_2,
//          SimilarityScorer.PARAM_SEGMENT_FEATURE_PATH, Document.class.getName(),
//          SimilarityScorer.PARAM_TEXT_SIMILARITY_RESOURCE, createExternalResourceDescription(
//            ClassifierResource.class,
//            ClassifierResource.PARAM_CLASSIFIER, wekaClassifier.toString(),
//            ClassifierResource.PARAM_TRAIN_ARFF, MODELS_DIR + "/" + train.toString() + ".arff",
//            ClassifierResource.PARAM_TEST_ARFF, MODELS_DIR + "/" + test.toString() + ".arff")
//          );
//   
//    AnalysisEngine writer = createPrimitive(
//        SimilarityScoreWriter.class,
//        SimilarityScoreWriter.PARAM_OUTPUT_FILE, OUTPUT_DIR + "/" + test.toString() + ".csv",
//        SimilarityScoreWriter.PARAM_OUTPUT_SCORES_ONLY, true,
//        SimilarityScoreWriter.PARAM_OUTPUT_GOLD_SCORES, false);
//
//    SimplePipeline.runPipeline(reader, aggr_seg, scorer, writer);
//  }
 
  public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset)
      throws Exception
  {
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);
   
    // Set up the random number generator
      long seed = new Date().getTime();     
    Random random = new Random(seed)
       
    // Add IDs to the train instances and get the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" });
    Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff");
    train.setClassIndex(train.numAttributes() - 1)
   
    // Add IDs to the test instances and get the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + testDataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" });
    Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff");
    test.setClassIndex(test.numAttributes() - 1);   
   
    // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
      removeIDFilter.setAttributeIndices("first");
       
    // Randomize the data
    test.randomize(random);
   
    // Apply log filter
//      Filter logFilter = new LogFilter();
//      logFilter.setInputFormat(train);
//      train = Filter.useFilter(train, logFilter);       
//      logFilter.setInputFormat(test);
//      test = Filter.useFilter(test, logFilter);
       
        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
         
        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);
          
        // Build the classifier
        filteredClassifier.buildClassifier(train);
   
        // Prepare the output buffer
        AbstractOutput output = new PlainText();
        output.setBuffer(new StringBuffer());
        output.setHeader(test);
        output.setAttributes("first");
       
    Evaluation eval = new Evaluation(train);
        eval.evaluateModel(filteredClassifier, test, output);
       
        // Convert predictions to CSV
        // Format: inst#, actual, predicted, error, probability, (ID)
        String[] scores = new String[new Double(eval.numInstances()).intValue()];
        double[] probabilities = new double[new Double(eval.numInstances()).intValue()];
        for (String line : output.getBuffer().toString().split("\n"))
        {
          String[] linesplit = line.split("\\s+");

          // If there's been an error, the length of linesplit is 6, otherwise 5,
          // due to the error flag "+"
         
          int id;
          String expectedValue, classification;
          double probability;
         
          if (line.contains("+"))
          {
               id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[5]);
          } else {
            id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[4]);
          }
         
          scores[id - 1] = classification;
          probabilities[id - 1] = probability;
        }
               
        System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
     
      // Output classifications
      StringBuilder sb = new StringBuilder();
      for (String score : scores)
        sb.append(score.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"),
        sb.toString());
     
      // Output probabilities
      sb = new StringBuilder();
      for (Double probability : probabilities)
        sb.append(probability.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"),
        sb.toString());
     
      // Output predictions
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"),
        output.getBuffer().toString());
     
      // Output meta information
      sb = new StringBuilder();
      sb.append(classifier.toString() + LF);
      sb.append(eval.toSummaryString() + LF);
      sb.append(eval.toMatrixString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"),
        sb.toString());
  }
 
  public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset)
    throws Exception
  {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);
   
    // Set up the random number generator
      long seed = new Date().getTime();     
    Random random = new Random(seed)
     
    // Add IDs to the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + dataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);       
   
        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
      removeIDFilter.setAttributeIndices("first");
       
    // Randomize the data
    data.randomize(random);
 
    // Perform cross-validation
      Instances predictedData = null;
      Evaluation eval = new Evaluation(data);
     
      for (int n = 0; n < folds; n++)
      {
        Instances train = data.trainCV(folds, n, random);
          Instances test = data.testCV(folds, n);
         
          // Apply log filter
//        Filter logFilter = new LogFilter();
//          logFilter.setInputFormat(train);
//          train = Filter.useFilter(train, logFilter);       
//          logFilter.setInputFormat(test);
//          test = Filter.useFilter(test, logFilter);
         
          // Copy the classifier
          Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                
          // Instantiate the FilteredClassifier
          FilteredClassifier filteredClassifier = new FilteredClassifier();
          filteredClassifier.setFilter(removeIDFilter);
          filteredClassifier.setClassifier(classifier);
            
          // Build the classifier
          filteredClassifier.buildClassifier(train);
         
          // Evaluate
          eval.evaluateModel(filteredClassifier, test);
         
          // Add predictions
          AddClassification filter = new AddClassification();
          filter.setClassifier(classifier);
          filter.setOutputClassification(true);
          filter.setOutputDistribution(false);
          filter.setOutputErrorFlag(true);
          filter.setInputFormat(train);
          Filter.useFilter(train, filter)// trains the classifier
         
          Instances pred = Filter.useFilter(test, filter)// performs predictions on test set
          if (predictedData == null)
            predictedData = new Instances(pred, 0);
          for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));           
      }
     
      System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
     
      // Prepare output scores
      String[] scores = new String[predictedData.numInstances()];
     
      for (Instance predInst : predictedData)
      {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;
       
        int valueIdx = predictedData.numAttributes() - 2;
       
        String value = predInst.stringValue(predInst.attribute(valueIdx));
       
        scores[id] = value;
      }
     
      // Output classifications
      StringBuilder sb = new StringBuilder();
      for (String score : scores)
        sb.append(score.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"),
        sb.toString());
     
      // Output prediction arff
      DataSink.write(
        OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".predicted.arff",
        predictedData);
     
      // Output meta information
      sb = new StringBuilder();
      sb.append(baseClassifier.toString() + LF);
      sb.append(eval.toSummaryString() + LF);
      sb.append(eval.toMatrixString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".meta.txt"),
        sb.toString());
  }

  @SuppressWarnings("unchecked")
  public static void runEvaluationMetric(EvaluationMetric metric, Dataset dataset)
    throws IOException
  {
    // Get all subdirectories (i.e. all classifiers)
    File outputDir = new File(OUTPUT_DIR + "/" + dataset.toString() + "/");
    File[] dirsArray = outputDir.listFiles((FileFilter) FileFilterUtils.directoryFileFilter());
   
    List<File> dirs = CollectionUtils.arrayToList(dirsArray);
   
    // Don't list hidden dirs (such as .svn)
    for (int i = dirs.size() - 1; i >= 0; i--)
      if (dirs.get(i).getName().startsWith("."))
          dirs.remove(i);
   
    // Iteratively evaluate all classifiers' results
    for (File dir : dirs)
      runEvaluationMetric(
          WekaClassifier.valueOf(dir.getName()),
          metric,
          dataset);
  }
 
  public static void runEvaluationMetric(WekaClassifier wekaClassifier, EvaluationMetric metric, Dataset dataset)
      throws IOException
  {
    StringBuilder sb = new StringBuilder();
     
    if (metric == Accuracy)
    {
      // Read gold scores
      List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));
           
      // Read the experimental scores
      List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));
           
      // Compute the accuracy
      double acc = 0.0;
      for (int i = 0; i < goldScores.size(); i++)
      {
        // The predictions have a max length of 8 characters...
        if (goldScores.get(i).substring(0, Math.min(goldScores.get(i).length(), 8)).equals(
            expScores.get(i).substring(0, Math.min(expScores.get(i).length(), 8))))
          acc++;
      }
      acc = acc / goldScores.size();
     
      sb.append(acc);
    }
    if (metric == CWS)
    {
      // Read gold scores
      List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));
           
      // Read the experimental scores
      List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));
     
      // Read the confidence scores
      List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv"));
     
      // Combine the data
      List<CwsData> data = new ArrayList<CwsData>();
     
      for (int i = 0; i < goldScores.size(); i++)
      {
        CwsData cws = (new Evaluator()).new CwsData(
            Double.parseDouble(probabilities.get(i)),
            goldScores.get(i),
            expScores.get(i));
        data.add(cws);
      }
     
      // Sort in descending order
      Collections.sort(data, Collections.reverseOrder());
     
      // Compute the CWS score
      double cwsScore = 0.0;
      for (int i = 0; i < data.size(); i++)
      {
        double cws_sub = 0.0;
        for (int j = 0; j <= i; j++)
        {
          if (data.get(j).isCorrect())
            cws_sub++;
        }
        cws_sub /= (i+1);
       
        cwsScore += cws_sub;
      }
      cwsScore /= data.size();
           
      sb.append(cwsScore);
    }
    if (metric == AveragePrecision)
    {
      // Read gold scores
      List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));

      // Trim to 8 characters
      for (int i = 0; i < goldScores.size(); i++)
        if (goldScores.get(i).length() > 8)
          goldScores.set(i, goldScores.get(i).substring(0, 8));
     
      // Read the experimental scores
      List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));
     
      // Trim to 8 characters
      for (int i = 0; i < expScores.size(); i++)
        if (expScores.get(i).length() > 8)
          expScores.set(i, expScores.get(i).substring(0, 8));
     
      // Read the confidence scores
      List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv"));
     
      // Conflate UNKONWN + CONTRADICTION classes for 3-way classifications
      if (RteUtil.hasThreeWayClassification(dataset))
      {
        // Gold
        for (int i = 0; i < goldScores.size(); i++)
          if (goldScores.get(i).equals("CONTRADI") || goldScores.get(i).equals("NO") || goldScores.get(i).equals("FALSE"))
            goldScores.set(i, "FALSE");
       
        // Experimental
        for (int i = 0; i < expScores.size(); i++)
          if (expScores.get(i).equals("CONTRADI") || expScores.get(i).equals("NO") || expScores.get(i).equals("FALSE"))
            expScores.set(i, "FALSE");
      }
     
      // Combine the data
      List<CwsData> data = new ArrayList<CwsData>();
     
      for (int i = 0; i < goldScores.size(); i++)
      {
        CwsData cws = (new Evaluator()).new CwsData(
            Double.parseDouble(probabilities.get(i)),
            goldScores.get(i),
            expScores.get(i));
        data.add(cws);
      }
     
      // Sort in descending order
      Collections.sort(data, Collections.reverseOrder());
     
      // Compute the average precision
      double avgPrec = 0.0;
      int numPositive = 0;
      for (int i = 0; i < data.size(); i++)
      {
        double ap_sub = 0.0;
        if (data.get(i).isPositivePair())
        {
          numPositive++;
         
          for (int j = 0; j <= i; j++)
          {
            if (data.get(j).isCorrect())
              ap_sub++;
          }
          ap_sub /= (i+1);
        }
       
        avgPrec += ap_sub;         
      }
      avgPrec /= numPositive;
           
      sb.append(avgPrec);
    }

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + "_" + metric.toString() + ".txt"), sb.toString());
   
    System.out.println("[" + wekaClassifier.toString() + "] " + metric.toString() + ": " + sb.toString());
  }
 
  private class CwsData
    implements Comparable
  {
    private double confidence;
    private String goldScore;
    private String expScore;
   
    public CwsData(double confidence, String goldScore, String expScore)
    {
      this.confidence = confidence;
      this.goldScore = goldScore;
      this.expScore = expScore;
    }
   
    public boolean isCorrect()
    {
      return goldScore.equals(expScore);
    }

    public int compareTo(Object other)
    {
      CwsData otherObj = (CwsData)other;
     
      if (this.getConfidence() == otherObj.getConfidence()) {
        return 0;
      } else if (this.getConfidence() > otherObj.getConfidence()) {
        return 1;
      } else {
        return -1;
      }
    }
   
    public boolean isPositivePair()
    {
      return this.goldScore.equals("TRUE") || this.goldScore.equals("YES") || this.goldScore.equals("ENTAILMENT") || this.goldScore.equals("ENTAILME");
    }

    public double getConfidence()
    {
      return confidence;
    }

    public String getGoldScore()
    {
      return goldScore;
    }

    public String getExpScore()
    {
      return expScore;
    }
  }
 
// 
//  @SuppressWarnings("unchecked")
//  private static void computePearsonCorrelation(Mode mode, Dataset dataset)
//    throws IOException
//  {
//    File expScoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv");
//   
//    String gsScoresFilePath = GOLDSTANDARD_DIR + "/" + mode.toString().toLowerCase() + "/" +
//        "STS.gs." + dataset.toString() + ".txt";
//   
//    PathMatchingResourcePatternResolver r = new PathMatchingResourcePatternResolver();
//        Resource res = r.getResource(gsScoresFilePath);       
//    File gsScoresFile = res.getFile();
//   
//    List<Double> expScores = new ArrayList<Double>();
//    List<Double> gsScores = new ArrayList<Double>();
//   
//    List<String> expLines = FileUtils.readLines(expScoresFile);
//    List<String> gsLines = FileUtils.readLines(gsScoresFile);
//   
//    for (int i = 0; i < expLines.size(); i++)
//    {
//      expScores.add(Double.parseDouble(expLines.get(i)));
//      gsScores.add(Double.parseDouble(gsLines.get(i)));
//    }
//   
//    double[] expArray = ArrayUtils.toPrimitive(expScores.toArray(new Double[expScores.size()]));
//    double[] gsArray = ArrayUtils.toPrimitive(gsScores.toArray(new Double[gsScores.size()]));
//
//    PearsonsCorrelation pearson = new PearsonsCorrelation();
//    Double correl = pearson.correlation(expArray, gsArray);
//   
//    FileUtils.writeStringToFile(
//        new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt"),
//        correl.toString());
//  }
}
TOP

Related Classes of dkpro.similarity.experiments.rte.util.Evaluator$CwsData

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.