Package dkpro.similarity.experiments.sts2013.util

Source Code of dkpro.similarity.experiments.sts2013.util.Evaluator

/*******************************************************************************
* Copyright 2013
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0
* which accompanies this distribution, and is available at
* http://www.gnu.org/licenses/gpl-3.0.txt
******************************************************************************/
package dkpro.similarity.experiments.sts2013.util;

import static dkpro.similarity.experiments.sts2013.Pipeline.DATASET_DIR;
import static dkpro.similarity.experiments.sts2013.Pipeline.GOLDSTANDARD_DIR;
import static dkpro.similarity.experiments.sts2013.Pipeline.MODELS_DIR;
import static dkpro.similarity.experiments.sts2013.Pipeline.OUTPUT_DIR;
import static dkpro.similarity.experiments.sts2013.Pipeline.EvaluationMetric.PearsonAll;
import static dkpro.similarity.experiments.sts2013.Pipeline.EvaluationMetric.PearsonMean;
import static dkpro.similarity.experiments.sts2013.Pipeline.EvaluationMetric.PearsonWeightedMean;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
import static org.apache.uima.fit.factory.CollectionReaderFactory.createReader;
import static org.apache.uima.fit.factory.ExternalResourceFactory.createExternalResourceDescription;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;

import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math.stat.correlation.PearsonsCorrelation;
import org.apache.uima.UIMAException;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AddClassification;
import weka.filters.unsupervised.attribute.AddID;
import weka.filters.unsupervised.attribute.Remove;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Document;
import de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter;
import dkpro.similarity.experiments.sts2013.Pipeline.Dataset;
import dkpro.similarity.experiments.sts2013.Pipeline.EvaluationMetric;
import dkpro.similarity.experiments.sts2013.Pipeline.Mode;
import dkpro.similarity.experiments.sts2013.filter.LogFilter;
import dkpro.similarity.ml.io.SimilarityScoreWriter;
import dkpro.similarity.uima.annotator.SimilarityScorer;
import dkpro.similarity.uima.io.CombinationReader;
import dkpro.similarity.uima.io.SemEvalCorpusReader;
import dkpro.similarity.uima.io.CombinationReader.CombinationStrategy;
import dkpro.similarity.uima.resource.ml.LinearRegressionResource;


public class Evaluator
{
  public static final String LF = System.getProperty("line.separator");
 
  public static void runLinearRegression(Dataset train, Dataset... test)
    throws UIMAException, IOException
  {
    for (Dataset dataset : test)
    {
      CollectionReader reader = createReader(SemEvalCorpusReader.class,
          SemEvalCorpusReader.PARAM_INPUT_FILE, DATASET_DIR + "/test/STS.input." + dataset.toString() + ".txt",
          SemEvalCorpusReader.PARAM_COMBINATION_STRATEGY, CombinationStrategy.SAME_ROW_ONLY.toString());
     
      AnalysisEngineDescription seg = createEngineDescription(BreakIteratorSegmenter.class);
     
      AggregateBuilder builder = new AggregateBuilder();
      builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_1);
      builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_2);
      AnalysisEngine aggr_seg = builder.createAggregate();
 
      AnalysisEngine scorer = createEngine(SimilarityScorer.class,
            SimilarityScorer.PARAM_NAME_VIEW_1, CombinationReader.VIEW_1,
            SimilarityScorer.PARAM_NAME_VIEW_2, CombinationReader.VIEW_2,
            SimilarityScorer.PARAM_SEGMENT_FEATURE_PATH, Document.class.getName(),
            SimilarityScorer.PARAM_TEXT_SIMILARITY_RESOURCE, createExternalResourceDescription(
              LinearRegressionResource.class,
              LinearRegressionResource.PARAM_LOG_FILTER, "true",
              LinearRegressionResource.PARAM_TRAIN_ARFF, MODELS_DIR + "/train/" + train.toString() + ".arff",
              LinearRegressionResource.PARAM_TEST_ARFF, MODELS_DIR + "/test/" + dataset.toString() + ".arff")
            );
     
      AnalysisEngine writer = createEngine(SimilarityScoreWriter.class,
          SimilarityScoreWriter.PARAM_OUTPUT_FILE, OUTPUT_DIR + "/test/" + dataset.toString() + ".csv",
          SimilarityScoreWriter.PARAM_OUTPUT_SCORES_ONLY, true,
          SimilarityScoreWriter.PARAM_OUTPUT_GOLD_SCORES, false);
 
