Package com.tamingtext.tagrecommender

Source Code of com.tamingtext.tagrecommender.TestStackOverflowTagger

/*
* Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
*
*    Licensed under the Apache License, Version 2.0 (the "License");
*    you may not use this file except in compliance with the License.
*    You may obtain a copy of the License at
*
*        http://www.apache.org/licenses/LICENSE-2.0
*
*    Unless required by applicable law or agreed to in writing, software
*    distributed under the License is distributed on an "AS IS" BASIS,
*    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*    See the License for the specific language governing permissions and
*    limitations under the License.
* -------------------
* To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
* http://www.manning.com/ingersoll
*/

package com.tamingtext.tagrecommender;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.net.MalformedURLException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.HashSet;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.function.ObjectIntProcedure;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.tamingtext.tagrecommender.TagRecommenderClient.ScoreTag;

public class TestStackOverflowTagger {
  
  private static final Logger log = LoggerFactory.getLogger(TestStackOverflowTagger.class);

  private final NumberFormat nf = new DecimalFormat("##.##");
  private TagRecommenderClient client;
  private File   inputFile;
  private File   countFile;
  private File   outputFile;
 
  private String solrUrl;
  private int    maxTags = 5;
  public static void main(String[] args) {
    TestStackOverflowTagger t = new TestStackOverflowTagger();
    if (t.parseArgs(args)) {
      t.execute();
    }
  }
 
  public boolean parseArgs(String[] args) {
    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
    ArgumentBuilder abuilder = new ArgumentBuilder();
    GroupBuilder gbuilder = new GroupBuilder();
    Option helpOpt = DefaultOptionCreator.helpOption();
   
    Option inputFileOpt = obuilder.withLongName("inputFile").withRequired(true).withArgument(
        abuilder.withName("inputFile").withMinimum(1).withMaximum(1).create()).withDescription(
        "The input file").withShortName("i").create();
   
    Option countFileOpt = obuilder.withLongName("countFile").withRequired(true).withArgument(
        abuilder.withName("countFile").withMinimum(1).withMaximum(1).create()).withDescription(
        "The tag count file").withShortName("c").create();
   
    Option outputFileOpt = obuilder.withLongName("outputFile").withRequired(true).withArgument(
        abuilder.withName("outputFile").withMinimum(1).withMaximum(1).create()).withDescription(
        "The output file").withShortName("c").create();
   
    Option solrUrlOpt = obuilder.withLongName("solrUrl").withRequired(true).withArgument(
        abuilder.withName("solrUrl").withMinimum(1).withMaximum(1).create()).withDescription(
        "URL of the solr server").withShortName("s").create();
   
    Group group = gbuilder.withName("Options")
      .withOption(inputFileOpt)
      .withOption(countFileOpt)
      .withOption(outputFileOpt)
      .withOption(solrUrlOpt).create();
   
    try {
      Parser parser = new Parser();
      parser.setGroup(group);
      CommandLine cmdLine = parser.parse(args);
     
      if (cmdLine.hasOption(helpOpt)) {
        CommandLineUtil.printHelp(group);
        return false;
      }
     
      inputFile  = new File((String) cmdLine.getValue(inputFileOpt));
      countFile  = new File((String) cmdLine.getValue(countFileOpt));
      outputFile = new File((String) cmdLine.getValue(outputFileOpt));
      solrUrl    = (String) cmdLine.getValue(solrUrlOpt);
      client     = new TagRecommenderClient(solrUrl);
    } catch (OptionException e) {
      log.error("Command-line option Exception", e);
      CommandLineUtil.printHelp(group);
      return false;
    } catch (MalformedURLException e) {
      log.error("MalformedURLException", e);
      return false;
    }
   
    validate();
    return true;
  }
 

 
  public void validate() {
    Util.validateFileWritable(outputFile);
  }
 
  public void loadTags(OpenObjectIntHashMap<String> tags) throws IOException {
    BufferedReader reader = new BufferedReader(new FileReader(countFile));
    String line;
    while ((line = reader.readLine()) != null) {
      int pos = line.lastIndexOf('\t');
      String tag = new String(line.substring(pos+1));
      tags.adjustOrPutValue(tag, 0, 0);
    }
  }
 
  public void execute() {
    PrintStream out = null;
   
    try {
      OpenObjectIntHashMap<String> tagCounts = new OpenObjectIntHashMap<String>();
      OpenObjectIntHashMap<String> tagCorrect = new OpenObjectIntHashMap<String>();
      loadTags(tagCounts);
     
      StackOverflowStream stream = new StackOverflowStream();
      stream.open(inputFile.getAbsolutePath());
     
      out = new PrintStream(new FileOutputStream(outputFile));
     
      int correctTagCount  = 0;
      int postCount        = 0;
     
      HashSet<String> postTags = new HashSet<String>();
      float postPctCorrect;
     
      int totalSingleCorrect = 0;
      int totalHalfCorrect   = 0;
     
      for (StackOverflowPost post: stream) {
        correctTagCount = 0;
        postCount++;
       
        postTags.clear();
        postTags.addAll(post.getTags());
        for (String tag: post.getTags()) {
          if (tagCounts.containsKey(tag)) {
            tagCounts.adjustOrPutValue(tag, 1, 1);
          }
        }
       
        ScoreTag[] tags = client.getTags(post.getTitle() + "\n" + post.getBody(), maxTags);
       
        for (ScoreTag tag: tags) {
          if (postTags.contains(tag.getTag())) {
            correctTagCount += 1;
            tagCorrect.adjustOrPutValue(tag.getTag(), 1, 1);
          }
        }
       
        if (correctTagCount > 0) {
          totalSingleCorrect += 1;
        }
       
        postPctCorrect = correctTagCount / (float) postTags.size();
        if (postPctCorrect >= 0.50f) {
          totalHalfCorrect += 1;
        }
       
        if ((postCount % 100) == 0 ) {
          dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect);
        }
       
      }
     
      dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect);
      dumpStats(out, postCount, totalSingleCorrect, totalHalfCorrect);
      dumpTags(out, tagCounts, tagCorrect);
    }
    catch (Exception ex) {
      throw (RuntimeException) new RuntimeException().initCause(ex);
    }
    finally {
      if (out != null) {
        out.close();
      }
    }
  }
 

  /** Dump the tag metrics */
  public void dumpTags(final PrintStream out,
      final OpenObjectIntHashMap<String> tagCounts,
      final OpenObjectIntHashMap<String> tagCorrect) {
   
    out.println("-- tag\ttotal\tcorrect\tpct-correct --");
   
    tagCounts.forEachPair(new ObjectIntProcedure<String>() {
      @Override
      public boolean apply(String tag, int total) {
        int correct = tagCorrect.get(tag);
       
        out.println(tag + "\t" + total + "\t" + correct + "\t"
            + nf.format(((correct * 100) / (float) total)));
        return true;
      }
    });
   
    out.println();
    out.flush();
  }
 
  /** Dump the overall metrics */
  public void dumpStats(PrintStream out, int postCount, int totalSingleCorrect, int totalHalfCorrect) {
    out.println("evaluated " + postCount + " posts; "
        + totalSingleCorrect + " with one correct tag, "
        + totalHalfCorrect + " with half correct");

    out.print("\t %single correct: " + nf.format((totalSingleCorrect * 100) / (float) postCount));
    out.println(", %half correct: " + nf.format((totalHalfCorrect * 100) / (float) postCount));
    out.println();
    out.flush();
  }
}
TOP

Related Classes of com.tamingtext.tagrecommender.TestStackOverflowTagger

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.