      SimplePipeline.runPipeline(reader, aggr_seg, scorer, writer);
    }
  }
 
  public static void runLinearRegressionCV(Mode mode, Dataset... datasets)
    throws Exception
  {
    for (Dataset dataset : datasets)
    {
      // Set parameters
      int folds = 10;
      Classifier baseClassifier = new LinearRegression();
     
      // Set up the random number generator
        long seed = new Date().getTime();     
      Random random = new Random(seed)
       
      // Add IDs to the instances
      AddID.main(new String[] {"-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff",
                     "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" });
      Instances data = DataSource.read(MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff");
      data.setClassIndex(data.numAttributes() - 1);       
     
          // Instantiate the Remove filter
          Remove removeIDFilter = new Remove();
          removeIDFilter.setAttributeIndices("first");
     
      // Randomize the data
      data.randomize(random);
   
      // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);
       
        for (int n = 0; n < folds; n++)
        {
          Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);
           
            // Apply log filter
          Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);       
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);
           
            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                  
            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);
              
            // Build the classifier
            filteredClassifier.buildClassifier(train);
            
            // Evaluate
            eval.evaluateModel(classifier, test);
           
            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter)// trains the classifier
           
            Instances pred = Filter.useFilter(test, filter)// performs predictions on test set
            if (predictedData == null) {
                    predictedData = new Instances(pred, 0);
                }
            for (int j = 0; j < pred.numInstances(); j++) {
                    predictedData.add(pred.instance(j));
                }           
        }
       
        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];
       
        for (Instance predInst : predictedData)
        {
          int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;
         
          int valueIdx = predictedData.numAttributes() - 2;
         
          double value = predInst.value(predInst.attribute(valueIdx));
         
          scores[id] = value;
         
          // Limit to interval [0;5]
        if (scores[id] > 5.0) {
                    scores[id] = 5.0;
                }
        if (scores[id] < 0.0) {
                    scores[id] = 0.0;
                }
        }
       
        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
                sb.append(score.toString() + LF);
            }
       
        FileUtils.writeStringToFile(
          new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
          sb.toString());
    }
  }
 
  @SuppressWarnings("unchecked")
  public static void runEvaluationMetric(Mode mode, EvaluationMetric metric, Dataset... datasets)
    throws IOException
  {
    StringBuilder sb = new StringBuilder();
   
    // Compute Pearson correlation for the specified datasets
    for (Dataset dataset : datasets)
    {
      computePearsonCorrelation(mode, dataset);
    }
   
    if (metric == PearsonAll)
    {
      List<Double> concatExp = new ArrayList<Double>();
      List<Double> concatGS = new ArrayList<Double>();
     
      // Concat the scores
      for (Dataset dataset : datasets)
      {
        File expScoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv");
       
        List<String> lines = FileUtils.readLines(expScoresFile);
       
        for (String line : lines) {
                    concatExp.add(Double.parseDouble(line));
                }
      }
     
      // Concat the gold standard
      for (Dataset dataset : datasets)
      {
        String gsScoresFilePath = GOLDSTANDARD_DIR + "/" + mode.toString().toLowerCase() + "/" +
            "STS.gs." + dataset.toString() + ".txt";
       
        PathMatchingResourcePatternResolver r = new PathMatchingResourcePatternResolver();
            Resource res = r.getResource(gsScoresFilePath);       
        File gsScoresFile = res.getFile();
       
        List<String> lines = FileUtils.readLines(gsScoresFile);
       
        for (String line : lines) {
                    concatGS.add(Double.parseDouble(line));
                }
      }
     
      double[] concatExpArray = ArrayUtils.toPrimitive(concatExp.toArray(new Double[concatExp.size()]));
      double[] concatGSArray = ArrayUtils.toPrimitive(concatGS.toArray(new Double[concatGS.size()]));

      PearsonsCorrelation pearson = new PearsonsCorrelation();
      Double correl = pearson.correlation(concatExpArray, concatGSArray);
     
      sb.append(correl.toString());
    }
    else if (metric == PearsonMean)
    {
      List<Double> scores = new ArrayList<Double>();
     
      for (Dataset dataset : datasets)
      {
        File resultFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt");
        double score = Double.parseDouble(FileUtils.readFileToString(resultFile));
       
        scores.add(score);
      }
     
      double mean = 0.0;
      for (Double score : scores) {
                mean += score;
            }
      mean = mean / scores.size();
     
      sb.append(mean);
    }
    else if (metric == PearsonWeightedMean)
    {
      List<Double> scores = new ArrayList<Double>();
      List<Integer> weights = new ArrayList<Integer>();
     
      for (Dataset dataset : datasets)
      {
        File resultFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt");
        double score = Double.parseDouble(FileUtils.readFileToString(resultFile));
       
        File scoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv");
        int weight = FileUtils.readLines(scoresFile).size();
       
        scores.add(score);
        weights.add(weight);
      }
     
      double mean = 0.0;
      int weightsum = 0;
      for (int i = 0; i < scores.size(); i++)
      {
        Double score = scores.get(i);
        Integer weight = weights.get(i);
               
        mean += weight * score;
        weightsum += weight;
      }
      mean = mean / weightsum;
     
      sb.append(mean);
    }

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + metric.toString() + ".txt"), sb.toString());
  }
 
  @SuppressWarnings("unchecked")
  private static void computePearsonCorrelation(Mode mode, Dataset dataset)
    throws IOException
  {
    File expScoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv");
   
    String gsScoresFilePath = GOLDSTANDARD_DIR + "/" + mode.toString().toLowerCase() + "/" +
        "STS.gs." + dataset.toString() + ".txt";
   
    PathMatchingResourcePatternResolver r = new PathMatchingResourcePatternResolver();
        Resource res = r.getResource(gsScoresFilePath);       
    File gsScoresFile = res.getFile();
   
    List<Double> expScores = new ArrayList<Double>();
    List<Double> gsScores = new ArrayList<Double>();
   
    List<String> expLines = FileUtils.readLines(expScoresFile);
    List<String> gsLines = FileUtils.readLines(gsScoresFile);
   
    for (int i = 0; i < expLines.size(); i++)
    {
      expScores.add(Double.parseDouble(expLines.get(i)));
      gsScores.add(Double.parseDouble(gsLines.get(i)));
    }
   
    double[] expArray = ArrayUtils.toPrimitive(expScores.toArray(new Double[expScores.size()]));
    double[] gsArray = ArrayUtils.toPrimitive(gsScores.toArray(new Double[gsScores.size()]));

    PearsonsCorrelation pearson = new PearsonsCorrelation();
    Double correl = pearson.correlation(expArray, gsArray);
   
    FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt"),
        correl.toString());
  }
}
TOP

Related Classes of dkpro.similarity.experiments.sts2013.util.Evaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.